Skip to content

Commit

Permalink
[mypyc] Add initial support for compiling singledispatch functions (#…
Browse files Browse the repository at this point in the history
…10753)

This PR adds initial support for compiling functions marked with singledispatch by generating IR that checks the type of the first argument and calls the correct implementation, falling back to the main singledispatch function if none of the registered implementations have a dispatch type that matches the argument.

Currently, this only supports both one-argument versions of register (passing a type as an argument to register or using type annotations), and only works if register is used as a decorator.
  • Loading branch information
pranavrajpal committed Jul 4, 2021
1 parent e07ad3b commit a5a9e15
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 34 deletions.
1 change: 1 addition & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self,
self.encapsulating_funcs = pbv.encapsulating_funcs
self.nested_fitems = pbv.nested_funcs.keys()
self.fdefs_to_decorators = pbv.funcs_to_decorators
self.singledispatch_impls = pbv.singledispatch_impls

self.visitor = visitor

Expand Down
4 changes: 3 additions & 1 deletion mypyc/irbuild/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self,
is_nested: bool = False,
contains_nested: bool = False,
is_decorated: bool = False,
in_non_ext: bool = False) -> None:
in_non_ext: bool = False,
is_singledispatch: bool = False) -> None:
self.fitem = fitem
self.name = name if not is_decorated else decorator_helper_name(name)
self.class_name = class_name
Expand All @@ -47,6 +48,7 @@ def __init__(self,
self.contains_nested = contains_nested
self.is_decorated = is_decorated
self.in_non_ext = in_non_ext
self.is_singledispatch = is_singledispatch

# TODO: add field for ret_type: RType = none_rprimitive

Expand Down
78 changes: 66 additions & 12 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
instance of the callable class.
"""

from typing import Optional, List, Tuple, Union, Dict
from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict

from mypy.nodes import (
ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr,
FuncItem, LambdaExpr, SymbolNode, ARG_NAMED, ARG_NAMED_OPT
FuncItem, LambdaExpr, SymbolNode, ARG_NAMED, ARG_NAMED_OPT, TypeInfo
)
from mypy.types import CallableType, get_proper_type

Expand All @@ -28,7 +28,9 @@
)
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
from mypyc.primitives.generic_ops import py_setattr_op, next_raw_op, iter_op
from mypyc.primitives.misc_ops import check_stop_op, yield_from_except_op, coro_op, send_op
from mypyc.primitives.misc_ops import (
check_stop_op, yield_from_except_op, coro_op, send_op, slow_isinstance_op
)
from mypyc.primitives.dict_ops import dict_set_item_op
from mypyc.common import SELF_NAME, LAMBDA_NAME, decorator_helper_name
from mypyc.sametype import is_same_method_signature
Expand Down Expand Up @@ -84,7 +86,10 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
decorated_func = load_decorated_func(builder, dec.func, func_reg)
builder.assign(get_func_target(builder, dec.func), decorated_func, dec.func.line)
func_reg = decorated_func
else:
# If the prebuild pass didn't put this function in the function to decorators map (for example
# if this is a registered singledispatch implementation with no other decorators), we should
# treat this function as a regular function, not a decorated function
elif dec.func in builder.fdefs_to_decorators:
# Obtain the the function name in order to construct the name of the helper function.
name = dec.func.fullname.split('.')[-1]
helper_name = decorator_helper_name(name)
Expand Down Expand Up @@ -206,6 +211,7 @@ def c() -> None:
is_nested = fitem in builder.nested_fitems or isinstance(fitem, LambdaExpr)
contains_nested = fitem in builder.encapsulating_funcs.keys()
is_decorated = fitem in builder.fdefs_to_decorators
is_singledispatch = fitem in builder.singledispatch_impls
in_non_ext = False
class_name = None
if cdef:
Expand All @@ -214,7 +220,8 @@ def c() -> None:
class_name = cdef.name

builder.enter(FuncInfo(fitem, name, class_name, gen_func_ns(builder),
is_nested, contains_nested, is_decorated, in_non_ext))
is_nested, contains_nested, is_decorated, in_non_ext,
is_singledispatch))

# Functions that contain nested functions need an environment class to store variables that
# are free in their nested functions. Generator functions need an environment class to
Expand Down Expand Up @@ -247,6 +254,9 @@ def c() -> None:
if builder.fn_info.contains_nested and not builder.fn_info.is_generator:
finalize_env_class(builder)

if builder.fn_info.is_singledispatch:
add_singledispatch_registered_impls(builder)

builder.ret_types[-1] = sig.ret_type

# Add all variables and functions that are declared/defined within this
Expand Down Expand Up @@ -628,6 +638,23 @@ def gen_glue(builder: IRBuilder, sig: FuncSignature, target: FuncIR,
return gen_glue_method(builder, sig, target, cls, base, fdef.line, do_py_ops)


class ArgInfo(NamedTuple):
args: List[Value]
arg_names: List[Optional[str]]
arg_kinds: List[int]


def get_args(builder: IRBuilder, rt_args: Sequence[RuntimeArg], line: int) -> ArgInfo:
# The environment operates on Vars, so we make some up
fake_vars = [(Var(arg.name), arg.type) for arg in rt_args]
args = [builder.read(builder.add_local_reg(var, type, is_arg=True), line)
for var, type in fake_vars]
arg_names = [arg.name if arg.kind in (ARG_NAMED, ARG_NAMED_OPT) else None
for arg in rt_args]
arg_kinds = [concrete_arg_kind(arg.kind) for arg in rt_args]
return ArgInfo(args, arg_names, arg_kinds)


def gen_glue_method(builder: IRBuilder, sig: FuncSignature, target: FuncIR,
cls: ClassIR, base: ClassIR, line: int,
do_pycall: bool,
Expand Down Expand Up @@ -664,13 +691,8 @@ def f(builder: IRBuilder, x: object) -> int: ...
if target.decl.kind == FUNC_NORMAL:
rt_args[0] = RuntimeArg(sig.args[0].name, RInstance(cls))

# The environment operates on Vars, so we make some up
fake_vars = [(Var(arg.name), arg.type) for arg in rt_args]
args = [builder.read(builder.add_local_reg(var, type, is_arg=True), line)
for var, type in fake_vars]
arg_names = [arg.name if arg.kind in (ARG_NAMED, ARG_NAMED_OPT) else None
for arg in rt_args]
arg_kinds = [concrete_arg_kind(arg.kind) for arg in rt_args]
arg_info = get_args(builder, rt_args, line)
args, arg_kinds, arg_names = arg_info.args, arg_info.arg_kinds, arg_info.arg_names

if do_pycall:
retval = builder.builder.py_method_call(
Expand Down Expand Up @@ -739,3 +761,35 @@ def get_func_target(builder: IRBuilder, fdef: FuncDef) -> AssignmentTarget:
return builder.lookup(fdef)

return builder.add_local_reg(fdef, object_rprimitive)


def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int) -> Value:
if typ in builder.mapper.type_to_ir:
class_ir = builder.mapper.type_to_ir[typ]
return builder.builder.isinstance_native(obj, class_ir, line)
else:
class_obj = builder.load_module_attr_by_fullname(typ.fullname, line)
return builder.call_c(slow_isinstance_op, [obj, class_obj], line)


def add_singledispatch_registered_impls(builder: IRBuilder) -> None:
fitem = builder.fn_info.fitem
assert isinstance(fitem, FuncDef)
impls = builder.singledispatch_impls[fitem]
line = fitem.line
current_func_decl = builder.mapper.func_to_decl[fitem]
arg_info = get_args(builder, current_func_decl.sig.args, line)
for dispatch_type, impl in impls:
func_decl = builder.mapper.func_to_decl[impl]
call_impl, next_impl = BasicBlock(), BasicBlock()
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
builder.add_bool_branch(should_call_impl, call_impl, next_impl)

# Call the registered implementation
builder.activate_block(call_impl)

ret_val = builder.builder.call(
func_decl, arg_info.args, arg_info.arg_kinds, arg_info.arg_names, line
)
builder.nonlocal_control[-1].gen_return(builder, ret_val, line)
builder.activate_block(next_impl)
67 changes: 65 additions & 2 deletions mypyc/irbuild/prebuildvisitor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Dict, List, Set
from mypy.types import Instance, get_proper_type
from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple
from collections import defaultdict

from mypy.nodes import (
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr,
CallExpr, RefExpr, TypeInfo
)
from mypy.traverser import TraverserVisitor

Expand Down Expand Up @@ -50,6 +53,10 @@ def __init__(self) -> None:
# Map function to its non-special decorators.
self.funcs_to_decorators: Dict[FuncDef, List[Expression]] = {}

# Map of main singledispatch function to list of registered implementations
self.singledispatch_impls: DefaultDict[
FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list)

def visit_decorator(self, dec: Decorator) -> None:
if dec.decorators:
# Only add the function being decorated if there exist
Expand All @@ -63,6 +70,20 @@ def visit_decorator(self, dec: Decorator) -> None:
# Property setters are not treated as decorated methods.
self.prop_setters.add(dec.func)
else:
removed: List[int] = []
for i, d in enumerate(dec.decorators):
impl = get_singledispatch_register_call_info(d, dec.func)
if impl is not None:
self.singledispatch_impls[impl.singledispatch_func].append(
(impl.dispatch_type, dec.func))
removed.append(i)
for i in reversed(removed):
del dec.decorators[i]
# if the only decorators are register calls, we shouldn't treat this
# as a decorated function because there aren't any decorators to apply
if not dec.decorators:
return

self.funcs_to_decorators[dec.func] = dec.decorators
super().visit_decorator(dec)

Expand Down Expand Up @@ -141,3 +162,45 @@ def add_free_variable(self, symbol: SymbolNode) -> None:
# and mark is as a non-local symbol within that function.
func = self.symbols_to_funcs[symbol]
self.free_variables.setdefault(func, set()).add(symbol)


class RegisteredImpl(NamedTuple):
singledispatch_func: FuncDef
dispatch_type: TypeInfo


def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef
) -> Optional[RegisteredImpl]:
# @fun.register(complex)
# def g(arg): ...
if (isinstance(decorator, CallExpr) and len(decorator.args) == 1
and isinstance(decorator.args[0], RefExpr)):
callee = decorator.callee
dispatch_type = decorator.args[0].node
if not isinstance(dispatch_type, TypeInfo):
return None

if isinstance(callee, MemberExpr):
return registered_impl_from_possible_register_call(callee, dispatch_type)
# @fun.register
# def g(arg: int): ...
elif isinstance(decorator, MemberExpr):
# we don't know if this is a register call yet, so we can't be sure that the function
# actually has arguments
if not func.arguments:
return None
arg_type = get_proper_type(func.arguments[0].variable.type)
if not isinstance(arg_type, Instance):
return None
info = arg_type.type
return registered_impl_from_possible_register_call(decorator, info)
return None


def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo
) -> Optional[RegisteredImpl]:
if expr.name == 'register' and isinstance(expr.expr, NameExpr):
node = expr.expr.node
if isinstance(node, Decorator):
return RegisteredImpl(node.func, dispatch_type)
return None
2 changes: 1 addition & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
is_borrowed=True)

# isinstance(obj, cls)
function_op(
slow_isinstance_op = function_op(
name='builtins.isinstance',
arg_types=[object_rprimitive, object_rprimitive],
return_type=c_int_rprimitive,
Expand Down
73 changes: 55 additions & 18 deletions mypyc/test-data/run-singledispatch.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Most of these tests are marked as xfails because mypyc doesn't support singledispatch yet
# (These tests will be re-enabled when mypyc supports singledispatch)

[case testSpecializedImplementationUsed-xfail]
[case testSpecializedImplementationUsed]
from functools import singledispatch

@singledispatch
Expand All @@ -17,7 +17,7 @@ def test_specialize() -> None:
assert fun('a')
assert not fun(3)

[case testSubclassesOfExpectedTypeUseSpecialized-xfail]
[case testSubclassesOfExpectedTypeUseSpecialized]
from functools import singledispatch
class A: pass
class B(A): pass
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_singledispatch() -> None:
assert fun('a') == 'str'
assert fun({'a': 'b'}) == 'default'

[case testCanRegisterCompiledClasses-xfail]
[case testCanRegisterCompiledClasses]
from functools import singledispatch
class A: pass

Expand Down Expand Up @@ -136,21 +136,6 @@ def fun_specialized(arg: int) -> bool:
def test_singledispatch() -> None:
assert fun_specialized('a')

[case testTypeAnnotationsDisagreeWithRegisterArgument-xfail]
from functools import singledispatch

@singledispatch
def fun(arg) -> bool:
return False

@fun.register(int)
def fun_specialized(arg: str) -> bool:
return True

def test_singledispatch() -> None:
assert fun(3) # type: ignore
assert not fun('a')

[case testNoneIsntATypeWhenUsedAsArgumentToRegister-xfail]
from functools import singledispatch

Expand Down Expand Up @@ -385,3 +370,55 @@ def test_verify() -> None:
assert verify_list(MypyFile(), 5, ['a', 'b']) == ['in TypeInfo', 'hello']
assert verify_list(TypeInfo(), str, ['a', 'b']) == ['in TypeInfo', 'hello']
assert verify_list(TypeVarExpr(), 'a', ['x', 'y']) == ['x', 'y']

[case testArgsInRegisteredImplNamedDifferentlyFromMainFunction]
from functools import singledispatch

@singledispatch
def f(a) -> bool:
return False

@f.register
def g(b: int) -> bool:
return True

def test_singledispatch():
assert f(5)
assert not f('a')

[case testKeywordArguments-xfail]
from functools import singledispatch

@singledispatch
def f(arg, *, kwarg: bool = False) -> bool:
return not kwarg

@f.register
def g(arg: int, *, kwarg: bool = True) -> bool:
return kwarg

def test_keywords():
assert f('a')
assert f('a', kwarg=False)
assert not f('a', kwarg=True)

assert f(1)
assert f(1, kwarg=True)
assert not f(1, kwarg=False)

[case testGeneratorAndMultipleTypesOfIterable-xfail]
from functools import singledispatch
from typing import *

@singledispatch
def f(arg: Any) -> Iterable[int]:
yield 1

@f.register
def g(arg: str) -> Iterable[int]:
return [0]

def test_iterables():
assert f(1) != [1]
assert list(f(1)) == [1]
assert f('a') == [0]

0 comments on commit a5a9e15

Please sign in to comment.