Skip to content

Commit

Permalink
Infer root type from Annotated
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobHayes committed Jan 13, 2021
1 parent 13a5c7d commit 687db0d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .types import Json, JsonWrapper
from .typing import (
NONE_TYPES,
Annotated,
Callable,
ForwardRef,
NoArgAnyCallable,
Expand Down Expand Up @@ -424,6 +425,10 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
if isinstance(self.type_, type) and isinstance(None, self.type_):
self.allow_none = True
return
if origin is Annotated:
self.type_ = get_args(self.type_)[0]
self._type_analysis()
return
if origin is Callable:
return
if origin is Union:
Expand Down
3 changes: 3 additions & 0 deletions pydantic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
)
from .typing import (
NONE_TYPES,
Annotated,
ForwardRef,
Literal,
get_args,
Expand Down Expand Up @@ -901,6 +902,8 @@ def go(type_: Any) -> Type[Any]:
# forward refs cause infinite recursion below
return type_

if origin is Annotated:
return go(args[0])
if origin is Union:
return Union[tuple(go(a) for a in args)] # type: ignore

Expand Down
30 changes: 30 additions & 0 deletions pydantic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,35 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
return typing_get_args(tp) or getattr(tp, '__args__', ()) or generic_get_args(tp)


if sys.version_info < (3, 9):
if TYPE_CHECKING:
from typing_extensions import Annotated, _AnnotatedAlias
else: # due to different mypy warnings raised during CI for python 3.7 and 3.8
try:
from typing_extensions import Annotated, _AnnotatedAlias
except ImportError:
Annotated, _AnnotatedAlias = None, None

# Our custom get_{args,origin} for <3.8 and the builtin 3.8 get_{args,origin} don't recognize
# typing_extensions.Annotated, so wrap them to short-circuit. We still want to use our wrapped
# get_origins defined above for non-Annotated data.
_get_args, _get_origin = get_args, get_origin

def get_args(tp: Type[Any]) -> Type[Any]:
if _AnnotatedAlias is not None and isinstance(tp, _AnnotatedAlias):
return tp.__args__ + tp.__metadata__
return _get_args(tp)

def get_origin(tp: Type[Any]) -> Type[Any]:
if _AnnotatedAlias is not None and isinstance(tp, _AnnotatedAlias):
return Annotated
return _get_origin(tp)


else:
from typing import Annotated


if TYPE_CHECKING:
from .fields import ModelField

Expand All @@ -178,6 +207,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
__all__ = (
'ForwardRef',
'Callable',
'Annotated',
'AnyCallable',
'NoArgAnyCallable',
'NoneType',
Expand Down
23 changes: 22 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
root_validator,
validator,
)
from pydantic.typing import Literal
from pydantic.fields import Undefined
from pydantic.typing import Annotated, Literal


def test_success():
Expand Down Expand Up @@ -1425,3 +1426,23 @@ class M(BaseModel):
a: int

get_type_hints(M.__config__)


@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
@pytest.mark.parametrize(['value'], [(Undefined,), (Field(default=5),), (Field(default=5, ge=0),)])
def test_annotated(value):
x_hint = Annotated[int, 5]

class M(BaseModel):
x: x_hint = value

assert M(x=5).x == 5

# get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full
# annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the
# full thing with the new include_extras flag.
if sys.version_info >= (3, 9):
assert get_type_hints(M)['x'] is int
assert get_type_hints(M, include_extras=True)['x'] == x_hint
else:
assert get_type_hints(M)['x'] == x_hint
24 changes: 24 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import os
import re
import string
Expand All @@ -14,11 +15,13 @@
from pydantic.dataclasses import dataclass
from pydantic.fields import Undefined
from pydantic.typing import (
Annotated,
ForwardRef,
Literal,
all_literal_values,
display_as_type,
get_args,
get_origin,
is_new_type,
new_type_supertype,
resolve_annotations,
Expand Down Expand Up @@ -432,6 +435,24 @@ def test_smart_deepcopy_collection(collection, mocker):
T = TypeVar('T')


@pytest.mark.skipif(sys.version_info < (3, 7), reason='get_origin is only consistent for python >= 3.7')
@pytest.mark.parametrize(
'input_value,output_value',
[
(Annotated and Annotated[int, 10], Annotated),
(Callable[[], T][int], collections.abc.Callable),
(Dict[str, int], dict),
(List[str], list),
(Union[int, str], Union),
(int, None),
],
)
def test_get_origin(input_value, output_value):
if input_value is None:
pytest.skip('Skipping undefined hint for this python version')
assert get_origin(input_value) is output_value


@pytest.mark.skipif(sys.version_info < (3, 8), reason='get_args is only consistent for python >= 3.8')
@pytest.mark.parametrize(
'input_value,output_value',
Expand All @@ -444,9 +465,12 @@ def test_smart_deepcopy_collection(collection, mocker):
(Union[int, Union[T, int], str][int], (int, str)),
(Union[int, Tuple[T, int]][str], (int, Tuple[str, int])),
(Callable[[], T][int], ([], int)),
(Annotated and Annotated[int, 10], (int, 10)),
],
)
def test_get_args(input_value, output_value):
if input_value is None:
pytest.skip('Skipping undefined hint for this python version')
assert get_args(input_value) == output_value


Expand Down

0 comments on commit 687db0d

Please sign in to comment.