diff --git a/CHANGELOG.md b/CHANGELOG.md index b6721cd2..d6c14835 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ - Add `typing_extensions.NamedTuple`, allowing for generic `NamedTuple`s on Python <3.11 (backport from python/cpython#92027, by Serhiy Storchaka). Patch by Alex Waygood (@AlexWaygood). +- Adjust `typing_extensions.TypedDict` to allow for generic `TypedDict`s on + Python <3.11 (backport from python/cpython#27663, by Samodya Abey). Patch by + Alex Waygood (@AlexWaygood). # Release 4.2.0 (April 17, 2022) diff --git a/README.md b/README.md index 79112d1c..c960e663 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,8 @@ Certain objects were changed after they were added to `typing`, and - `TypedDict` does not store runtime information about which (if any) keys are non-required in Python 3.8, and does not honor the `total` keyword with old-style `TypedDict()` in Python - 3.9.0 and 3.9.1. + 3.9.0 and 3.9.1. `TypedDict` also does not support multiple inheritance + with `typing.Generic` on Python <3.11. - `get_origin` and `get_args` lack support for `Annotated` in Python 3.8 and lack support for `ParamSpecArgs` and `ParamSpecKwargs` in 3.9. diff --git a/src/_typed_dict_test_helper.py b/src/_typed_dict_test_helper.py new file mode 100644 index 00000000..396a94fe --- /dev/null +++ b/src/_typed_dict_test_helper.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from typing import Generic, Optional, T +from typing_extensions import TypedDict + + +class FooGeneric(TypedDict, Generic[T]): + a: Optional[T] diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 407a4860..ee498e56 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -29,6 +29,7 @@ from typing_extensions import assert_type, get_type_hints, get_origin, get_args from typing_extensions import clear_overloads, get_overloads, overload from typing_extensions import NamedTuple +from _typed_dict_test_helper import FooGeneric # Flags used to mark tests that only apply after a specific # version of the typing module. @@ -1664,6 +1665,15 @@ class CustomProtocolWithoutInitB(Protocol): self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__) +class Point2DGeneric(Generic[T], TypedDict): + a: T + b: T + + +class BarGeneric(FooGeneric[T], total=False): + b: int + + class TypedDictTests(BaseTestCase): def test_basics_iterable_syntax(self): @@ -1769,7 +1779,9 @@ def test_pickle(self): global EmpD # pickle wants to reference the class by name EmpD = TypedDict('EmpD', name=str, id=int) jane = EmpD({'name': 'jane', 'id': 37}) + point = Point2DGeneric(a=5.0, b=3.0) for proto in range(pickle.HIGHEST_PROTOCOL + 1): + # Test non-generic TypedDict z = pickle.dumps(jane, proto) jane2 = pickle.loads(z) self.assertEqual(jane2, jane) @@ -1777,6 +1789,14 @@ def test_pickle(self): ZZ = pickle.dumps(EmpD, proto) EmpDnew = pickle.loads(ZZ) self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane) + # and generic TypedDict + y = pickle.dumps(point, proto) + point2 = pickle.loads(y) + self.assertEqual(point, point2) + self.assertEqual(point2, {'a': 5.0, 'b': 3.0}) + YY = pickle.dumps(Point2DGeneric, proto) + Point2DGenericNew = pickle.loads(YY) + self.assertEqual(Point2DGenericNew({'a': 5.0, 'b': 3.0}), point) def test_optional(self): EmpD = TypedDict('EmpD', name=str, id=int) @@ -1854,6 +1874,124 @@ class PointDict3D(PointDict2D, total=False): assert is_typeddict(PointDict2D) is True assert is_typeddict(PointDict3D) is True + def test_get_type_hints_generic(self): + self.assertEqual( + get_type_hints(BarGeneric), + {'a': typing.Optional[T], 'b': int} + ) + + class FooBarGeneric(BarGeneric[int]): + c: str + + self.assertEqual( + get_type_hints(FooBarGeneric), + {'a': typing.Optional[T], 'b': int, 'c': str} + ) + + def test_generic_inheritance(self): + class A(TypedDict, Generic[T]): + a: T + + self.assertEqual(A.__bases__, (Generic, dict)) + self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T])) + self.assertEqual(A.__mro__, (A, Generic, dict, object)) + self.assertEqual(A.__parameters__, (T,)) + self.assertEqual(A[str].__parameters__, ()) + self.assertEqual(A[str].__args__, (str,)) + + class A2(Generic[T], TypedDict): + a: T + + self.assertEqual(A2.__bases__, (Generic, dict)) + self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict)) + self.assertEqual(A2.__mro__, (A2, Generic, dict, object)) + self.assertEqual(A2.__parameters__, (T,)) + self.assertEqual(A2[str].__parameters__, ()) + self.assertEqual(A2[str].__args__, (str,)) + + class B(A[KT], total=False): + b: KT + + self.assertEqual(B.__bases__, (Generic, dict)) + self.assertEqual(B.__orig_bases__, (A[KT],)) + self.assertEqual(B.__mro__, (B, Generic, dict, object)) + self.assertEqual(B.__parameters__, (KT,)) + self.assertEqual(B.__total__, False) + self.assertEqual(B.__optional_keys__, frozenset(['b'])) + self.assertEqual(B.__required_keys__, frozenset(['a'])) + + self.assertEqual(B[str].__parameters__, ()) + self.assertEqual(B[str].__args__, (str,)) + self.assertEqual(B[str].__origin__, B) + + class C(B[int]): + c: int + + self.assertEqual(C.__bases__, (Generic, dict)) + self.assertEqual(C.__orig_bases__, (B[int],)) + self.assertEqual(C.__mro__, (C, Generic, dict, object)) + self.assertEqual(C.__parameters__, ()) + self.assertEqual(C.__total__, True) + self.assertEqual(C.__optional_keys__, frozenset(['b'])) + self.assertEqual(C.__required_keys__, frozenset(['a', 'c'])) + assert C.__annotations__ == { + 'a': T, + 'b': KT, + 'c': int, + } + with self.assertRaises(TypeError): + C[str] + + + class Point3D(Point2DGeneric[T], Generic[T, KT]): + c: KT + + self.assertEqual(Point3D.__bases__, (Generic, dict)) + self.assertEqual(Point3D.__orig_bases__, (Point2DGeneric[T], Generic[T, KT])) + self.assertEqual(Point3D.__mro__, (Point3D, Generic, dict, object)) + self.assertEqual(Point3D.__parameters__, (T, KT)) + self.assertEqual(Point3D.__total__, True) + self.assertEqual(Point3D.__optional_keys__, frozenset()) + self.assertEqual(Point3D.__required_keys__, frozenset(['a', 'b', 'c'])) + assert Point3D.__annotations__ == { + 'a': T, + 'b': T, + 'c': KT, + } + self.assertEqual(Point3D[int, str].__origin__, Point3D) + + with self.assertRaises(TypeError): + Point3D[int] + + with self.assertRaises(TypeError): + class Point3D(Point2DGeneric[T], Generic[KT]): + c: KT + + def test_implicit_any_inheritance(self): + class A(TypedDict, Generic[T]): + a: T + + class B(A[KT], total=False): + b: KT + + class WithImplicitAny(B): + c: int + + self.assertEqual(WithImplicitAny.__bases__, (Generic, dict,)) + self.assertEqual(WithImplicitAny.__mro__, (WithImplicitAny, Generic, dict, object)) + # Consistent with GenericTests.test_implicit_any + self.assertEqual(WithImplicitAny.__parameters__, ()) + self.assertEqual(WithImplicitAny.__total__, True) + self.assertEqual(WithImplicitAny.__optional_keys__, frozenset(['b'])) + self.assertEqual(WithImplicitAny.__required_keys__, frozenset(['a', 'c'])) + assert WithImplicitAny.__annotations__ == { + 'a': T, + 'b': KT, + 'c': int, + } + with self.assertRaises(TypeError): + WithImplicitAny[str] + class AnnotatedTests(BaseTestCase): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 3b9a39cf..31d3564e 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -381,6 +381,46 @@ def _is_callable_members_only(cls): return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) +def _maybe_adjust_parameters(cls): + """Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__. + + The contents of this function are very similar + to logic found in typing.Generic.__init_subclass__ + on the CPython main branch. + """ + tvars = [] + if '__orig_bases__' in cls.__dict__: + tvars = typing._collect_type_vars(cls.__orig_bases__) + # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...] and/or Protocol[...]. + gvars = None + for base in cls.__orig_bases__: + if (isinstance(base, typing._GenericAlias) and + base.__origin__ in (typing.Generic, Protocol)): + # for error messages + the_base = base.__origin__.__name__ + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...]" + " and/or Protocol[...] multiple types.") + gvars = base.__parameters__ + if gvars is None: + gvars = tvars + else: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError(f"Some type variables ({s_vars}) are" + f" not listed in {the_base}[{s_args}]") + tvars = gvars + cls.__parameters__ = tuple(tvars) + + # 3.8+ if hasattr(typing, 'Protocol'): Protocol = typing.Protocol @@ -477,43 +517,13 @@ def __class_getitem__(cls, params): return typing._GenericAlias(cls, params) def __init_subclass__(cls, *args, **kwargs): - tvars = [] if '__orig_bases__' in cls.__dict__: error = typing.Generic in cls.__orig_bases__ else: error = typing.Generic in cls.__bases__ if error: raise TypeError("Cannot inherit from plain Generic") - if '__orig_bases__' in cls.__dict__: - tvars = typing._collect_type_vars(cls.__orig_bases__) - # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. - # If found, tvars must be a subset of it. - # If not found, tvars is it. - # Also check for and reject plain Generic, - # and reject multiple Generic[...] and/or Protocol[...]. - gvars = None - for base in cls.__orig_bases__: - if (isinstance(base, typing._GenericAlias) and - base.__origin__ in (typing.Generic, Protocol)): - # for error messages - the_base = base.__origin__.__name__ - if gvars is not None: - raise TypeError( - "Cannot inherit from Generic[...]" - " and/or Protocol[...] multiple types.") - gvars = base.__parameters__ - if gvars is None: - gvars = tvars - else: - tvarset = set(tvars) - gvarset = set(gvars) - if not tvarset <= gvarset: - s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) - s_args = ', '.join(str(g) for g in gvars) - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {the_base}[{s_args}]") - tvars = gvars - cls.__parameters__ = tuple(tvars) + _maybe_adjust_parameters(cls) # Determine if this is a protocol or a concrete subclass. if not cls.__dict__.get('_is_protocol', None): @@ -614,6 +624,7 @@ def __index__(self) -> int: # keyword with old-style TypedDict(). See https://bugs.python.org/issue42059 # The standard library TypedDict below Python 3.11 does not store runtime # information about optional and required keys when using Required or NotRequired. + # Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11. TypedDict = typing.TypedDict _TypedDictMeta = typing._TypedDictMeta is_typeddict = typing.is_typeddict @@ -696,8 +707,16 @@ def __new__(cls, name, bases, ns, total=True): # Subclasses and instances of TypedDict return actual dictionaries # via _dict_new. ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new + # Don't insert typing.Generic into __bases__ here, + # or Generic.__init_subclass__ will raise TypeError + # in the super().__new__() call. + # Instead, monkey-patch __bases__ onto the class after it's been created. tp_dict = super().__new__(cls, name, (dict,), ns) + if any(issubclass(base, typing.Generic) for base in bases): + tp_dict.__bases__ = (typing.Generic, dict) + _maybe_adjust_parameters(tp_dict) + annotations = {} own_annotations = ns.get('__annotations__', {}) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"