Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for functools.partial #16939

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
34 changes: 17 additions & 17 deletions mypy/checkexpr.py
Expand Up @@ -1216,14 +1216,14 @@ def apply_function_plugin(
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)
else:
Expand All @@ -1233,15 +1233,15 @@ def apply_function_plugin(
object_type = get_proper_type(object_type)
return method_callback(
MethodContext(
object_type,
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
type=object_type,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)

Expand Down
15 changes: 12 additions & 3 deletions mypy/plugins/default.py
Expand Up @@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return ctypes.array_constructor_callback
elif fullname == "functools.singledispatch":
return singledispatch.create_singledispatch_function_callback
elif fullname == "functools.partial":
import mypy.plugins.functools

return mypy.plugins.functools.partial_new_callback

return None

Expand Down Expand Up @@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return singledispatch.singledispatch_register_callback
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_after_register_argument
elif fullname == "functools.partial.__call__":
import mypy.plugins.functools

return mypy.plugins.functools.partial_call_callback
return None

def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
Expand Down Expand Up @@ -155,12 +163,13 @@ def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext],
def get_class_decorator_hook_2(
self, fullname: str
) -> Callable[[ClassDefContext], bool] | None:
from mypy.plugins import attrs, dataclasses, functools
import mypy.plugins.functools
from mypy.plugins import attrs, dataclasses

if fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
elif fullname in functools.functools_total_ordering_makers:
return functools.functools_total_ordering_maker_callback
elif fullname in mypy.plugins.functools.functools_total_ordering_makers:
return mypy.plugins.functools.functools_total_ordering_maker_callback
elif fullname in attrs.attr_class_makers:
return attrs.attr_class_maker_callback
elif fullname in attrs.attr_dataclass_makers:
Expand Down
137 changes: 135 additions & 2 deletions mypy/plugins/functools.py
Expand Up @@ -4,10 +4,21 @@

from typing import Final, NamedTuple

import mypy.checker
import mypy.plugin
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
from mypy.types import (
AnyType,
CallableType,
Instance,
Type,
TypeOfAny,
UnboundType,
UninhabitedType,
get_proper_type,
)

functools_total_ordering_makers: Final = {"functools.total_ordering"}

Expand Down Expand Up @@ -102,3 +113,125 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
comparison_methods[name] = None

return comparison_methods


def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.partial"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
return ctx.default_return_type
if len(ctx.arg_types[0]) != 1:
return ctx.default_return_type

fn_type = get_proper_type(ctx.arg_types[0][0])
if not isinstance(fn_type, CallableType):
return ctx.default_return_type

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
]
)

actual_args = [a for param in ctx.args[1:] for a in param]
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]

_, bound = ctx.api.expr_checker.check_call(
callee=defaulted,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=defaulted,
)
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
return ctx.default_return_type

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

partial_kinds = []
partial_types = []
partial_names = []
# We need to fully apply any positional arguments (they cannot be respecified)
# However, keyword arguments can be respecified, so just give them a default
for i, actuals in enumerate(formal_to_actual):
if len(bound.arg_types) == len(fn_type.arg_types):
arg_type = bound.arg_types[i]
if isinstance(get_proper_type(arg_type), UninhabitedType):
arg_type = fn_type.arg_types[i] # bit of a hack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this could refer to a type variable type, so I wonder if this can leak type variables, in the target function is a generic one, and one of the provided arguments can be used to bind the type variable.

else:
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
arg_type = fn_type.arg_types[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.


if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
partial_kinds.append(fn_type.arg_kinds[i])
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
elif actuals:
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
continue
kind = actual_arg_kinds[actuals[0]]
if kind == ArgKind.ARG_NAMED:
kind = ArgKind.ARG_NAMED_OPT
partial_kinds.append(kind)
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])

ret_type = bound.ret_type
if isinstance(get_proper_type(ret_type), UninhabitedType):
ret_type = fn_type.ret_type # same kind of hack as above
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above -- it seems that this might leak type variables. If that is the case, it would probably be better to fall back to the default return type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't leak type variables, because partially_applied has the same variables as fn_type (from copy_modified). partially_applied just remains generic. See the test cases in testFunctoolsPartialGeneric. I can add some more comments to the code

But let me know if I'm off the mark!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, and this is good already.


partially_applied = fn_type.copy_modified(
arg_types=partial_types,
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
)

ret = ctx.api.named_generic_type("functools.partial", [ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_attrs is currently a bit poorly thought out, so this will cause some problems. First, it isn't serialized, so storing the additional information there breaks in incremental mode. For example, consider a module like this:

...
p = partial(f, 1)

Now if we import p from another module, the second run which uses serialized data will produce a different type for p compared to the first run.

I'm not sure what is the best way to fix this. Probably the simplest option would be to serialize extra_attrs. It seems that everything we put there can be serialized -- it just hasn't been implemented.

Also it would be good to have an incremental mode test case.

It looks like mypy daemon doesn't keep track of extra_attrs. mypy.server.astdiff should look into extra_attrs to detect changes in extra_attrs, as these may need to be propagated.

return ret


def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Infer a more precise return type for functools.partial.__call__."""
if (
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
or not isinstance(ctx.type, Instance)
or ctx.type.type.fullname != "functools.partial"
or not ctx.type.extra_attrs
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
):
return ctx.default_return_type

partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

args = [a for param in ctx.args for a in param]
arg_kinds = [a for param in ctx.arg_kinds for a in param]
arg_names = [a for param in ctx.arg_names for a in param]

result = ctx.api.expr_checker.check_call(
callee=partial_type,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=ctx.context,
)
return result[0]
9 changes: 5 additions & 4 deletions mypy/types.py
Expand Up @@ -1495,13 +1495,14 @@ def copy_modified(
last_known_value: Bogus[LiteralType | None] = _dummy,
) -> Instance:
new = Instance(
self.type,
args if args is not _dummy else self.args,
self.line,
self.column,
typ=self.type,
args=args if args is not _dummy else self.args,
line=self.line,
column=self.column,
last_known_value=(
last_known_value if last_known_value is not _dummy else self.last_known_value
),
extra_attrs=self.extra_attrs,
)
# We intentionally don't copy the extra_attrs here, so they will be erased.
new.can_be_true = self.can_be_true
Expand Down
96 changes: 96 additions & 0 deletions test-data/unit/check-functools.test
Expand Up @@ -144,3 +144,99 @@ def f(d: D[C]) -> None:

d: D[int] # E: Type argument "int" of "D" must be a subtype of "C"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialBasic]
from typing import Callable
import functools

def foo(a: int, b: str, c: int = 5) -> int: ... # N: "foo" defined here

p1 = functools.partial(foo)
p1(1, "a", 3) # OK
p1(1, "a", c=3) # OK
p1(1, b="a", c=3) # OK
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using reveal_type(p1) etc. in tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added one of these, it's not very useful because it'll just show you partial[return_type]... would need a custom mypy type / some very involved generics logic / non-standard features to make it show up in the reveal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah makes sense. The current approach seems fine then.


def takes_callable_int(f: Callable[..., int]) -> None: ...
def takes_callable_str(f: Callable[..., str]) -> None: ...
takes_callable_int(p1)
takes_callable_str(p1) # E: Argument 1 to "takes_callable_str" has incompatible type "partial[int]"; expected "Callable[..., str]" \
# N: "partial[int].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], int]"

p2 = functools.partial(foo, 1)
p2("a") # OK
p2("a", 3) # OK
p2("a", c=3) # OK
p2(1, 3) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
p2(1, "a", 3) # E: Too many arguments for "foo" \
# E: Argument 1 to "foo" has incompatible type "int"; expected "str" \
# E: Argument 2 to "foo" has incompatible type "str"; expected "int"
p2(a=1, b="a", c=3) # E: Unexpected keyword argument "a" for "foo"

p3 = functools.partial(foo, b="a")
p3(1) # OK
p3(1, c=3) # OK
p3(a=1) # OK
p3(1, b="a", c=3) # OK, keywords can be clobbered
p3(1, 3) # E: Too many positional arguments for "foo" \
# E: Argument 2 to "foo" has incompatible type "int"; expected "str"

functools.partial(foo, "a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
functools.partial(foo, b=1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
functools.partial(foo, a=1, b=2, c=3) # E: Argument 2 to "foo" has incompatible type "int"; expected "str"
functools.partial(1) # E: Argument 1 to "partial" has incompatible type "int"; expected "Callable[..., Never]"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialStar]
import functools

def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ...

p1 = functools.partial(foo, 1, d="a", x=9)
p1("a", 2, 3, 4) # OK
p1("a", 2, 3, 4, d="a") # OK
p1("a", 2, 3, 4, "a") # E: Argument 5 to "foo" has incompatible type "str"; expected "int"
p1("a", 2, 3, 4, x="a") # E: Argument "x" to "foo" has incompatible type "str"; expected "int"

p2 = functools.partial(foo, 1, "a")
p2(2, 3, 4, d="a") # OK
p2("a") # E: Missing named argument "d" for "foo" \
# E: Argument 1 to "foo" has incompatible type "str"; expected "int"
p2(2, 3, 4) # E: Missing named argument "d" for "foo"

functools.partial(foo, 1, "a", "b", "c", d="a") # E: Argument 3 to "foo" has incompatible type "str"; expected "int" \
# E: Argument 4 to "foo" has incompatible type "str"; expected "int"

[builtins fixtures/dict.pyi]

[case testFunctoolsPartialGeneric]
from typing import TypeVar
import functools

T = TypeVar("T")
U = TypeVar("U")

def foo(a: T, b: T) -> T: ...

p1 = functools.partial(foo, 1)
reveal_type(p1(2)) # N: Revealed type is "builtins.int"
p1("a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"

p2 = functools.partial(foo, "a")
p2(1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
reveal_type(p2("a")) # N: Revealed type is "builtins.str"

def bar(a: T, b: U) -> U: ...

p3 = functools.partial(bar, 1)
reveal_type(p3(2)) # N: Revealed type is "builtins.int"
reveal_type(p3("a")) # N: Revealed type is "builtins.str"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialTypeVarTuple]
import functools
import typing
Ts = typing.TypeVarTuple("Ts")
def foo(fn: typing.Callable[[typing.Unpack[Ts]], None], /, *arg: typing.Unpack[Ts], kwarg: str) -> None: ...
# just test that this doesn't crash
functools.partial(foo, kwarg="asdf")
[builtins fixtures/dict.pyi]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideas for test cases (it's not important to support all of these cases precisely, but they at least shouldn't crash or behave erratically):

  • Test with a value of type[x] as an argument to partial.
  • Test with a class reference as an argument to partial (e.g. partial(Foo)). This is different from the above.
  • Test calling partial with *args and/or **kwargs.
  • Test passing Callable[..., Foo] to partial (with explicit ...).
  • Test passing a type guard to partial.
  • Test passing an instance that has a __call__ method to partial.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review! Added the various tests and some small fixes. The TypeGuard one surfaces a pre-existing bug, I can make a separate PR edit: already fixed by TypeIs PR. There's also some work to be done with overloads, but similarly gets into some pre-existing behaviour

6 changes: 5 additions & 1 deletion test-data/unit/lib-stub/functools.pyi
@@ -1,4 +1,4 @@
from typing import Generic, TypeVar, Callable, Any, Mapping, overload
from typing import Generic, TypeVar, Callable, Any, Mapping, Self, overload

_T = TypeVar("_T")

Expand Down Expand Up @@ -33,3 +33,7 @@ class cached_property(Generic[_T]):
def __get__(self, instance: object, owner: type[Any] | None = ...) -> _T: ...
def __set_name__(self, owner: type[Any], name: str) -> None: ...
def __class_getitem__(cls, item: Any) -> Any: ...

class partial(Generic[_T]):
def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ...
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...