Skip to content

Commit

Permalink
[mypyc] Support top level function ops via CallC (#8902)
Browse files Browse the repository at this point in the history
Relates to mypyc/mypyc#709

This PR supports top-level function ops via recently added CallC IR. To demonstrate 
the idea, it transform to_list op from PrimitiveOp to CallC. It also refines CallC with
arguments coercing and support of steals.
  • Loading branch information
TH3CHARLie committed Jun 1, 2020
1 parent 273a865 commit 8457e50
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 26 deletions.
19 changes: 16 additions & 3 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,13 +1145,19 @@ class CallC(RegisterOp):
A call to a C function
"""

error_kind = ERR_MAGIC

def __init__(self, function_name: str, args: List[Value], ret_type: RType, line: int) -> None:
def __init__(self,
function_name: str,
args: List[Value],
ret_type: RType,
steals: StealsDescription,
error_kind: int,
line: int) -> None:
self.error_kind = error_kind
super().__init__(line)
self.function_name = function_name
self.args = args
self.type = ret_type
self.steals = steals

def to_str(self, env: Environment) -> str:
args_str = ', '.join(env.format('%r', arg) for arg in self.args)
Expand All @@ -1160,6 +1166,13 @@ def to_str(self, env: Environment) -> str:
def sources(self) -> List[Value]:
return self.args

def stolen(self) -> List[Value]:
if isinstance(self.steals, list):
assert len(self.steals) == len(self.args)
return [arg for arg, steal in zip(self.args, self.steals) if steal]
else:
return [] if not self.steals else self.sources()

def accept(self, visitor: 'OpVisitor[T]') -> T:
return visitor.visit_call_c(self)

Expand Down
12 changes: 10 additions & 2 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
from mypyc.primitives.registry import func_ops
from mypyc.primitives.registry import func_ops, CFunctionDescription, c_function_ops
from mypyc.primitives.list_ops import list_len_op, to_list, list_pop_last
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op
Expand Down Expand Up @@ -229,6 +229,9 @@ def gen_method_call(self,
def load_module(self, name: str) -> Value:
return self.builder.load_module(name)

def call_c(self, desc: CFunctionDescription, args: List[Value], line: int) -> Value:
return self.builder.call_c(desc, args, line)

@property
def environment(self) -> Environment:
return self.builder.environment
Expand Down Expand Up @@ -498,7 +501,7 @@ def process_iterator_tuple_assignment(self,
# Assign the starred value and all values after it
if target.star_idx is not None:
post_star_vals = target.items[split_idx + 1:]
iter_list = self.primitive_op(to_list, [iterator], line)
iter_list = self.call_c(to_list, [iterator], line)
iter_list_len = self.primitive_op(list_len_op, [iter_list], line)
post_star_len = self.add(LoadInt(len(post_star_vals)))
condition = self.binary_op(post_star_len, iter_list_len, '<=', line)
Expand Down Expand Up @@ -715,6 +718,11 @@ def call_refexpr_with_args(

# Handle data-driven special-cased primitive call ops.
if callee.fullname is not None and expr.arg_kinds == [ARG_POS] * len(arg_values):
call_c_ops_candidates = c_function_ops.get(callee.fullname, [])
target = self.builder.matching_call_c(call_c_ops_candidates, arg_values,
expr.line, self.node_type(expr))
if target:
return target
ops = func_ops.get(callee.fullname, [])
target = self.builder.matching_primitive_op(
ops, arg_values, expr.line, self.node_type(expr)
Expand Down
33 changes: 24 additions & 9 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from mypyc.ir.rtypes import (
RType, RUnion, RInstance, optional_value_type, int_rprimitive, float_rprimitive,
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive,
void_rtype
)
from mypyc.ir.func_ir import FuncDecl, FuncSignature
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
Expand All @@ -35,7 +34,7 @@
)
from mypyc.primitives.registry import (
binary_ops, unary_ops, method_ops, func_ops,
c_method_call_ops, CFunctionDescription
c_method_call_ops, CFunctionDescription, c_function_ops
)
from mypyc.primitives.list_ops import (
list_extend_op, list_len_op, new_list_op
Expand Down Expand Up @@ -592,6 +591,10 @@ def builtin_call(self,
args: List[Value],
fn_op: str,
line: int) -> Value:
call_c_ops_candidates = c_function_ops.get(fn_op, [])
target = self.matching_call_c(call_c_ops_candidates, args, line)
if target:
return target
ops = func_ops.get(fn_op, [])
target = self.matching_primitive_op(ops, args, line)
assert target, 'Unsupported builtin function: %s' % fn_op
Expand Down Expand Up @@ -667,13 +670,25 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
self.add(Branch(value, true, false, Branch.BOOL_EXPR))

def call_c(self,
function_name: str,
desc: CFunctionDescription,
args: List[Value],
line: int,
result_type: Optional[RType]) -> Value:
result_type: Optional[RType] = None) -> Value:
# handle void function via singleton RVoid instance
ret_type = void_rtype if result_type is None else result_type
target = self.add(CallC(function_name, args, ret_type, line))
coerced = []
for i, arg in enumerate(args):
formal_type = desc.arg_types[i]
arg = self.coerce(arg, formal_type, line)
coerced.append(arg)
target = self.add(CallC(desc.c_function_name, coerced, desc.return_type, desc.steals,
desc.error_kind, line))
if result_type and not is_runtime_subtype(target.type, result_type):
if is_none_rprimitive(result_type):
# Special case None return. The actual result may actually be a bool
# and so we can't just coerce it.
target = self.none()
else:
target = self.coerce(target, result_type, line)
return target

def matching_call_c(self,
Expand All @@ -697,7 +712,7 @@ def matching_call_c(self,
else:
matching = desc
if matching:
target = self.call_c(matching.c_function_name, args, line, result_type)
target = self.call_c(matching, args, line, result_type)
return target
return None

Expand Down Expand Up @@ -786,8 +801,8 @@ def translate_special_method_call(self,
"""
ops = method_ops.get(name, [])
call_c_ops_candidates = c_method_call_ops.get(name, [])
call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args, line,
result_type=result_type)
call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args,
line, result_type)
if call_c_op is not None:
return call_c_op
return self.matching_primitive_op(ops, [base_reg] + args, line, result_type=result_type)
Expand Down
10 changes: 5 additions & 5 deletions mypyc/primitives/list_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from mypyc.primitives.registry import (
name_ref_op, binary_op, func_op, method_op, custom_op, name_emit,
call_emit, call_negative_bool_emit,
call_emit, call_negative_bool_emit, c_function_op
)


Expand All @@ -20,12 +20,13 @@
is_borrowed=True)

# list(obj)
to_list = func_op(
to_list = c_function_op(
name='builtins.list',
arg_types=[object_rprimitive],
result_type=list_rprimitive,
return_type=list_rprimitive,
c_function_name='PySequence_List',
error_kind=ERR_MAGIC,
emit=call_emit('PySequence_List'))
)


def emit_new(emitter: EmitterInterface, args: List[str], dest: str) -> None:
Expand Down Expand Up @@ -83,7 +84,6 @@ def emit_new(emitter: EmitterInterface, args: List[str], dest: str) -> None:
error_kind=ERR_FALSE,
emit=call_emit('CPyList_SetItem'))


# list.append(obj)
list_append_op = method_op(
name='append',
Expand Down
31 changes: 26 additions & 5 deletions mypyc/primitives/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
CFunctionDescription = NamedTuple(
'CFunctionDescription', [('name', str),
('arg_types', List[RType]),
('result_type', Optional[RType]),
('return_type', RType),
('c_function_name', str),
('error_kind', int),
('steals', StealsDescription),
('priority', int)])

# Primitive binary ops (key is operator such as '+')
Expand All @@ -65,8 +66,12 @@
# Primitive ops for reading module attributes (key is name such as 'builtins.None')
name_ref_ops = {} # type: Dict[str, OpDescription]

# CallC op for method call(such as 'str.join')
c_method_call_ops = {} # type: Dict[str, List[CFunctionDescription]]

# CallC op for top level function call(such as 'builtins.list')
c_function_ops = {} # type: Dict[str, List[CFunctionDescription]]


def simple_emit(template: str) -> EmitCallback:
"""Construct a simple PrimitiveOp emit callback function.
Expand Down Expand Up @@ -323,14 +328,30 @@ def custom_op(arg_types: List[RType],

def c_method_op(name: str,
arg_types: List[RType],
result_type: Optional[RType],
return_type: RType,
c_function_name: str,
error_kind: int,
priority: int = 1) -> None:
steals: StealsDescription = False,
priority: int = 1) -> CFunctionDescription:
ops = c_method_call_ops.setdefault(name, [])
desc = CFunctionDescription(name, arg_types, result_type,
c_function_name, error_kind, priority)
desc = CFunctionDescription(name, arg_types, return_type,
c_function_name, error_kind, steals, priority)
ops.append(desc)
return desc


def c_function_op(name: str,
arg_types: List[RType],
return_type: RType,
c_function_name: str,
error_kind: int,
steals: StealsDescription = False,
priority: int = 1) -> CFunctionDescription:
ops = c_function_ops.setdefault(name, [])
desc = CFunctionDescription(name, arg_types, return_type,
c_function_name, error_kind, steals, priority)
ops.append(desc)
return desc


# Import various modules that set up global state.
Expand Down
2 changes: 1 addition & 1 deletion mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
c_method_op(
name='join',
arg_types=[str_rprimitive, object_rprimitive],
result_type=str_rprimitive,
return_type=str_rprimitive,
c_function_name='PyUnicode_Join',
error_kind=ERR_MAGIC
)
Expand Down
38 changes: 37 additions & 1 deletion mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3383,7 +3383,7 @@ L0:
r5 = None
return r5

[case testCallCWithStrJoin]
[case testCallCWithStrJoinMethod]
from typing import List
def f(x: str, y: List[str]) -> str:
return x.join(y)
Expand All @@ -3395,3 +3395,39 @@ def f(x, y):
L0:
r0 = PyUnicode_Join(x, y)
return r0

[case testCallCWithToListFunction]
from typing import List, Iterable, Tuple, Dict
# generic object
def f(x: Iterable[int]) -> List[int]:
return list(x)

# need coercing
def g(x: Tuple[int, int, int]) -> List[int]:
return list(x)

# non-list object
def h(x: Dict[int, str]) -> List[int]:
return list(x)

[out]
def f(x):
x :: object
r0 :: list
L0:
r0 = PySequence_List(x)
return r0
def g(x):
x :: tuple[int, int, int]
r0 :: object
r1 :: list
L0:
r0 = box(tuple[int, int, int], x)
r1 = PySequence_List(r0)
return r1
def h(x):
x :: dict
r0 :: list
L0:
r0 = PySequence_List(x)
return r0

0 comments on commit 8457e50

Please sign in to comment.