Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport generic TypedDicts #46

Merged
merged 4 commits into from May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -3,6 +3,9 @@
- Add `typing_extensions.NamedTuple`, allowing for generic `NamedTuple`s on
Python <3.11 (backport from python/cpython#92027, by Serhiy Storchaka). Patch
by Alex Waygood (@AlexWaygood).
- Adjust `typing_extensions.TypedDict` to allow for generic `TypedDict`s on
Python <3.11 (backport from python/cpython#27663, by Samodya Abey). Patch by
Alex Waygood (@AlexWaygood).

# Release 4.2.0 (April 17, 2022)

Expand Down
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -113,7 +113,8 @@ Certain objects were changed after they were added to `typing`, and
- `TypedDict` does not store runtime information
about which (if any) keys are non-required in Python 3.8, and does not
honor the `total` keyword with old-style `TypedDict()` in Python
3.9.0 and 3.9.1.
3.9.0 and 3.9.1. `TypedDict` also does not support multiple inheritance
with `typing.Generic` on Python <3.11.
- `get_origin` and `get_args` lack support for `Annotated` in
Python 3.8 and lack support for `ParamSpecArgs` and `ParamSpecKwargs`
in 3.9.
Expand Down
8 changes: 8 additions & 0 deletions src/_typed_dict_test_helper.py
@@ -0,0 +1,8 @@
from __future__ import annotations

from typing import Generic, Optional, T
from typing_extensions import TypedDict


class FooGeneric(TypedDict, Generic[T]):
a: Optional[T]
138 changes: 138 additions & 0 deletions src/test_typing_extensions.py
Expand Up @@ -29,6 +29,7 @@
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
from typing_extensions import clear_overloads, get_overloads, overload
from typing_extensions import NamedTuple
from _typed_dict_test_helper import FooGeneric

# Flags used to mark tests that only apply after a specific
# version of the typing module.
Expand Down Expand Up @@ -1664,6 +1665,15 @@ class CustomProtocolWithoutInitB(Protocol):
self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__)


class Point2DGeneric(Generic[T], TypedDict):
a: T
b: T


class BarGeneric(FooGeneric[T], total=False):
b: int


class TypedDictTests(BaseTestCase):

def test_basics_iterable_syntax(self):
Expand Down Expand Up @@ -1769,14 +1779,24 @@ def test_pickle(self):
global EmpD # pickle wants to reference the class by name
EmpD = TypedDict('EmpD', name=str, id=int)
jane = EmpD({'name': 'jane', 'id': 37})
point = Point2DGeneric(a=5.0, b=3.0)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
# Test non-generic TypedDict
z = pickle.dumps(jane, proto)
jane2 = pickle.loads(z)
self.assertEqual(jane2, jane)
self.assertEqual(jane2, {'name': 'jane', 'id': 37})
ZZ = pickle.dumps(EmpD, proto)
EmpDnew = pickle.loads(ZZ)
self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
# and generic TypedDict
y = pickle.dumps(point, proto)
point2 = pickle.loads(y)
self.assertEqual(point, point2)
self.assertEqual(point2, {'a': 5.0, 'b': 3.0})
YY = pickle.dumps(Point2DGeneric, proto)
Point2DGenericNew = pickle.loads(YY)
self.assertEqual(Point2DGenericNew({'a': 5.0, 'b': 3.0}), point)

def test_optional(self):
EmpD = TypedDict('EmpD', name=str, id=int)
Expand Down Expand Up @@ -1854,6 +1874,124 @@ class PointDict3D(PointDict2D, total=False):
assert is_typeddict(PointDict2D) is True
assert is_typeddict(PointDict3D) is True

def test_get_type_hints_generic(self):
self.assertEqual(
get_type_hints(BarGeneric),
{'a': typing.Optional[T], 'b': int}
)

class FooBarGeneric(BarGeneric[int]):
c: str

self.assertEqual(
get_type_hints(FooBarGeneric),
{'a': typing.Optional[T], 'b': int, 'c': str}
)

def test_generic_inheritance(self):
class A(TypedDict, Generic[T]):
a: T

self.assertEqual(A.__bases__, (Generic, dict))
self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T]))
self.assertEqual(A.__mro__, (A, Generic, dict, object))
self.assertEqual(A.__parameters__, (T,))
self.assertEqual(A[str].__parameters__, ())
self.assertEqual(A[str].__args__, (str,))

class A2(Generic[T], TypedDict):
a: T

self.assertEqual(A2.__bases__, (Generic, dict))
self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict))
self.assertEqual(A2.__mro__, (A2, Generic, dict, object))
self.assertEqual(A2.__parameters__, (T,))
self.assertEqual(A2[str].__parameters__, ())
self.assertEqual(A2[str].__args__, (str,))

class B(A[KT], total=False):
b: KT

self.assertEqual(B.__bases__, (Generic, dict))
self.assertEqual(B.__orig_bases__, (A[KT],))
self.assertEqual(B.__mro__, (B, Generic, dict, object))
self.assertEqual(B.__parameters__, (KT,))
self.assertEqual(B.__total__, False)
self.assertEqual(B.__optional_keys__, frozenset(['b']))
self.assertEqual(B.__required_keys__, frozenset(['a']))

self.assertEqual(B[str].__parameters__, ())
self.assertEqual(B[str].__args__, (str,))
self.assertEqual(B[str].__origin__, B)

class C(B[int]):
c: int

self.assertEqual(C.__bases__, (Generic, dict))
self.assertEqual(C.__orig_bases__, (B[int],))
self.assertEqual(C.__mro__, (C, Generic, dict, object))
self.assertEqual(C.__parameters__, ())
self.assertEqual(C.__total__, True)
self.assertEqual(C.__optional_keys__, frozenset(['b']))
self.assertEqual(C.__required_keys__, frozenset(['a', 'c']))
assert C.__annotations__ == {
'a': T,
'b': KT,
'c': int,
}
with self.assertRaises(TypeError):
C[str]


class Point3D(Point2DGeneric[T], Generic[T, KT]):
c: KT

self.assertEqual(Point3D.__bases__, (Generic, dict))
self.assertEqual(Point3D.__orig_bases__, (Point2DGeneric[T], Generic[T, KT]))
self.assertEqual(Point3D.__mro__, (Point3D, Generic, dict, object))
self.assertEqual(Point3D.__parameters__, (T, KT))
self.assertEqual(Point3D.__total__, True)
self.assertEqual(Point3D.__optional_keys__, frozenset())
self.assertEqual(Point3D.__required_keys__, frozenset(['a', 'b', 'c']))
assert Point3D.__annotations__ == {
'a': T,
'b': T,
'c': KT,
}
self.assertEqual(Point3D[int, str].__origin__, Point3D)

with self.assertRaises(TypeError):
Point3D[int]

with self.assertRaises(TypeError):
class Point3D(Point2DGeneric[T], Generic[KT]):
c: KT

def test_implicit_any_inheritance(self):
class A(TypedDict, Generic[T]):
a: T

class B(A[KT], total=False):
b: KT

class WithImplicitAny(B):
c: int

self.assertEqual(WithImplicitAny.__bases__, (Generic, dict,))
self.assertEqual(WithImplicitAny.__mro__, (WithImplicitAny, Generic, dict, object))
# Consistent with GenericTests.test_implicit_any
self.assertEqual(WithImplicitAny.__parameters__, ())
self.assertEqual(WithImplicitAny.__total__, True)
self.assertEqual(WithImplicitAny.__optional_keys__, frozenset(['b']))
self.assertEqual(WithImplicitAny.__required_keys__, frozenset(['a', 'c']))
assert WithImplicitAny.__annotations__ == {
'a': T,
'b': KT,
'c': int,
}
with self.assertRaises(TypeError):
WithImplicitAny[str]


class AnnotatedTests(BaseTestCase):

Expand Down
81 changes: 50 additions & 31 deletions src/typing_extensions.py
Expand Up @@ -381,6 +381,46 @@ def _is_callable_members_only(cls):
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))


def _maybe_adjust_parameters(cls):
"""Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__.

The contents of this function are very similar
to logic found in typing.Generic.__init_subclass__
on the CPython main branch.
"""
tvars = []
if '__orig_bases__' in cls.__dict__:
tvars = typing._collect_type_vars(cls.__orig_bases__)
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...] and/or Protocol[...].
gvars = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (typing.Generic, Protocol)):
# for error messages
the_base = base.__origin__.__name__
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...]"
" and/or Protocol[...] multiple types.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(f"Some type variables ({s_vars}) are"
f" not listed in {the_base}[{s_args}]")
tvars = gvars
cls.__parameters__ = tuple(tvars)


# 3.8+
if hasattr(typing, 'Protocol'):
Protocol = typing.Protocol
Expand Down Expand Up @@ -477,43 +517,13 @@ def __class_getitem__(cls, params):
return typing._GenericAlias(cls, params)

def __init_subclass__(cls, *args, **kwargs):
tvars = []
if '__orig_bases__' in cls.__dict__:
error = typing.Generic in cls.__orig_bases__
else:
error = typing.Generic in cls.__bases__
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
tvars = typing._collect_type_vars(cls.__orig_bases__)
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...] and/or Protocol[...].
gvars = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (typing.Generic, Protocol)):
# for error messages
the_base = base.__origin__.__name__
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...]"
" and/or Protocol[...] multiple types.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(f"Some type variables ({s_vars}) are"
f" not listed in {the_base}[{s_args}]")
tvars = gvars
cls.__parameters__ = tuple(tvars)
_maybe_adjust_parameters(cls)

# Determine if this is a protocol or a concrete subclass.
if not cls.__dict__.get('_is_protocol', None):
Expand Down Expand Up @@ -614,6 +624,7 @@ def __index__(self) -> int:
# keyword with old-style TypedDict(). See https://bugs.python.org/issue42059
# The standard library TypedDict below Python 3.11 does not store runtime
# information about optional and required keys when using Required or NotRequired.
# Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11.
TypedDict = typing.TypedDict
_TypedDictMeta = typing._TypedDictMeta
is_typeddict = typing.is_typeddict
Expand Down Expand Up @@ -696,8 +707,16 @@ def __new__(cls, name, bases, ns, total=True):
# Subclasses and instances of TypedDict return actual dictionaries
# via _dict_new.
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
# Don't insert typing.Generic into __bases__ here,
# or Generic.__init_subclass__ will raise TypeError
# in the super().__new__() call.
# Instead, monkey-patch __bases__ onto the class after it's been created.
tp_dict = super().__new__(cls, name, (dict,), ns)

if any(issubclass(base, typing.Generic) for base in bases):
tp_dict.__bases__ = (typing.Generic, dict)
_maybe_adjust_parameters(tp_dict)

annotations = {}
own_annotations = ns.get('__annotations__', {})
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
Expand Down