Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a backport of generic NamedTuples #44

Merged
merged 21 commits into from May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
@@ -1,3 +1,9 @@
# Unreleased

- Add `typing_extensions.NamedTuple`, allowing for generic `NamedTuple`s on
Python <3.11 (backport from python/cpython#92027, by Serhiy Storchaka). Patch
by Alex Waygood (@AlexWaygood).

# Release 4.2.0 (April 17, 2022)

- Re-export `typing.Unpack` and `typing.TypeVarTuple` on Python 3.11.
Expand Down
3 changes: 3 additions & 0 deletions README.md
Expand Up @@ -96,6 +96,7 @@ This module currently contains the following:
- `Counter`
- `DefaultDict`
- `Deque`
- `NamedTuple`
- `NewType`
- `NoReturn`
- `overload`
Expand All @@ -121,6 +122,8 @@ Certain objects were changed after they were added to `typing`, and
introspectable at runtime. In order to access overloads with
`typing_extensions.get_overloads()`, you must use
`@typing_extensions.overload`.
- `NamedTuple` was changed in Python 3.11 to allow for multiple inheritance
with `typing.Generic`.

There are a few types whose interface was modified between different
versions of typing. For example, `typing.Sequence` was modified to
Expand Down
306 changes: 304 additions & 2 deletions src/test_typing_extensions.py
Expand Up @@ -5,6 +5,7 @@
import collections
from collections import defaultdict
import collections.abc
import copy
from functools import lru_cache
import inspect
import pickle
Expand All @@ -17,7 +18,7 @@
from typing import TypeVar, Optional, Union, Any, AnyStr
from typing import T, KT, VT # Not in __all__.
from typing import Tuple, List, Dict, Iterable, Iterator, Callable
from typing import Generic, NamedTuple
from typing import Generic
from typing import no_type_check
import typing_extensions
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self
Expand All @@ -27,10 +28,12 @@
from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
from typing_extensions import clear_overloads, get_overloads, overload
from typing_extensions import NamedTuple

# Flags used to mark tests that only apply after a specific
# version of the typing module.
TYPING_3_8_0 = sys.version_info[:3] >= (3, 8, 0)
TYPING_3_9_0 = sys.version_info[:3] >= (3, 9, 0)
TYPING_3_10_0 = sys.version_info[:3] >= (3, 10, 0)

# 3.11 makes runtime type checks (_type_check) more lenient.
Expand Down Expand Up @@ -2874,7 +2877,7 @@ def test_typing_extensions_defers_when_possible(self):
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin'}
if sys.version_info < (3, 11):
exclude.add('final')
exclude |= {'final', 'NamedTuple'}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
self.assertIs(
Expand All @@ -2892,6 +2895,305 @@ def test_typing_extensions_compiles_with_opt(self):
self.fail('Module does not compile with optimize=2 (-OO flag).')


class CoolEmployee(NamedTuple):
name: str
cool: int


class CoolEmployeeWithDefault(NamedTuple):
name: str
cool: int = 0


class XMeth(NamedTuple):
x: int

def double(self):
return 2 * self.x


class XRepr(NamedTuple):
x: int
y: int = 1

def __str__(self):
return f'{self.x} -> {self.y}'

def __add__(self, other):
return 0


@skipIf(TYPING_3_11_0, "These invariants should all be tested upstream on 3.11+")
class NamedTupleTests(BaseTestCase):
class NestedEmployee(NamedTuple):
name: str
cool: int

def test_basics(self):
Emp = NamedTuple('Emp', [('name', str), ('id', int)])
self.assertIsSubclass(Emp, tuple)
joe = Emp('Joe', 42)
jim = Emp(name='Jim', id=1)
self.assertIsInstance(joe, Emp)
self.assertIsInstance(joe, tuple)
self.assertEqual(joe.name, 'Joe')
self.assertEqual(joe.id, 42)
self.assertEqual(jim.name, 'Jim')
self.assertEqual(jim.id, 1)
self.assertEqual(Emp.__name__, 'Emp')
self.assertEqual(Emp._fields, ('name', 'id'))
self.assertEqual(Emp.__annotations__,
collections.OrderedDict([('name', str), ('id', int)]))

def test_annotation_usage(self):
tim = CoolEmployee('Tim', 9000)
self.assertIsInstance(tim, CoolEmployee)
self.assertIsInstance(tim, tuple)
self.assertEqual(tim.name, 'Tim')
self.assertEqual(tim.cool, 9000)
self.assertEqual(CoolEmployee.__name__, 'CoolEmployee')
self.assertEqual(CoolEmployee._fields, ('name', 'cool'))
self.assertEqual(CoolEmployee.__annotations__,
collections.OrderedDict(name=str, cool=int))

def test_annotation_usage_with_default(self):
jelle = CoolEmployeeWithDefault('Jelle')
self.assertIsInstance(jelle, CoolEmployeeWithDefault)
self.assertIsInstance(jelle, tuple)
self.assertEqual(jelle.name, 'Jelle')
self.assertEqual(jelle.cool, 0)
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
cooler_employee = CoolEmployeeWithDefault('Sjoerd', 1)
self.assertEqual(cooler_employee.cool, 1)

self.assertEqual(CoolEmployeeWithDefault.__name__, 'CoolEmployeeWithDefault')
self.assertEqual(CoolEmployeeWithDefault._fields, ('name', 'cool'))
self.assertEqual(CoolEmployeeWithDefault.__annotations__,
dict(name=str, cool=int))

with self.assertRaisesRegex(
TypeError,
'Non-default namedtuple field y cannot follow default field x'
):
class NonDefaultAfterDefault(NamedTuple):
x: int = 3
y: int

@skipUnless(
(
TYPING_3_8_0
or hasattr(CoolEmployeeWithDefault, '_field_defaults')
),
'"_field_defaults" attribute was added in a micro version of 3.7'
)
def test_field_defaults(self):
self.assertEqual(CoolEmployeeWithDefault._field_defaults, dict(cool=0))

def test_annotation_usage_with_methods(self):
self.assertEqual(XMeth(1).double(), 2)
self.assertEqual(XMeth(42).x, XMeth(42)[0])
self.assertEqual(str(XRepr(42)), '42 -> 1')
self.assertEqual(XRepr(1, 2) + XRepr(3), 0)

bad_overwrite_error_message = 'Cannot overwrite NamedTuple attribute'

with self.assertRaisesRegex(AttributeError, bad_overwrite_error_message):
class XMethBad(NamedTuple):
x: int
def _fields(self):
return 'no chance for this'

with self.assertRaisesRegex(AttributeError, bad_overwrite_error_message):
class XMethBad2(NamedTuple):
x: int
def _source(self):
return 'no chance for this as well'

def test_multiple_inheritance(self):
class A:
pass
with self.assertRaisesRegex(
TypeError,
'can only inherit from a NamedTuple type and Generic'
):
class X(NamedTuple, A):
x: int

with self.assertRaisesRegex(
TypeError,
'can only inherit from a NamedTuple type and Generic'
):
class X(NamedTuple, tuple):
x: int

with self.assertRaisesRegex(TypeError, 'duplicate base class'):
class X(NamedTuple, NamedTuple):
x: int

class A(NamedTuple):
x: int
with self.assertRaisesRegex(
TypeError,
'can only inherit from a NamedTuple type and Generic'
):
class X(NamedTuple, A):
y: str

def test_generic(self):
class X(NamedTuple, Generic[T]):
x: T
self.assertEqual(X.__bases__, (tuple, Generic))
self.assertEqual(X.__orig_bases__, (NamedTuple, Generic[T]))
self.assertEqual(X.__mro__, (X, tuple, Generic, object))

class Y(Generic[T], NamedTuple):
x: T
self.assertEqual(Y.__bases__, (Generic, tuple))
self.assertEqual(Y.__orig_bases__, (Generic[T], NamedTuple))
self.assertEqual(Y.__mro__, (Y, Generic, tuple, object))

for G in X, Y:
with self.subTest(type=G):
self.assertEqual(G.__parameters__, (T,))
A = G[int]
self.assertIs(A.__origin__, G)
self.assertEqual(A.__args__, (int,))
self.assertEqual(A.__parameters__, ())

a = A(3)
self.assertIs(type(a), G)
self.assertEqual(a.x, 3)

with self.assertRaisesRegex(TypeError, 'Too many parameters'):
G[int, str]

@skipUnless(TYPING_3_9_0, "tuple.__class_getitem__ was added in 3.9")
def test_non_generic_subscript_py39_plus(self):
# For backward compatibility, subscription works
# on arbitrary NamedTuple types.
class Group(NamedTuple):
key: T
group: list[T]
A = Group[int]
self.assertEqual(A.__origin__, Group)
self.assertEqual(A.__parameters__, ())
self.assertEqual(A.__args__, (int,))
a = A(1, [2])
self.assertIs(type(a), Group)
self.assertEqual(a, (1, [2]))

@skipIf(TYPING_3_9_0, "Test isn't relevant to 3.9+")
def test_non_generic_subscript_error_message_py38_minus(self):
class Group(NamedTuple):
key: T
group: List[T]

with self.assertRaisesRegex(TypeError, 'not subscriptable'):
Group[int]

for attr in ('__args__', '__origin__', '__parameters__'):
with self.subTest(attr=attr):
self.assertFalse(hasattr(Group, attr))

def test_namedtuple_keyword_usage(self):
LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int)
nick = LocalEmployee('Nick', 25)
self.assertIsInstance(nick, tuple)
self.assertEqual(nick.name, 'Nick')
self.assertEqual(LocalEmployee.__name__, 'LocalEmployee')
self.assertEqual(LocalEmployee._fields, ('name', 'age'))
self.assertEqual(LocalEmployee.__annotations__, dict(name=str, age=int))
with self.assertRaisesRegex(
TypeError,
'Either list of fields or keywords can be provided to NamedTuple, not both'
):
NamedTuple('Name', [('x', int)], y=str)

def test_namedtuple_special_keyword_names(self):
NT = NamedTuple("NT", cls=type, self=object, typename=str, fields=list)
self.assertEqual(NT.__name__, 'NT')
self.assertEqual(NT._fields, ('cls', 'self', 'typename', 'fields'))
a = NT(cls=str, self=42, typename='foo', fields=[('bar', tuple)])
self.assertEqual(a.cls, str)
self.assertEqual(a.self, 42)
self.assertEqual(a.typename, 'foo')
self.assertEqual(a.fields, [('bar', tuple)])

def test_empty_namedtuple(self):
NT = NamedTuple('NT')

class CNT(NamedTuple):
pass # empty body

for struct in [NT, CNT]:
with self.subTest(struct=struct):
self.assertEqual(struct._fields, ())
self.assertEqual(struct.__annotations__, {})
self.assertIsInstance(struct(), struct)
# Attribute was added in a micro version of 3.7
# and is tested more fully elsewhere
if hasattr(struct, "_field_defaults"):
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(struct._field_defaults, {})

def test_namedtuple_errors(self):
with self.assertRaises(TypeError):
NamedTuple.__new__()
with self.assertRaises(TypeError):
NamedTuple()
with self.assertRaises(TypeError):
NamedTuple('Emp', [('name', str)], None)
with self.assertRaisesRegex(ValueError, 'cannot start with an underscore'):
NamedTuple('Emp', [('_name', str)])
with self.assertRaises(TypeError):
NamedTuple(typename='Emp', name=str, id=int)

def test_copy_and_pickle(self):
global Emp # pickle wants to reference the class by name
Emp = NamedTuple('Emp', [('name', str), ('cool', int)])
for cls in Emp, CoolEmployee, self.NestedEmployee:
with self.subTest(cls=cls):
jane = cls('jane', 37)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z = pickle.dumps(jane, proto)
jane2 = pickle.loads(z)
self.assertEqual(jane2, jane)
self.assertIsInstance(jane2, cls)

jane2 = copy.copy(jane)
self.assertEqual(jane2, jane)
self.assertIsInstance(jane2, cls)

jane2 = copy.deepcopy(jane)
self.assertEqual(jane2, jane)
self.assertIsInstance(jane2, cls)

def test_docstring(self):
self.assertEqual(NamedTuple.__doc__, typing.NamedTuple.__doc__)
self.assertIsInstance(NamedTuple.__doc__, str)

@skipUnless(TYPING_3_8_0, "NamedTuple had a bad signature on <=3.7")
def test_signature_is_same_as_typing_NamedTuple(self):
self.assertEqual(inspect.signature(NamedTuple), inspect.signature(typing.NamedTuple))

@skipIf(TYPING_3_8_0, "tests are only relevant to <=3.7")
def test_signature_on_37(self):
self.assertIsInstance(inspect.signature(NamedTuple), inspect.Signature)
self.assertFalse(hasattr(NamedTuple, "__text_signature__"))

@skipUnless(TYPING_3_9_0, "NamedTuple was a class on 3.8 and lower")
def test_same_as_typing_NamedTuple_39_plus(self):
self.assertEqual(
set(dir(NamedTuple)),
set(dir(typing.NamedTuple)) | {"__text_signature__"}
)
self.assertIs(type(NamedTuple), type(typing.NamedTuple))

@skipIf(TYPING_3_9_0, "tests are only relevant to <=3.8")
def test_same_as_typing_NamedTuple_38_minus(self):
self.assertEqual(
self.NestedEmployee.__annotations__,
self.NestedEmployee._field_types
)


if __name__ == '__main__':
main()