From c1f1f9566feefa2ade91b880f731d0b0cf76ff09 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Mon, 1 Jun 2020 23:52:16 +0800 Subject: [PATCH] Support binary ops via CallC (#8929) related mypyc/mypyc#709, mypyc/mypyc#734 * support binary ops, implement str += * support list * int, int * list --- mypyc/irbuild/ll_builder.py | 7 ++++++- mypyc/primitives/list_ops.py | 26 ++++++++++++-------------- mypyc/primitives/registry.py | 17 +++++++++++++++++ mypyc/primitives/str_ops.py | 14 +++++++------- mypyc/test-data/irbuild-lists.test | 4 ++-- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index b8720a58bb96..98ab82f0c569 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -34,7 +34,8 @@ ) from mypyc.primitives.registry import ( binary_ops, unary_ops, method_ops, func_ops, - c_method_call_ops, CFunctionDescription, c_function_ops + c_method_call_ops, CFunctionDescription, c_function_ops, + c_binary_ops ) from mypyc.primitives.list_ops import ( list_extend_op, list_len_op, new_list_op @@ -541,6 +542,10 @@ def binary_op(self, if value is not None: return value + call_c_ops_candidates = c_binary_ops.get(expr_op, []) + target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line) + if target: + return target ops = binary_ops.get(expr_op, []) target = self.matching_primitive_op(ops, [lreg, rreg], line) assert target, 'Unsupported binary operation: %s' % expr_op diff --git a/mypyc/primitives/list_ops.py b/mypyc/primitives/list_ops.py index 7d177782dcc9..76254dd35292 100644 --- a/mypyc/primitives/list_ops.py +++ b/mypyc/primitives/list_ops.py @@ -7,8 +7,8 @@ int_rprimitive, short_int_rprimitive, list_rprimitive, object_rprimitive, bool_rprimitive ) from mypyc.primitives.registry import ( - name_ref_op, binary_op, func_op, method_op, custom_op, name_emit, - call_emit, call_negative_bool_emit, c_function_op + name_ref_op, func_op, method_op, custom_op, name_emit, + call_emit, call_negative_bool_emit, c_function_op, c_binary_op ) @@ -125,20 +125,18 @@ def emit_new(emitter: EmitterInterface, args: List[str], dest: str) -> None: emit=call_emit('CPyList_Count')) # list * int -binary_op(op='*', - arg_types=[list_rprimitive, int_rprimitive], - result_type=list_rprimitive, - error_kind=ERR_MAGIC, - format_str='{dest} = {args[0]} * {args[1]} :: list', - emit=call_emit("CPySequence_Multiply")) +c_binary_op(name='*', + arg_types=[list_rprimitive, int_rprimitive], + return_type=list_rprimitive, + c_function_name='CPySequence_Multiply', + error_kind=ERR_MAGIC) # int * list -binary_op(op='*', - arg_types=[int_rprimitive, list_rprimitive], - result_type=list_rprimitive, - error_kind=ERR_MAGIC, - format_str='{dest} = {args[0]} * {args[1]} :: list', - emit=call_emit("CPySequence_RMultiply")) +c_binary_op(name='*', + arg_types=[int_rprimitive, list_rprimitive], + return_type=list_rprimitive, + c_function_name='CPySequence_RMultiply', + error_kind=ERR_MAGIC) def emit_len(emitter: EmitterInterface, args: List[str], dest: str) -> None: diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index eee929db2e35..8ed180f48576 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -72,6 +72,9 @@ # CallC op for top level function call(such as 'builtins.list') c_function_ops = {} # type: Dict[str, List[CFunctionDescription]] +# CallC op for binary ops +c_binary_ops = {} # type: Dict[str, List[CFunctionDescription]] + def simple_emit(template: str) -> EmitCallback: """Construct a simple PrimitiveOp emit callback function. @@ -354,6 +357,20 @@ def c_function_op(name: str, return desc +def c_binary_op(name: str, + arg_types: List[RType], + return_type: RType, + c_function_name: str, + error_kind: int, + steals: StealsDescription = False, + priority: int = 1) -> CFunctionDescription: + ops = c_binary_ops.setdefault(name, []) + desc = CFunctionDescription(name, arg_types, return_type, + c_function_name, error_kind, steals, priority) + ops.append(desc) + return desc + + # Import various modules that set up global state. import mypyc.primitives.int_ops # noqa import mypyc.primitives.str_ops # noqa diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 03bd386ef05f..2e261131257b 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -8,7 +8,7 @@ ) from mypyc.primitives.registry import ( func_op, binary_op, simple_emit, name_ref_op, method_op, call_emit, name_emit, - c_method_op + c_method_op, c_binary_op ) @@ -68,12 +68,12 @@ # # PyUnicodeAppend makes an effort to reuse the LHS when the refcount # is 1. This is super dodgy but oh well, the interpreter does it. -binary_op(op='+=', - arg_types=[str_rprimitive, str_rprimitive], - steals=[True, False], - result_type=str_rprimitive, - error_kind=ERR_MAGIC, - emit=call_emit('CPyStr_Append')) +c_binary_op(name='+=', + arg_types=[str_rprimitive, str_rprimitive], + return_type=str_rprimitive, + c_function_name='CPyStr_Append', + error_kind=ERR_MAGIC, + steals=[True, False]) def emit_str_compare(comparison: str) -> Callable[[EmitterInterface, List[str], str], None]: diff --git a/mypyc/test-data/irbuild-lists.test b/mypyc/test-data/irbuild-lists.test index cd12bfcdf10e..de760d71914b 100644 --- a/mypyc/test-data/irbuild-lists.test +++ b/mypyc/test-data/irbuild-lists.test @@ -121,13 +121,13 @@ def f(a): r7 :: None L0: r0 = 2 - r1 = a * r0 :: list + r1 = CPySequence_Multiply(a, r0) b = r1 r2 = 3 r3 = 4 r4 = box(short_int, r3) r5 = [r4] - r6 = r2 * r5 :: list + r6 = CPySequence_RMultiply(r2, r5) b = r6 r7 = None return r7