New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for functools.partial #16939
base: master
Are you sure you want to change the base?
Changes from 4 commits
5b56460
ff4914f
c0084eb
a847234
2a5f3f1
325cae4
1ee376d
b7ca434
03b397e
d2886ac
65e356a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,21 @@ | |
|
||
from typing import Final, NamedTuple | ||
|
||
import mypy.checker | ||
import mypy.plugin | ||
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var | ||
from mypy.argmap import map_actuals_to_formals | ||
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var | ||
from mypy.plugins.common import add_method_to_class | ||
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type | ||
from mypy.types import ( | ||
AnyType, | ||
CallableType, | ||
Instance, | ||
Type, | ||
TypeOfAny, | ||
UnboundType, | ||
UninhabitedType, | ||
get_proper_type, | ||
) | ||
|
||
functools_total_ordering_makers: Final = {"functools.total_ordering"} | ||
|
||
|
@@ -102,3 +113,125 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | | |
comparison_methods[name] = None | ||
|
||
return comparison_methods | ||
|
||
|
||
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: | ||
"""Infer a more precise return type for functools.partial""" | ||
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals | ||
return ctx.default_return_type | ||
if len(ctx.arg_types) != 3: # fn, *args, **kwargs | ||
return ctx.default_return_type | ||
if len(ctx.arg_types[0]) != 1: | ||
return ctx.default_return_type | ||
|
||
fn_type = get_proper_type(ctx.arg_types[0][0]) | ||
if not isinstance(fn_type, CallableType): | ||
return ctx.default_return_type | ||
|
||
defaulted = fn_type.copy_modified( | ||
arg_kinds=[ | ||
( | ||
ArgKind.ARG_OPT | ||
if k == ArgKind.ARG_POS | ||
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k) | ||
) | ||
for k in fn_type.arg_kinds | ||
] | ||
) | ||
|
||
actual_args = [a for param in ctx.args[1:] for a in param] | ||
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param] | ||
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param] | ||
actual_types = [a for param in ctx.arg_types[1:] for a in param] | ||
|
||
_, bound = ctx.api.expr_checker.check_call( | ||
callee=defaulted, | ||
args=actual_args, | ||
arg_kinds=actual_arg_kinds, | ||
arg_names=actual_arg_names, | ||
context=defaulted, | ||
) | ||
bound = get_proper_type(bound) | ||
if not isinstance(bound, CallableType): | ||
return ctx.default_return_type | ||
|
||
formal_to_actual = map_actuals_to_formals( | ||
actual_kinds=actual_arg_kinds, | ||
actual_names=actual_arg_names, | ||
formal_kinds=fn_type.arg_kinds, | ||
formal_names=fn_type.arg_names, | ||
actual_arg_type=lambda i: actual_types[i], | ||
) | ||
|
||
partial_kinds = [] | ||
partial_types = [] | ||
partial_names = [] | ||
# We need to fully apply any positional arguments (they cannot be respecified) | ||
# However, keyword arguments can be respecified, so just give them a default | ||
for i, actuals in enumerate(formal_to_actual): | ||
if len(bound.arg_types) == len(fn_type.arg_types): | ||
arg_type = bound.arg_types[i] | ||
if isinstance(get_proper_type(arg_type), UninhabitedType): | ||
arg_type = fn_type.arg_types[i] # bit of a hack | ||
else: | ||
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't | ||
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple | ||
arg_type = fn_type.arg_types[i] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above. |
||
|
||
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2): | ||
partial_kinds.append(fn_type.arg_kinds[i]) | ||
partial_types.append(arg_type) | ||
partial_names.append(fn_type.arg_names[i]) | ||
elif actuals: | ||
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals): | ||
continue | ||
kind = actual_arg_kinds[actuals[0]] | ||
if kind == ArgKind.ARG_NAMED: | ||
kind = ArgKind.ARG_NAMED_OPT | ||
partial_kinds.append(kind) | ||
partial_types.append(arg_type) | ||
partial_names.append(fn_type.arg_names[i]) | ||
|
||
ret_type = bound.ret_type | ||
if isinstance(get_proper_type(ret_type), UninhabitedType): | ||
ret_type = fn_type.ret_type # same kind of hack as above | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above -- it seems that this might leak type variables. If that is the case, it would probably be better to fall back to the default return type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this doesn't leak type variables, because But let me know if I'm off the mark! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you are right, and this is good already. |
||
|
||
partially_applied = fn_type.copy_modified( | ||
arg_types=partial_types, | ||
arg_kinds=partial_kinds, | ||
arg_names=partial_names, | ||
ret_type=ret_type, | ||
) | ||
|
||
ret = ctx.api.named_generic_type("functools.partial", [ret_type]) | ||
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Now if we import I'm not sure what is the best way to fix this. Probably the simplest option would be to serialize Also it would be good to have an incremental mode test case. It looks like mypy daemon doesn't keep track of |
||
return ret | ||
|
||
|
||
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: | ||
"""Infer a more precise return type for functools.partial.__call__.""" | ||
if ( | ||
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals | ||
or not isinstance(ctx.type, Instance) | ||
or ctx.type.type.fullname != "functools.partial" | ||
or not ctx.type.extra_attrs | ||
or "__mypy_partial" not in ctx.type.extra_attrs.attrs | ||
): | ||
return ctx.default_return_type | ||
|
||
partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"] | ||
if len(ctx.arg_types) != 2: # *args, **kwargs | ||
return ctx.default_return_type | ||
|
||
args = [a for param in ctx.args for a in param] | ||
arg_kinds = [a for param in ctx.arg_kinds for a in param] | ||
arg_names = [a for param in ctx.arg_names for a in param] | ||
|
||
result = ctx.api.expr_checker.check_call( | ||
callee=partial_type, | ||
args=args, | ||
arg_kinds=arg_kinds, | ||
arg_names=arg_names, | ||
context=ctx.context, | ||
) | ||
return result[0] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,3 +144,99 @@ def f(d: D[C]) -> None: | |
|
||
d: D[int] # E: Type argument "int" of "D" must be a subtype of "C" | ||
[builtins fixtures/dict.pyi] | ||
|
||
[case testFunctoolsPartialBasic] | ||
from typing import Callable | ||
import functools | ||
|
||
def foo(a: int, b: str, c: int = 5) -> int: ... # N: "foo" defined here | ||
|
||
p1 = functools.partial(foo) | ||
p1(1, "a", 3) # OK | ||
p1(1, "a", c=3) # OK | ||
p1(1, b="a", c=3) # OK | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added one of these, it's not very useful because it'll just show you There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah makes sense. The current approach seems fine then. |
||
|
||
def takes_callable_int(f: Callable[..., int]) -> None: ... | ||
def takes_callable_str(f: Callable[..., str]) -> None: ... | ||
takes_callable_int(p1) | ||
takes_callable_str(p1) # E: Argument 1 to "takes_callable_str" has incompatible type "partial[int]"; expected "Callable[..., str]" \ | ||
# N: "partial[int].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], int]" | ||
|
||
p2 = functools.partial(foo, 1) | ||
p2("a") # OK | ||
p2("a", 3) # OK | ||
p2("a", c=3) # OK | ||
p2(1, 3) # E: Argument 1 to "foo" has incompatible type "int"; expected "str" | ||
p2(1, "a", 3) # E: Too many arguments for "foo" \ | ||
# E: Argument 1 to "foo" has incompatible type "int"; expected "str" \ | ||
# E: Argument 2 to "foo" has incompatible type "str"; expected "int" | ||
p2(a=1, b="a", c=3) # E: Unexpected keyword argument "a" for "foo" | ||
|
||
p3 = functools.partial(foo, b="a") | ||
p3(1) # OK | ||
p3(1, c=3) # OK | ||
p3(a=1) # OK | ||
p3(1, b="a", c=3) # OK, keywords can be clobbered | ||
p3(1, 3) # E: Too many positional arguments for "foo" \ | ||
# E: Argument 2 to "foo" has incompatible type "int"; expected "str" | ||
|
||
functools.partial(foo, "a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int" | ||
functools.partial(foo, b=1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str" | ||
functools.partial(foo, a=1, b=2, c=3) # E: Argument 2 to "foo" has incompatible type "int"; expected "str" | ||
functools.partial(1) # E: Argument 1 to "partial" has incompatible type "int"; expected "Callable[..., Never]" | ||
[builtins fixtures/dict.pyi] | ||
|
||
[case testFunctoolsPartialStar] | ||
import functools | ||
|
||
def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ... | ||
|
||
p1 = functools.partial(foo, 1, d="a", x=9) | ||
p1("a", 2, 3, 4) # OK | ||
p1("a", 2, 3, 4, d="a") # OK | ||
p1("a", 2, 3, 4, "a") # E: Argument 5 to "foo" has incompatible type "str"; expected "int" | ||
p1("a", 2, 3, 4, x="a") # E: Argument "x" to "foo" has incompatible type "str"; expected "int" | ||
|
||
p2 = functools.partial(foo, 1, "a") | ||
p2(2, 3, 4, d="a") # OK | ||
p2("a") # E: Missing named argument "d" for "foo" \ | ||
# E: Argument 1 to "foo" has incompatible type "str"; expected "int" | ||
p2(2, 3, 4) # E: Missing named argument "d" for "foo" | ||
|
||
functools.partial(foo, 1, "a", "b", "c", d="a") # E: Argument 3 to "foo" has incompatible type "str"; expected "int" \ | ||
# E: Argument 4 to "foo" has incompatible type "str"; expected "int" | ||
|
||
[builtins fixtures/dict.pyi] | ||
|
||
[case testFunctoolsPartialGeneric] | ||
from typing import TypeVar | ||
import functools | ||
|
||
T = TypeVar("T") | ||
U = TypeVar("U") | ||
|
||
def foo(a: T, b: T) -> T: ... | ||
|
||
p1 = functools.partial(foo, 1) | ||
reveal_type(p1(2)) # N: Revealed type is "builtins.int" | ||
p1("a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int" | ||
|
||
p2 = functools.partial(foo, "a") | ||
p2(1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str" | ||
reveal_type(p2("a")) # N: Revealed type is "builtins.str" | ||
|
||
def bar(a: T, b: U) -> U: ... | ||
|
||
p3 = functools.partial(bar, 1) | ||
reveal_type(p3(2)) # N: Revealed type is "builtins.int" | ||
reveal_type(p3("a")) # N: Revealed type is "builtins.str" | ||
[builtins fixtures/dict.pyi] | ||
|
||
[case testFunctoolsPartialTypeVarTuple] | ||
import functools | ||
import typing | ||
Ts = typing.TypeVarTuple("Ts") | ||
def foo(fn: typing.Callable[[typing.Unpack[Ts]], None], /, *arg: typing.Unpack[Ts], kwarg: str) -> None: ... | ||
# just test that this doesn't crash | ||
functools.partial(foo, kwarg="asdf") | ||
[builtins fixtures/dict.pyi] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideas for test cases (it's not important to support all of these cases precisely, but they at least shouldn't crash or behave erratically):
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First, thank you for working on this issue 🙏 Pyright Tests:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the review! Added the various tests and some small fixes. The TypeGuard one surfaces a pre-existing bug, I can make a separate PR edit: already fixed by TypeIs PR. There's also some work to be done with overloads, but similarly gets into some pre-existing behaviour |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this could refer to a type variable type, so I wonder if this can leak type variables, in the target function is a generic one, and one of the provided arguments can be used to bind the type variable.