Skip to content

Commit

Permalink
Annotated: backport bpo-46491 (#1049)
Browse files Browse the repository at this point in the history
  • Loading branch information
GBeauregard committed Jan 25, 2022
1 parent 523cf02 commit 3b53f01
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
2 changes: 2 additions & 0 deletions typing_extensions/CHANGELOG
@@ -1,5 +1,7 @@
# Release 4.x.x

- `Annotated` can now wrap `ClassVar` and `Final`. Backport from
bpo-46491. Patch by Gregory Beauregard (@GBeauregard).
- Add missed `Required` and `NotRequired` to `__all__`. Patch by
Yuri Karabas (@uriyyo).
- The `@final` decorator now sets the `__final__` attribute on the
Expand Down
18 changes: 18 additions & 0 deletions typing_extensions/src/test_typing_extensions.py
Expand Up @@ -1799,6 +1799,24 @@ class C:
A.x = 5
self.assertEqual(C.x, 5)

@skipIf(sys.version_info[:2] in ((3, 9), (3, 10)), "Waiting for bpo-46491 bugfix.")
def test_special_form_containment(self):
class C:
classvar: Annotated[ClassVar[int], "a decoration"] = 4
const: Annotated[Final[int], "Const"] = 4

if sys.version_info[:2] >= (3, 7):
self.assertEqual(get_type_hints(C, globals())["classvar"], ClassVar[int])
self.assertEqual(get_type_hints(C, globals())["const"], Final[int])
else:
self.assertEqual(
get_type_hints(C, globals())["classvar"],
Annotated[ClassVar[int], "a decoration"]
)
self.assertEqual(
get_type_hints(C, globals())["const"], Annotated[Final[int], "Const"]
)

def test_hash_eq(self):
self.assertEqual(len({Annotated[int, 4, 5], Annotated[int, 4, 5]}), 1)
self.assertNotEqual(Annotated[int, 4, 5], Annotated[int, 5, 4])
Expand Down
18 changes: 14 additions & 4 deletions typing_extensions/src/typing_extensions.py
Expand Up @@ -1251,8 +1251,12 @@ def __class_getitem__(cls, params):
raise TypeError("Annotated[...] should be used "
"with at least two arguments (a type and an "
"annotation).")
msg = "Annotated[t, ...]: t must be a type."
origin = typing._type_check(params[0], msg)
allowed_special_forms = (ClassVar, Final)
if get_origin(params[0]) in allowed_special_forms:
origin = params[0]
else:
msg = "Annotated[t, ...]: t must be a type."
origin = typing._type_check(params[0], msg)
metadata = tuple(params[1:])
return _AnnotatedAlias(origin, metadata)

Expand Down Expand Up @@ -1377,8 +1381,14 @@ def __getitem__(self, params):
"with at least two arguments (a type and an "
"annotation).")
else:
msg = "Annotated[t, ...]: t must be a type."
tp = typing._type_check(params[0], msg)
if (
isinstance(params[0], typing._TypingBase) and
type(params[0]).__name__ == "_ClassVar"
):
tp = params[0]
else:
msg = "Annotated[t, ...]: t must be a type."
tp = typing._type_check(params[0], msg)
metadata = tuple(params[1:])
return self.__class__(
self.__name__,
Expand Down

0 comments on commit 3b53f01

Please sign in to comment.