New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for functools.partial #16939
base: master
Are you sure you want to change the base?
Changes from all commits
5b56460
ff4914f
c0084eb
a847234
2a5f3f1
325cae4
1ee376d
b7ca434
03b397e
d2886ac
65e356a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,22 @@ | |
|
||
from typing import Final, NamedTuple | ||
|
||
import mypy.checker | ||
import mypy.plugin | ||
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var | ||
from mypy.argmap import map_actuals_to_formals | ||
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var | ||
from mypy.plugins.common import add_method_to_class | ||
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type | ||
from mypy.types import ( | ||
AnyType, | ||
CallableType, | ||
Instance, | ||
Overloaded, | ||
Type, | ||
TypeOfAny, | ||
UnboundType, | ||
UninhabitedType, | ||
get_proper_type, | ||
) | ||
|
||
functools_total_ordering_makers: Final = {"functools.total_ordering"} | ||
|
||
|
@@ -102,3 +114,131 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | | |
comparison_methods[name] = None | ||
|
||
return comparison_methods | ||
|
||
|
||
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: | ||
"""Infer a more precise return type for functools.partial""" | ||
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals | ||
return ctx.default_return_type | ||
if len(ctx.arg_types) != 3: # fn, *args, **kwargs | ||
return ctx.default_return_type | ||
if len(ctx.arg_types[0]) != 1: | ||
return ctx.default_return_type | ||
|
||
if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded): | ||
# TODO: handle overloads, just fall back to whatever the non-plugin code does | ||
return ctx.default_return_type | ||
fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type) | ||
if fn_type is None: | ||
return ctx.default_return_type | ||
|
||
defaulted = fn_type.copy_modified( | ||
arg_kinds=[ | ||
( | ||
ArgKind.ARG_OPT | ||
if k == ArgKind.ARG_POS | ||
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k) | ||
) | ||
for k in fn_type.arg_kinds | ||
] | ||
) | ||
if defaulted.line < 0: | ||
# Make up a line number if we don't have one | ||
defaulted.set_line(ctx.default_return_type) | ||
|
||
actual_args = [a for param in ctx.args[1:] for a in param] | ||
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param] | ||
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param] | ||
actual_types = [a for param in ctx.arg_types[1:] for a in param] | ||
|
||
_, bound = ctx.api.expr_checker.check_call( | ||
callee=defaulted, | ||
args=actual_args, | ||
arg_kinds=actual_arg_kinds, | ||
arg_names=actual_arg_names, | ||
context=defaulted, | ||
) | ||
bound = get_proper_type(bound) | ||
if not isinstance(bound, CallableType): | ||
return ctx.default_return_type | ||
|
||
formal_to_actual = map_actuals_to_formals( | ||
actual_kinds=actual_arg_kinds, | ||
actual_names=actual_arg_names, | ||
formal_kinds=fn_type.arg_kinds, | ||
formal_names=fn_type.arg_names, | ||
actual_arg_type=lambda i: actual_types[i], | ||
) | ||
|
||
partial_kinds = [] | ||
partial_types = [] | ||
partial_names = [] | ||
# We need to fully apply any positional arguments (they cannot be respecified) | ||
# However, keyword arguments can be respecified, so just give them a default | ||
for i, actuals in enumerate(formal_to_actual): | ||
if len(bound.arg_types) == len(fn_type.arg_types): | ||
arg_type = bound.arg_types[i] | ||
if isinstance(get_proper_type(arg_type), UninhabitedType): | ||
arg_type = fn_type.arg_types[i] # bit of a hack | ||
else: | ||
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't | ||
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple | ||
arg_type = fn_type.arg_types[i] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above. |
||
|
||
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2): | ||
partial_kinds.append(fn_type.arg_kinds[i]) | ||
partial_types.append(arg_type) | ||
partial_names.append(fn_type.arg_names[i]) | ||
elif actuals: | ||
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals): | ||
continue | ||
kind = actual_arg_kinds[actuals[0]] | ||
if kind == ArgKind.ARG_NAMED: | ||
kind = ArgKind.ARG_NAMED_OPT | ||
partial_kinds.append(kind) | ||
partial_types.append(arg_type) | ||
partial_names.append(fn_type.arg_names[i]) | ||
|
||
ret_type = bound.ret_type | ||
if isinstance(get_proper_type(ret_type), UninhabitedType): | ||
ret_type = fn_type.ret_type # same kind of hack as above | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above -- it seems that this might leak type variables. If that is the case, it would probably be better to fall back to the default return type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this doesn't leak type variables, because But let me know if I'm off the mark! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you are right, and this is good already. |
||
|
||
partially_applied = fn_type.copy_modified( | ||
arg_types=partial_types, | ||
arg_kinds=partial_kinds, | ||
arg_names=partial_names, | ||
ret_type=ret_type, | ||
) | ||
|
||
ret = ctx.api.named_generic_type("functools.partial", [ret_type]) | ||
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Now if we import I'm not sure what is the best way to fix this. Probably the simplest option would be to serialize Also it would be good to have an incremental mode test case. It looks like mypy daemon doesn't keep track of |
||
return ret | ||
|
||
|
||
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: | ||
"""Infer a more precise return type for functools.partial.__call__.""" | ||
if ( | ||
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals | ||
or not isinstance(ctx.type, Instance) | ||
or ctx.type.type.fullname != "functools.partial" | ||
or not ctx.type.extra_attrs | ||
or "__mypy_partial" not in ctx.type.extra_attrs.attrs | ||
): | ||
return ctx.default_return_type | ||
|
||
partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"] | ||
if len(ctx.arg_types) != 2: # *args, **kwargs | ||
return ctx.default_return_type | ||
|
||
args = [a for param in ctx.args for a in param] | ||
arg_kinds = [a for param in ctx.arg_kinds for a in param] | ||
arg_names = [a for param in ctx.arg_names for a in param] | ||
|
||
result = ctx.api.expr_checker.check_call( | ||
callee=partial_type, | ||
args=args, | ||
arg_kinds=arg_kinds, | ||
arg_names=arg_names, | ||
context=ctx.context, | ||
) | ||
return result[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this could refer to a type variable type, so I wonder if this can leak type variables, in the target function is a generic one, and one of the provided arguments can be used to bind the type variable.