From 0df8cf532918f888610a5afd7bb88192712de984 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 16 Apr 2022 13:48:20 -0700 Subject: [PATCH] Support typing_extensions.overload (#12602) This always existed in typing_extensions, but was an alias for typing.overload. With python/typing#1140, it will actually make a difference at runtime which one you use. Note that this shouldn't change mypy's behaviour, since we alias typing_extensions.overload to typing.overload in typeshed, but this makes the logic less fragile. --- mypy/checker.py | 5 +- mypy/semanal.py | 4 +- mypy/stubgen.py | 23 +++--- mypy/stubtest.py | 3 +- mypy/types.py | 5 ++ test-data/unit/check-overloading.test | 18 +++++ test-data/unit/lib-stub/typing_extensions.pyi | 2 +- test-data/unit/stubgen.test | 79 ++++++++++++++++++- 8 files changed, 122 insertions(+), 17 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e53e306a7e5d..24f101421ff4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -37,7 +37,8 @@ UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, is_named_instance, union_items, TypeQuery, LiteralType, is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, - get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType + get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType, + OVERLOAD_NAMES, ) from mypy.sametypes import is_same_type from mypy.messages import ( @@ -3981,7 +3982,7 @@ def visit_decorator(self, e: Decorator) -> None: # may be different from the declared signature. sig: Type = self.function_type(e.func) for d in reversed(e.decorators): - if refers_to_fullname(d, 'typing.overload'): + if refers_to_fullname(d, OVERLOAD_NAMES): self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e) continue dec = self.expr_checker.accept(d) diff --git a/mypy/semanal.py b/mypy/semanal.py index 3ffc20cead77..1ec37309ce8e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -99,7 +99,7 @@ TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType, PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES, - ASSERT_TYPE_NAMES, is_named_instance, + ASSERT_TYPE_NAMES, OVERLOAD_NAMES, is_named_instance, ) from mypy.typeops import function_type, get_type_vars from mypy.type_visitor import TypeQuery @@ -835,7 +835,7 @@ def analyze_overload_sigs_and_impl( if isinstance(item, Decorator): callable = function_type(item.func, self.named_type('builtins.function')) assert isinstance(callable, CallableType) - if not any(refers_to_fullname(dec, 'typing.overload') + if not any(refers_to_fullname(dec, OVERLOAD_NAMES) for dec in item.decorators): if i == len(defn.items) - 1 and not self.is_stub_file: # Last item outside a stub is impl diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 5d8e6a57c212..eade0bbdc363 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -54,7 +54,7 @@ from collections import defaultdict from typing import ( - List, Dict, Tuple, Iterable, Mapping, Optional, Set, cast, + List, Dict, Tuple, Iterable, Mapping, Optional, Set, Union, cast, ) from typing_extensions import Final @@ -84,7 +84,7 @@ from mypy.options import Options as MypyOptions from mypy.types import ( Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance, - AnyType, get_proper_type + AnyType, get_proper_type, OVERLOAD_NAMES ) from mypy.visitor import NodeVisitor from mypy.find_sources import create_source_list, InvalidSourceList @@ -93,6 +93,10 @@ from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression from mypy.moduleinspect import ModuleInspect +TYPING_MODULE_NAMES: Final = ( + 'typing', + 'typing_extensions', +) # Common ways of naming package containing vendored modules. VENDOR_PACKAGES: Final = [ @@ -768,13 +772,15 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tup self.add_decorator('property') self.add_decorator('abc.abstractmethod') is_abstract = True - elif self.refers_to_fullname(name, 'typing.overload'): + elif self.refers_to_fullname(name, OVERLOAD_NAMES): self.add_decorator(name) self.add_typing_import('overload') is_overload = True return is_abstract, is_overload - def refers_to_fullname(self, name: str, fullname: str) -> bool: + def refers_to_fullname(self, name: str, fullname: Union[str, Tuple[str, ...]]) -> bool: + if isinstance(fullname, tuple): + return any(self.refers_to_fullname(name, fname) for fname in fullname) module, short = fullname.rsplit('.', 1) return (self.import_tracker.module_for.get(name) == module and (name == short or @@ -825,8 +831,8 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> expr.expr.name + '.coroutine', expr.expr.name) elif (isinstance(expr.expr, NameExpr) and - (expr.expr.name == 'typing' or - self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and + (expr.expr.name in TYPING_MODULE_NAMES or + self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES) and expr.name == 'overload'): self.import_tracker.require_name(expr.expr.name) self.add_decorator('%s.%s' % (expr.expr.name, 'overload')) @@ -1060,7 +1066,7 @@ def visit_import_from(self, o: ImportFrom) -> None: and name not in self.referenced_names and (not self._all_ or name in IGNORED_DUNDERS) and not is_private - and module not in ('abc', 'typing', 'asyncio')): + and module not in ('abc', 'asyncio') + TYPING_MODULE_NAMES): # An imported name that is never referenced in the module is assumed to be # exported, unless there is an explicit __all__. Note that we need to special # case 'abc' since some references are deleted during semantic analysis. @@ -1118,8 +1124,7 @@ def get_init(self, lvalue: str, rvalue: Expression, typename = self.print_annotation(annotation) if (isinstance(annotation, UnboundType) and not annotation.args and annotation.name == 'Final' and - self.import_tracker.module_for.get('Final') in ('typing', - 'typing_extensions')): + self.import_tracker.module_for.get('Final') in TYPING_MODULE_NAMES): # Final without type argument is invalid in stubs. final_arg = self.get_str_type_of_node(rvalue) typename += '[{}]'.format(final_arg) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 582c467ee2b0..7fa0f5937f99 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -912,9 +912,8 @@ def apply_decorator_to_funcitem( return None if decorator.fullname in ( "builtins.staticmethod", - "typing.overload", "abc.abstractmethod", - ): + ) or decorator.fullname in mypy.types.OVERLOAD_NAMES: return func if decorator.fullname == "builtins.classmethod": assert func.arguments[0].variable.name in ("cls", "metacls") diff --git a/mypy/types.py b/mypy/types.py index 1d0274f38330..213d8de7d8bb 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -137,6 +137,11 @@ 'typing_extensions.assert_type', ) +OVERLOAD_NAMES: Final = ( + 'typing.overload', + 'typing_extensions.overload', +) + # Attributes that can optionally be defined in the body of a subclass of # enum.Enum but are removed from the class __dict__ by EnumMeta. ENUM_REMOVED_PROPS: Final = ( diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 0376b62ab202..e2a87ea62a92 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -40,6 +40,24 @@ class A: pass class B: pass [builtins fixtures/isinstance.pyi] +[case testTypingExtensionsOverload] +from typing import Any +from typing_extensions import overload +@overload +def f(x: 'A') -> 'B': ... +@overload +def f(x: 'B') -> 'A': ... + +def f(x: Any) -> Any: + pass + +reveal_type(f(A())) # N: Revealed type is "__main__.B" +reveal_type(f(B())) # N: Revealed type is "__main__.A" + +class A: pass +class B: pass +[builtins fixtures/isinstance.pyi] + [case testOverloadNeedsImplementation] from typing import overload, Any @overload # E: An overloaded function outside a stub file must have an implementation diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi index 7ad334d6a24e..95f45f3b8947 100644 --- a/test-data/unit/lib-stub/typing_extensions.pyi +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -1,6 +1,6 @@ from typing import TypeVar, Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type from typing import TYPE_CHECKING as TYPE_CHECKING -from typing import NewType as NewType +from typing import NewType as NewType, overload as overload import sys diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index cbce46b35605..927cc5617c75 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -2461,6 +2461,50 @@ class A: def f(self, x: Tuple[int, int]) -> int: ... +@overload +def f(x: int, y: int) -> int: ... +@overload +def f(x: Tuple[int, int]) -> int: ... + +[case testOverload_fromTypingExtensionsImport] +from typing import Tuple, Union +from typing_extensions import overload + +class A: + @overload + def f(self, x: int, y: int) -> int: + ... + + @overload + def f(self, x: Tuple[int, int]) -> int: + ... + + def f(self, *args: Union[int, Tuple[int, int]]) -> int: + pass + +@overload +def f(x: int, y: int) -> int: + ... + +@overload +def f(x: Tuple[int, int]) -> int: + ... + +def f(*args: Union[int, Tuple[int, int]]) -> int: + pass + + +[out] +from typing import Tuple +from typing_extensions import overload + +class A: + @overload + def f(self, x: int, y: int) -> int: ... + @overload + def f(self, x: Tuple[int, int]) -> int: ... + + @overload def f(x: int, y: int) -> int: ... @overload @@ -2468,6 +2512,7 @@ def f(x: Tuple[int, int]) -> int: ... [case testOverload_importTyping] import typing +import typing_extensions class A: @typing.overload @@ -2506,9 +2551,21 @@ def f(x: typing.Tuple[int, int]) -> int: def f(*args: typing.Union[int, typing.Tuple[int, int]]) -> int: pass +@typing_extensions.overload +def g(x: int, y: int) -> int: + ... + +@typing_extensions.overload +def g(x: typing.Tuple[int, int]) -> int: + ... + +def g(*args: typing.Union[int, typing.Tuple[int, int]]) -> int: + pass + [out] import typing +import typing_extensions class A: @typing.overload @@ -2527,10 +2584,14 @@ class A: def f(x: int, y: int) -> int: ... @typing.overload def f(x: typing.Tuple[int, int]) -> int: ... - +@typing_extensions.overload +def g(x: int, y: int) -> int: ... +@typing_extensions.overload +def g(x: typing.Tuple[int, int]) -> int: ... [case testOverload_importTypingAs] import typing as t +import typing_extensions as te class A: @t.overload @@ -2570,8 +2631,20 @@ def f(*args: t.Union[int, t.Tuple[int, int]]) -> int: pass +@te.overload +def g(x: int, y: int) -> int: + ... + +@te.overload +def g(x: t.Tuple[int, int]) -> int: + ... + +def g(*args: t.Union[int, t.Tuple[int, int]]) -> int: + pass + [out] import typing as t +import typing_extensions as te class A: @t.overload @@ -2590,6 +2663,10 @@ class A: def f(x: int, y: int) -> int: ... @t.overload def f(x: t.Tuple[int, int]) -> int: ... +@te.overload +def g(x: int, y: int) -> int: ... +@te.overload +def g(x: t.Tuple[int, int]) -> int: ... [case testProtocol_semanal] from typing import Protocol, TypeVar