Skip to content

Commit

Permalink
feat: add support for NamedTuple and TypedDict types
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Dec 23, 2020
1 parent 3496a47 commit da3d88f
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 0 deletions.
1 change: 1 addition & 0 deletions changes/2216-PrettyWood.md
@@ -0,0 +1 @@
add support for `NamedTuple` and `TypedDict` types
8 changes: 8 additions & 0 deletions docs/usage/types.md
Expand Up @@ -85,9 +85,17 @@ with custom properties and validation.
`typing.Tuple`
: see [Typing Iterables](#typing-iterables) below for more detail on parsing and validation

`subclass of typing.NamedTuple (or collections.namedtuple)`
: Same as `tuple` but instantiates with the given namedtuple.
_pydantic_ will validate the tuple if you use `typing.NamedTuple` since fields are annotated.
If you use `collections.namedtuple`, no validation will be done.

`typing.Dict`
: see [Typing Iterables](#typing-iterables) below for more detail on parsing and validation

`subclass of typing.TypedDict`
: Same as `dict` but _pydantic_ will validate the dictionary since keys are annotated

`typing.Set`
: see [Typing Iterables](#typing-iterables) below for more detail on parsing and validation

Expand Down
3 changes: 3 additions & 0 deletions pydantic/fields.py
Expand Up @@ -37,6 +37,7 @@
get_origin,
is_literal_type,
is_new_type,
is_typed_dict_type,
new_type_supertype,
)
from .utils import PyObjectStr, Representation, lenient_issubclass, sequence_like, smart_deepcopy
Expand Down Expand Up @@ -415,6 +416,8 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
return
elif is_literal_type(self.type_):
return
elif is_typed_dict_type(self.type_):
return

origin = get_origin(self.type_)
if origin is None:
Expand Down
14 changes: 14 additions & 0 deletions pydantic/typing.py
Expand Up @@ -155,6 +155,8 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
'is_literal_type',
'literal_values',
'Literal',
'is_named_tuple_type',
'is_typed_dict_type',
'is_new_type',
'new_type_supertype',
'is_classvar',
Expand Down Expand Up @@ -258,6 +260,18 @@ def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
return tuple(x for value in values for x in all_literal_values(value))


def is_named_tuple_type(type_: Type[Any]) -> bool:
from .utils import lenient_issubclass

return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')


def is_typed_dict_type(type_: Type[Any]) -> bool:
from .utils import lenient_issubclass

return lenient_issubclass(type_, dict) and getattr(type_, '__annotations__', None)


test_type = NewType('test_type', str)


Expand Down
44 changes: 44 additions & 0 deletions pydantic/validators.py
Expand Up @@ -15,6 +15,7 @@
FrozenSet,
Generator,
List,
NamedTuple,
Pattern,
Set,
Tuple,
Expand All @@ -34,6 +35,8 @@
get_class,
is_callable_type,
is_literal_type,
is_named_tuple_type,
is_typed_dict_type,
)
from .utils import almost_equal_floats, lenient_issubclass, sequence_like

Expand Down Expand Up @@ -523,6 +526,40 @@ def pattern_validator(v: Any) -> Pattern[str]:
raise errors.PatternError()


NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple)


def make_named_tuple_validator(type_: Type[NamedTupleT]) -> Callable[[Tuple[Any, ...]], NamedTupleT]:
from .main import BaseModel

# A named tuple can be created with `typing,NamedTuple` with types
# but also with `collections.namedtuple` with just the fields
# in which case we consider the type to be `Any`
named_tuple_annotations = getattr(type_, '__annotations__', {k: Any for k in type_._fields})

class NamedTupleModel(BaseModel):
__annotations__ = named_tuple_annotations

def named_tuple_validator(v: Tuple[Any, ...]) -> NamedTupleT:
values: Dict[str, Any] = dict(zip(named_tuple_annotations, v))
validated_values: Dict[str, Any] = dict(NamedTupleModel(**values))
return type_(**validated_values)

return named_tuple_validator


def make_typed_dict_validator(type_: Type[Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]:
from .main import BaseModel

class TypedDictModel(BaseModel):
__annotations__ = type_.__annotations__

def typed_dict_validator(values: Dict[str, Any]) -> Dict[str, Any]:
return dict(TypedDictModel(**values))

return typed_dict_validator


class IfConfig:
def __init__(self, validator: AnyCallable, *config_attr_names: str) -> None:
self.validator = validator
Expand Down Expand Up @@ -610,6 +647,13 @@ def find_validators( # noqa: C901 (ignore complexity)
if type_ is IntEnum:
yield int_enum_validator
return
if is_named_tuple_type(type_):
yield tuple_validator
yield make_named_tuple_validator(type_)
return
if is_typed_dict_type(type_):
yield make_typed_dict_validator(type_)
return

class_ = get_class(type_)
if class_ is not None:
Expand Down
61 changes: 61 additions & 0 deletions tests/test_main.py
Expand Up @@ -1425,3 +1425,64 @@ class M(BaseModel):
a: int

get_type_hints(M.__config__)


def test_named_tuple():
from collections import namedtuple
from typing import NamedTuple

Position = namedtuple('Pos', 'x y')

class Event(NamedTuple):
a: int
b: int
c: int
d: str

class Model(BaseModel):
pos: Position
events: List[Event]

model = Model(pos=('1', 2), events=[[b'1', '2', 3, 'qwe']])
assert isinstance(model.pos, Position)
assert isinstance(model.events[0], Event)
assert model.pos.x == '1'
assert model.pos == Position('1', 2)
assert model.events[0] == Event(1, 2, 3, 'qwe')
assert repr(model) == "Model(pos=Pos(x='1', y=2), events=[Event(a=1, b=2, c=3, d='qwe')])"

with pytest.raises(ValidationError) as exc_info:
Model(pos=('1', 2), events=[['qwe', '2', 3, 'qwe']])
assert exc_info.value.errors() == [
{
'loc': ('events', 0, 'a'),
'msg': 'value is not a valid integer',
'type': 'type_error.integer',
}
]


def test_typed_dict():
from typing import TypedDict

class TD(TypedDict):
a: int
b: int
c: int
d: str

class Model(BaseModel):
td: TD

m = Model(td={'a': '3', 'b': b'1', 'c': 4, 'd': 'qwe'})
assert m.td == {'a': 3, 'b': 1, 'c': 4, 'd': 'qwe'}

with pytest.raises(ValidationError) as exc_info:
Model(td={'a': [1], 'b': 2, 'c': 3, 'd': 'qwe'})
assert exc_info.value.errors() == [
{
'loc': ('td', 'a'),
'msg': 'value is not a valid integer',
'type': 'type_error.integer',
}
]

0 comments on commit da3d88f

Please sign in to comment.