Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for functools.partial #16939

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
34 changes: 17 additions & 17 deletions mypy/checkexpr.py
Expand Up @@ -1216,14 +1216,14 @@ def apply_function_plugin(
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)
else:
Expand All @@ -1233,15 +1233,15 @@ def apply_function_plugin(
object_type = get_proper_type(object_type)
return method_callback(
MethodContext(
object_type,
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
type=object_type,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)

Expand Down
15 changes: 12 additions & 3 deletions mypy/plugins/default.py
Expand Up @@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return ctypes.array_constructor_callback
elif fullname == "functools.singledispatch":
return singledispatch.create_singledispatch_function_callback
elif fullname == "functools.partial":
import mypy.plugins.functools

return mypy.plugins.functools.partial_new_callback

return None

Expand Down Expand Up @@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return singledispatch.singledispatch_register_callback
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_after_register_argument
elif fullname == "functools.partial.__call__":
import mypy.plugins.functools

return mypy.plugins.functools.partial_call_callback
return None

def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
Expand Down Expand Up @@ -155,12 +163,13 @@ def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext],
def get_class_decorator_hook_2(
self, fullname: str
) -> Callable[[ClassDefContext], bool] | None:
from mypy.plugins import attrs, dataclasses, functools
import mypy.plugins.functools
from mypy.plugins import attrs, dataclasses

if fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
elif fullname in functools.functools_total_ordering_makers:
return functools.functools_total_ordering_maker_callback
elif fullname in mypy.plugins.functools.functools_total_ordering_makers:
return mypy.plugins.functools.functools_total_ordering_maker_callback
elif fullname in attrs.attr_class_makers:
return attrs.attr_class_maker_callback
elif fullname in attrs.attr_dataclass_makers:
Expand Down
144 changes: 142 additions & 2 deletions mypy/plugins/functools.py
Expand Up @@ -4,10 +4,22 @@

from typing import Final, NamedTuple

import mypy.checker
import mypy.plugin
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
from mypy.types import (
AnyType,
CallableType,
Instance,
Overloaded,
Type,
TypeOfAny,
UnboundType,
UninhabitedType,
get_proper_type,
)

functools_total_ordering_makers: Final = {"functools.total_ordering"}

Expand Down Expand Up @@ -102,3 +114,131 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
comparison_methods[name] = None

return comparison_methods


def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.partial"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
return ctx.default_return_type
if len(ctx.arg_types[0]) != 1:
return ctx.default_return_type

if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
# TODO: handle overloads, just fall back to whatever the non-plugin code does
return ctx.default_return_type
fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type)
if fn_type is None:
return ctx.default_return_type

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
]
)
if defaulted.line < 0:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

actual_args = [a for param in ctx.args[1:] for a in param]
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]

_, bound = ctx.api.expr_checker.check_call(
callee=defaulted,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=defaulted,
)
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
return ctx.default_return_type

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

partial_kinds = []
partial_types = []
partial_names = []
# We need to fully apply any positional arguments (they cannot be respecified)
# However, keyword arguments can be respecified, so just give them a default
for i, actuals in enumerate(formal_to_actual):
if len(bound.arg_types) == len(fn_type.arg_types):
arg_type = bound.arg_types[i]
if isinstance(get_proper_type(arg_type), UninhabitedType):
arg_type = fn_type.arg_types[i] # bit of a hack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this could refer to a type variable type, so I wonder if this can leak type variables, in the target function is a generic one, and one of the provided arguments can be used to bind the type variable.

else:
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
arg_type = fn_type.arg_types[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.


if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
partial_kinds.append(fn_type.arg_kinds[i])
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
elif actuals:
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
continue
kind = actual_arg_kinds[actuals[0]]
if kind == ArgKind.ARG_NAMED:
kind = ArgKind.ARG_NAMED_OPT
partial_kinds.append(kind)
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])

ret_type = bound.ret_type
if isinstance(get_proper_type(ret_type), UninhabitedType):
ret_type = fn_type.ret_type # same kind of hack as above
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above -- it seems that this might leak type variables. If that is the case, it would probably be better to fall back to the default return type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't leak type variables, because partially_applied has the same variables as fn_type (from copy_modified). partially_applied just remains generic. See the test cases in testFunctoolsPartialGeneric. I can add some more comments to the code

But let me know if I'm off the mark!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, and this is good already.


partially_applied = fn_type.copy_modified(
arg_types=partial_types,
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
)

ret = ctx.api.named_generic_type("functools.partial", [ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_attrs is currently a bit poorly thought out, so this will cause some problems. First, it isn't serialized, so storing the additional information there breaks in incremental mode. For example, consider a module like this:

...
p = partial(f, 1)

Now if we import p from another module, the second run which uses serialized data will produce a different type for p compared to the first run.

I'm not sure what is the best way to fix this. Probably the simplest option would be to serialize extra_attrs. It seems that everything we put there can be serialized -- it just hasn't been implemented.

Also it would be good to have an incremental mode test case.

It looks like mypy daemon doesn't keep track of extra_attrs. mypy.server.astdiff should look into extra_attrs to detect changes in extra_attrs, as these may need to be propagated.

return ret


def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Infer a more precise return type for functools.partial.__call__."""
if (
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
or not isinstance(ctx.type, Instance)
or ctx.type.type.fullname != "functools.partial"
or not ctx.type.extra_attrs
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
):
return ctx.default_return_type

partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

args = [a for param in ctx.args for a in param]
arg_kinds = [a for param in ctx.arg_kinds for a in param]
arg_names = [a for param in ctx.arg_names for a in param]

result = ctx.api.expr_checker.check_call(
callee=partial_type,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=ctx.context,
)
return result[0]
9 changes: 5 additions & 4 deletions mypy/types.py
Expand Up @@ -1495,13 +1495,14 @@ def copy_modified(
last_known_value: Bogus[LiteralType | None] = _dummy,
) -> Instance:
new = Instance(
self.type,
args if args is not _dummy else self.args,
self.line,
self.column,
typ=self.type,
args=args if args is not _dummy else self.args,
line=self.line,
column=self.column,
last_known_value=(
last_known_value if last_known_value is not _dummy else self.last_known_value
),
extra_attrs=self.extra_attrs,
)
# We intentionally don't copy the extra_attrs here, so they will be erased.
new.can_be_true = self.can_be_true
Expand Down