diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e8a2e501a452..4b0f5fe533d8 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1229,14 +1229,14 @@ def apply_function_plugin( assert callback is not None # Assume that caller ensures this return callback( FunctionContext( - formal_arg_types, - formal_arg_kinds, - callee.arg_names, - formal_arg_names, - callee.ret_type, - formal_arg_exprs, - context, - self.chk, + arg_types=formal_arg_types, + arg_kinds=formal_arg_kinds, + callee_arg_names=callee.arg_names, + arg_names=formal_arg_names, + default_return_type=callee.ret_type, + args=formal_arg_exprs, + context=context, + api=self.chk, ) ) else: @@ -1246,15 +1246,15 @@ def apply_function_plugin( object_type = get_proper_type(object_type) return method_callback( MethodContext( - object_type, - formal_arg_types, - formal_arg_kinds, - callee.arg_names, - formal_arg_names, - callee.ret_type, - formal_arg_exprs, - context, - self.chk, + type=object_type, + arg_types=formal_arg_types, + arg_kinds=formal_arg_kinds, + callee_arg_names=callee.arg_names, + arg_names=formal_arg_names, + default_return_type=callee.ret_type, + args=formal_arg_exprs, + context=context, + api=self.chk, ) ) diff --git a/mypy/fixup.py b/mypy/fixup.py index 849a6483d724..f2b5bc17d32e 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -239,6 +239,9 @@ def visit_instance(self, inst: Instance) -> None: a.accept(self) if inst.last_known_value is not None: inst.last_known_value.accept(self) + if inst.extra_attrs: + for v in inst.extra_attrs.attrs.values(): + v.accept(self) def visit_type_alias_type(self, t: TypeAliasType) -> None: type_ref = t.type_ref diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 170d3c85b5f9..3ad301a15f6c 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] return ctypes.array_constructor_callback elif fullname == "functools.singledispatch": return singledispatch.create_singledispatch_function_callback + elif fullname == "functools.partial": + import mypy.plugins.functools + + return mypy.plugins.functools.partial_new_callback return None @@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No return singledispatch.singledispatch_register_callback elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD: return singledispatch.call_singledispatch_function_after_register_argument + elif fullname == "functools.partial.__call__": + import mypy.plugins.functools + + return mypy.plugins.functools.partial_call_callback return None def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: @@ -155,12 +163,13 @@ def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], def get_class_decorator_hook_2( self, fullname: str ) -> Callable[[ClassDefContext], bool] | None: - from mypy.plugins import attrs, dataclasses, functools + import mypy.plugins.functools + from mypy.plugins import attrs, dataclasses if fullname in dataclasses.dataclass_makers: return dataclasses.dataclass_class_maker_callback - elif fullname in functools.functools_total_ordering_makers: - return functools.functools_total_ordering_maker_callback + elif fullname in mypy.plugins.functools.functools_total_ordering_makers: + return mypy.plugins.functools.functools_total_ordering_maker_callback elif fullname in attrs.attr_class_makers: return attrs.attr_class_maker_callback elif fullname in attrs.attr_dataclass_makers: diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 792ed6669503..81a3b4d96ef3 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -4,10 +4,22 @@ from typing import Final, NamedTuple +import mypy.checker import mypy.plugin -from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var +from mypy.argmap import map_actuals_to_formals +from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var from mypy.plugins.common import add_method_to_class -from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type +from mypy.types import ( + AnyType, + CallableType, + Instance, + Overloaded, + Type, + TypeOfAny, + UnboundType, + UninhabitedType, + get_proper_type, +) functools_total_ordering_makers: Final = {"functools.total_ordering"} @@ -102,3 +114,131 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | comparison_methods[name] = None return comparison_methods + + +def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """Infer a more precise return type for functools.partial""" + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals + return ctx.default_return_type + if len(ctx.arg_types) != 3: # fn, *args, **kwargs + return ctx.default_return_type + if len(ctx.arg_types[0]) != 1: + return ctx.default_return_type + + if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded): + # TODO: handle overloads, just fall back to whatever the non-plugin code does + return ctx.default_return_type + fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type) + if fn_type is None: + return ctx.default_return_type + + defaulted = fn_type.copy_modified( + arg_kinds=[ + ( + ArgKind.ARG_OPT + if k == ArgKind.ARG_POS + else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k) + ) + for k in fn_type.arg_kinds + ] + ) + if defaulted.line < 0: + # Make up a line number if we don't have one + defaulted.set_line(ctx.default_return_type) + + actual_args = [a for param in ctx.args[1:] for a in param] + actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param] + actual_arg_names = [a for param in ctx.arg_names[1:] for a in param] + actual_types = [a for param in ctx.arg_types[1:] for a in param] + + _, bound = ctx.api.expr_checker.check_call( + callee=defaulted, + args=actual_args, + arg_kinds=actual_arg_kinds, + arg_names=actual_arg_names, + context=defaulted, + ) + bound = get_proper_type(bound) + if not isinstance(bound, CallableType): + return ctx.default_return_type + + formal_to_actual = map_actuals_to_formals( + actual_kinds=actual_arg_kinds, + actual_names=actual_arg_names, + formal_kinds=fn_type.arg_kinds, + formal_names=fn_type.arg_names, + actual_arg_type=lambda i: actual_types[i], + ) + + partial_kinds = [] + partial_types = [] + partial_names = [] + # We need to fully apply any positional arguments (they cannot be respecified) + # However, keyword arguments can be respecified, so just give them a default + for i, actuals in enumerate(formal_to_actual): + if len(bound.arg_types) == len(fn_type.arg_types): + arg_type = bound.arg_types[i] + if isinstance(get_proper_type(arg_type), UninhabitedType): + arg_type = fn_type.arg_types[i] # bit of a hack + else: + # TODO: I assume that bound and fn_type have the same arguments. It appears this isn't + # true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple + arg_type = fn_type.arg_types[i] + + if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2): + partial_kinds.append(fn_type.arg_kinds[i]) + partial_types.append(arg_type) + partial_names.append(fn_type.arg_names[i]) + elif actuals: + if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals): + continue + kind = actual_arg_kinds[actuals[0]] + if kind == ArgKind.ARG_NAMED: + kind = ArgKind.ARG_NAMED_OPT + partial_kinds.append(kind) + partial_types.append(arg_type) + partial_names.append(fn_type.arg_names[i]) + + ret_type = bound.ret_type + if isinstance(get_proper_type(ret_type), UninhabitedType): + ret_type = fn_type.ret_type # same kind of hack as above + + partially_applied = fn_type.copy_modified( + arg_types=partial_types, + arg_kinds=partial_kinds, + arg_names=partial_names, + ret_type=ret_type, + ) + + ret = ctx.api.named_generic_type("functools.partial", [ret_type]) + ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied) + return ret + + +def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: + """Infer a more precise return type for functools.partial.__call__.""" + if ( + not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals + or not isinstance(ctx.type, Instance) + or ctx.type.type.fullname != "functools.partial" + or not ctx.type.extra_attrs + or "__mypy_partial" not in ctx.type.extra_attrs.attrs + ): + return ctx.default_return_type + + partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"] + if len(ctx.arg_types) != 2: # *args, **kwargs + return ctx.default_return_type + + args = [a for param in ctx.args for a in param] + arg_kinds = [a for param in ctx.arg_kinds for a in param] + arg_names = [a for param in ctx.arg_names for a in param] + + result = ctx.api.expr_checker.check_call( + callee=partial_type, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=ctx.context, + ) + return result[0] diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 5323bf2c57cb..f8a874005adb 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -378,11 +378,20 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem: return snapshot_simple_type(typ) def visit_instance(self, typ: Instance) -> SnapshotItem: + extra_attrs: SnapshotItem + if typ.extra_attrs: + extra_attrs = ( + tuple(sorted((k, v.accept(self)) for k, v in typ.extra_attrs.attrs.items())), + tuple(typ.extra_attrs.immutable), + ) + else: + extra_attrs = () return ( "Instance", encode_optional_str(typ.type.fullname), snapshot_types(typ.args), ("None",) if typ.last_known_value is None else snapshot_type(typ.last_known_value), + extra_attrs, ) def visit_type_var(self, typ: TypeVarType) -> SnapshotItem: diff --git a/mypy/types.py b/mypy/types.py index 5573dc9efe0e..0ef3803c5687 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1322,6 +1322,23 @@ def copy(self) -> ExtraAttrs: def __repr__(self) -> str: return f"ExtraAttrs({self.attrs!r}, {self.immutable!r}, {self.mod_name!r})" + def serialize(self) -> JsonDict: + return { + ".class": "ExtraAttrs", + "attrs": {k: v.serialize() for k, v in self.attrs.items()}, + "immutable": list(self.immutable), + "mod_name": self.mod_name, + } + + @classmethod + def deserialize(cls, data: JsonDict) -> ExtraAttrs: + assert data[".class"] == "ExtraAttrs" + return ExtraAttrs( + {k: deserialize_type(v) for k, v in data["attrs"].items()}, + set(data["immutable"]), + data["mod_name"], + ) + class Instance(ProperType): """An instance type of form C[T1, ..., Tn]. @@ -1434,6 +1451,7 @@ def serialize(self) -> JsonDict | str: data["args"] = [arg.serialize() for arg in self.args] if self.last_known_value is not None: data["last_known_value"] = self.last_known_value.serialize() + data["extra_attrs"] = self.extra_attrs.serialize() if self.extra_attrs else None return data @classmethod @@ -1452,6 +1470,8 @@ def deserialize(cls, data: JsonDict | str) -> Instance: inst.type_ref = data["type_ref"] # Will be fixed up by fixup.py later. if "last_known_value" in data: inst.last_known_value = LiteralType.deserialize(data["last_known_value"]) + if data.get("extra_attrs") is not None: + inst.extra_attrs = ExtraAttrs.deserialize(data["extra_attrs"]) return inst def copy_modified( @@ -1461,13 +1481,14 @@ def copy_modified( last_known_value: Bogus[LiteralType | None] = _dummy, ) -> Instance: new = Instance( - self.type, - args if args is not _dummy else self.args, - self.line, - self.column, + typ=self.type, + args=args if args is not _dummy else self.args, + line=self.line, + column=self.column, last_known_value=( last_known_value if last_known_value is not _dummy else self.last_known_value ), + extra_attrs=self.extra_attrs, ) # We intentionally don't copy the extra_attrs here, so they will be erased. new.can_be_true = self.can_be_true diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index e721a56850e1..5af5dfc8e469 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -144,3 +144,183 @@ def f(d: D[C]) -> None: d: D[int] # E: Type argument "int" of "D" must be a subtype of "C" [builtins fixtures/dict.pyi] + +[case testFunctoolsPartialBasic] +from typing import Callable +import functools + +def foo(a: int, b: str, c: int = 5) -> int: ... # N: "foo" defined here + +p1 = functools.partial(foo) +p1(1, "a", 3) # OK +p1(1, "a", c=3) # OK +p1(1, b="a", c=3) # OK + +reveal_type(p1) # N: Revealed type is "functools.partial[builtins.int]" + +def takes_callable_int(f: Callable[..., int]) -> None: ... +def takes_callable_str(f: Callable[..., str]) -> None: ... +takes_callable_int(p1) +takes_callable_str(p1) # E: Argument 1 to "takes_callable_str" has incompatible type "partial[int]"; expected "Callable[..., str]" \ + # N: "partial[int].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], int]" + +p2 = functools.partial(foo, 1) +p2("a") # OK +p2("a", 3) # OK +p2("a", c=3) # OK +p2(1, 3) # E: Argument 1 to "foo" has incompatible type "int"; expected "str" +p2(1, "a", 3) # E: Too many arguments for "foo" \ + # E: Argument 1 to "foo" has incompatible type "int"; expected "str" \ + # E: Argument 2 to "foo" has incompatible type "str"; expected "int" +p2(a=1, b="a", c=3) # E: Unexpected keyword argument "a" for "foo" + +p3 = functools.partial(foo, b="a") +p3(1) # OK +p3(1, c=3) # OK +p3(a=1) # OK +p3(1, b="a", c=3) # OK, keywords can be clobbered +p3(1, 3) # E: Too many positional arguments for "foo" \ + # E: Argument 2 to "foo" has incompatible type "int"; expected "str" + +functools.partial(foo, "a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int" +functools.partial(foo, b=1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str" +functools.partial(foo, a=1, b=2, c=3) # E: Argument 2 to "foo" has incompatible type "int"; expected "str" +functools.partial(1) # E: "int" not callable \ + # E: Argument 1 to "partial" has incompatible type "int"; expected "Callable[..., Never]" +[builtins fixtures/dict.pyi] + +[case testFunctoolsPartialStar] +import functools + +def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ... + +p1 = functools.partial(foo, 1, d="a", x=9) +p1("a", 2, 3, 4) # OK +p1("a", 2, 3, 4, d="a") # OK +p1("a", 2, 3, 4, "a") # E: Argument 5 to "foo" has incompatible type "str"; expected "int" +p1("a", 2, 3, 4, x="a") # E: Argument "x" to "foo" has incompatible type "str"; expected "int" + +p2 = functools.partial(foo, 1, "a") +p2(2, 3, 4, d="a") # OK +p2("a") # E: Missing named argument "d" for "foo" \ + # E: Argument 1 to "foo" has incompatible type "str"; expected "int" +p2(2, 3, 4) # E: Missing named argument "d" for "foo" + +functools.partial(foo, 1, "a", "b", "c", d="a") # E: Argument 3 to "foo" has incompatible type "str"; expected "int" \ + # E: Argument 4 to "foo" has incompatible type "str"; expected "int" + +def bar(*a: bytes, **k: int): + p1("a", 2, 3, 4, d="a", **k) + p1("a", d="a", **k) + p1("a", **k) # E: Argument 2 to "foo" has incompatible type "**Dict[str, int]"; expected "str" + p1(**k) # E: Argument 1 to "foo" has incompatible type "**Dict[str, int]"; expected "str" + p1(*a) # E: List or tuple expected as variadic arguments +[builtins fixtures/dict.pyi] + +[case testFunctoolsPartialGeneric] +from typing import TypeVar +import functools + +T = TypeVar("T") +U = TypeVar("U") + +def foo(a: T, b: T) -> T: ... + +p1 = functools.partial(foo, 1) +reveal_type(p1(2)) # N: Revealed type is "builtins.int" +p1("a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int" + +p2 = functools.partial(foo, "a") +p2(1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str" +reveal_type(p2("a")) # N: Revealed type is "builtins.str" + +def bar(a: T, b: U) -> U: ... + +p3 = functools.partial(bar, 1) +reveal_type(p3(2)) # N: Revealed type is "builtins.int" +reveal_type(p3("a")) # N: Revealed type is "builtins.str" +[builtins fixtures/dict.pyi] + +[case testFunctoolsPartialCallable] +from typing import Callable +import functools + +def main1(f: Callable[[int, str], int]) -> None: + p = functools.partial(f, 1) + p("a") # OK + p(1) # E: Argument 1 has incompatible type "int"; expected "str" + + functools.partial(f, a=1) # E: Unexpected keyword argument "a" + +class CallbackProto: + def __call__(self, a: int, b: str) -> int: ... + +def main2(f: CallbackProto) -> None: + p = functools.partial(f, b="a") + p(1) # OK + p("a") # E: Argument 1 to "__call__" of "CallbackProto" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testFunctoolsPartialOverload] +from typing import overload +import functools + +@overload +def foo(a: int, b: str) -> int: ... +@overload +def foo(a: str, b: int) -> str: ... +def foo(*a, **k): ... + +p1 = functools.partial(foo) +reveal_type(p1(1, "a")) # N: Revealed type is "builtins.int" +reveal_type(p1("a", 1)) # N: Revealed type is "builtins.int" +p1(1, 2) # TODO: false negative +p1("a", "b") # TODO: false negative +[builtins fixtures/dict.pyi] + +[case testFunctoolsPartialTypeGuard] +import functools +from typing_extensions import TypeGuard + +def is_str_list(val: list[object]) -> TypeGuard[list[str]]: ... # E: "list" is not subscriptable, use "typing.List" instead + +reveal_type(functools.partial(is_str_list, [1, 2, 3])) # N: Revealed type is "functools.partial[builtins.bool]" +reveal_type(functools.partial(is_str_list, [1, 2, 3])()) # N: Revealed type is "builtins.bool" +[builtins fixtures/dict.pyi] + +[case testFunctoolsPartialType] +import functools +from typing import Type + +class A: + def __init__(self, a: int, b: str) -> None: ... # N: "A" defined here + +p = functools.partial(A, 1) +reveal_type(p) # N: Revealed type is "functools.partial[__main__.A]" + +p("a") # OK +p(1) # E: Argument 1 to "A" has incompatible type "int"; expected "str" +p(z=1) # E: Unexpected keyword argument "z" for "A" + +def main(t: Type[A]) -> None: + p = functools.partial(t, 1) # E: "Type[A]" not callable + reveal_type(p) # N: Revealed type is "functools.partial[__main__.A]" + + p("a") # OK + p(1) # False negative + p(z=1) # False negative + +[builtins fixtures/dict.pyi] + +[case testFunctoolsPartialTypeVarTuple] +import functools +import typing +Ts = typing.TypeVarTuple("Ts") +def foo(fn: typing.Callable[[typing.Unpack[Ts]], None], /, *arg: typing.Unpack[Ts], kwarg: str) -> None: ... +p = functools.partial(foo, kwarg="asdf") + +def bar(a: int, b: str, c: float) -> None: ... +p(bar, 1, "a", 3.0) # OK +p(bar, 1, "a", 3.0, kwarg="asdf") # OK +p(bar, 1, "a", "b") # E: Argument 1 to "foo" has incompatible type "Callable[[int, str, float], None]"; expected "Callable[[int, str, str], None]" +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index a7f4fafc579e..ead896b8e458 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -6575,6 +6575,67 @@ class TheClass: [out2] tmp/a.py:3: note: Revealed type is "def (value: builtins.object) -> lib.TheClass.pyenum@6" + +[case testIncrementalFunctoolsPartial] +import a + +[file a.py] +from typing import Callable +from partial import p1, p2 + +p1(1, "a", 3) # OK +p1(1, "a", c=3) # OK +p1(1, b="a", c=3) # OK + +reveal_type(p1) + +def takes_callable_int(f: Callable[..., int]) -> None: ... +def takes_callable_str(f: Callable[..., str]) -> None: ... +takes_callable_int(p1) +takes_callable_str(p1) + +p2("a") # OK +p2("a", 3) # OK +p2("a", c=3) # OK +p2(1, 3) +p2(1, "a", 3) +p2(a=1, b="a", c=3) + +[file a.py.2] +from typing import Callable +from partial import p3 + +p3(1) # OK +p3(1, c=3) # OK +p3(a=1) # OK +p3(1, b="a", c=3) # OK, keywords can be clobbered +p3(1, 3) + +[file partial.py] +from typing import Callable +import functools + +def foo(a: int, b: str, c: int = 5) -> int: ... + +p1 = functools.partial(foo) +p2 = functools.partial(foo, 1) +p3 = functools.partial(foo, b="a") +[builtins fixtures/dict.pyi] +[out] +tmp/a.py:8: note: Revealed type is "functools.partial[builtins.int]" +tmp/a.py:13: error: Argument 1 to "takes_callable_str" has incompatible type "partial[int]"; expected "Callable[..., str]" +tmp/a.py:13: note: "partial[int].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], int]" +tmp/a.py:18: error: Argument 1 to "foo" has incompatible type "int"; expected "str" +tmp/a.py:19: error: Too many arguments for "foo" +tmp/a.py:19: error: Argument 1 to "foo" has incompatible type "int"; expected "str" +tmp/a.py:19: error: Argument 2 to "foo" has incompatible type "str"; expected "int" +tmp/a.py:20: error: Unexpected keyword argument "a" for "foo" +tmp/partial.py:4: note: "foo" defined here +[out2] +tmp/a.py:8: error: Too many positional arguments for "foo" +tmp/a.py:8: error: Argument 2 to "foo" has incompatible type "int"; expected "str" + + [case testStartUsingTypeGuard] import a [file a.py] diff --git a/test-data/unit/lib-stub/functools.pyi b/test-data/unit/lib-stub/functools.pyi index e665b2bad0c2..b8d47e1da2b5 100644 --- a/test-data/unit/lib-stub/functools.pyi +++ b/test-data/unit/lib-stub/functools.pyi @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar, Callable, Any, Mapping, overload +from typing import Generic, TypeVar, Callable, Any, Mapping, Self, overload _T = TypeVar("_T") @@ -33,3 +33,7 @@ class cached_property(Generic[_T]): def __get__(self, instance: object, owner: type[Any] | None = ...) -> _T: ... def __set_name__(self, owner: type[Any], name: str) -> None: ... def __class_getitem__(cls, item: Any) -> Any: ... + +class partial(Generic[_T]): + def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ... + def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...