From b742c6f527f32fd0c6eb96f0a35f9e632a3af6bd Mon Sep 17 00:00:00 2001 From: Jacob Hayes Date: Sat, 13 Feb 2021 11:13:21 -0500 Subject: [PATCH] Support Annotated type hints and extracting Field from Annotated (#2147) * Infer root type from Annotated * Extract Field from Annotated * Add changelog * Extend existing get_args/get_origin * Fix Annotated on py3.6 without typing-extensions * Handle Ellipsis default * Fix field reuse after FieldInfo.default mutation * Fix ci --- changes/2147-JacobHayes.md | 1 + docs/examples/schema_annotated.py | 13 +++ docs/requirements.txt | 1 + docs/usage/schema.md | 15 ++++ docs/usage/types.md | 5 ++ pydantic/fields.py | 64 +++++++++++--- pydantic/schema.py | 3 + pydantic/typing.py | 43 ++++++++++ setup.py | 2 +- tests/test_annotated.py | 136 ++++++++++++++++++++++++++++++ tests/test_utils.py | 24 ++++++ 11 files changed, 295 insertions(+), 12 deletions(-) create mode 100644 changes/2147-JacobHayes.md create mode 100644 docs/examples/schema_annotated.py create mode 100644 tests/test_annotated.py diff --git a/changes/2147-JacobHayes.md b/changes/2147-JacobHayes.md new file mode 100644 index 0000000000..322da5ffbc --- /dev/null +++ b/changes/2147-JacobHayes.md @@ -0,0 +1 @@ +Support `typing.Annotated` hints on model fields. A `Field` may now be set in the type hint with `Annotated[..., Field(...)`; all other annotations are ignored but still visible with `get_type_hints(..., include_extras=True)`. diff --git a/docs/examples/schema_annotated.py b/docs/examples/schema_annotated.py new file mode 100644 index 0000000000..ab29ad38a9 --- /dev/null +++ b/docs/examples/schema_annotated.py @@ -0,0 +1,13 @@ +from uuid import uuid4 + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +from pydantic import BaseModel, Field + + +class Foo(BaseModel): + id: Annotated[str, Field(default_factory=lambda: uuid4().hex)] + name: Annotated[str, Field(max_length=256)] = 'Bar' diff --git a/docs/requirements.txt b/docs/requirements.txt index 87e939b08a..a12fff6093 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,5 +7,6 @@ mkdocs-exclude==1.0.2 mkdocs-material==6.2.8 markdown-include==0.6.0 sqlalchemy +typing-extensions==3.7.4 orjson ujson diff --git a/docs/usage/schema.md b/docs/usage/schema.md index f2f889df8b..7ec256c901 100644 --- a/docs/usage/schema.md +++ b/docs/usage/schema.md @@ -106,6 +106,21 @@ to `Field()` with the raw schema attribute name: ``` _(This script is complete, it should run "as is")_ +### typing.Annotated Fields + +Rather than assigning a `Field` value, it can be specified in the type hint with `typing.Annotated`: + +```py +{!.tmp_examples/schema_annotated.py!} +``` +_(This script is complete, it should run "as is")_ + +`Field` can only be supplied once per field - an error will be raised if used in `Annotated` and as the assigned value. +Defaults can be set outside `Annotated` as the assigned value or with `Field.default_factory` inside `Annotated` - the +`Field.default` argument is not supported inside `Annotated`. + +For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used. + ## Modifying schema in custom fields Custom field types can customise the schema generated for them using the `__modify_schema__` class method; diff --git a/docs/usage/types.md b/docs/usage/types.md index 39aeb92a8e..929188fffd 100644 --- a/docs/usage/types.md +++ b/docs/usage/types.md @@ -72,6 +72,11 @@ with custom properties and validation. `typing.Any` : allows any value include `None`, thus an `Any` field is optional +`typing.Annotated` +: allows wrapping another type with arbitrary metadata, as per [PEP-593](https://www.python.org/dev/peps/pep-0593/). The + `Annotated` hint may contain a single call to the [`Field` function](schema.md#typingannotated-fields), but otherwise + the additional metadata is ignored and the root type is used. + `typing.TypeVar` : constrains the values allowed based on `constraints` or `bound`, see [TypeVar](#typevar) diff --git a/pydantic/fields.py b/pydantic/fields.py index 578f2a8825..f75e16f1f2 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -29,6 +29,7 @@ from .types import Json, JsonWrapper from .typing import ( NONE_TYPES, + Annotated, Callable, ForwardRef, NoArgAnyCallable, @@ -120,6 +121,10 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.regex = kwargs.pop('regex', None) self.extra = kwargs + def _validate(self) -> None: + if self.default not in (Undefined, Ellipsis) and self.default_factory is not None: + raise ValueError('cannot specify both default and default_factory') + def Field( default: Any = Undefined, @@ -171,10 +176,7 @@ def Field( pattern string. The schema will have a ``pattern`` validation keyword :param **extra: any additional keyword arguments will be added as is to the schema """ - if default is not Undefined and default_factory is not None: - raise ValueError('cannot specify both default and default_factory') - - return FieldInfo( + field_info = FieldInfo( default, default_factory=default_factory, alias=alias, @@ -193,6 +195,8 @@ def Field( regex=regex, **extra, ) + field_info._validate() + return field_info def Schema(default: Any, **kwargs: Any) -> Any: @@ -288,6 +292,46 @@ def __init__( def get_default(self) -> Any: return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() + @staticmethod + def _get_field_info( + field_name: str, annotation: Any, value: Any, config: Type['BaseConfig'] + ) -> Tuple[FieldInfo, Any]: + """ + Get a FieldInfo from a root typing.Annotated annotation, value, or config default. + + The FieldInfo may be set in typing.Annotated or the value, but not both. If neither contain + a FieldInfo, a new one will be created using the config. + + :param field_name: name of the field for use in error messages + :param annotation: a type hint such as `str` or `Annotated[str, Field(..., min_length=5)]` + :param value: the field's assigned value + :param config: the model's config object + :return: the FieldInfo contained in the `annotation`, the value, or a new one from the config. + """ + field_info_from_config = config.get_field_info(field_name) + + field_info = None + if get_origin(annotation) is Annotated: + field_infos = [arg for arg in get_args(annotation)[1:] if isinstance(arg, FieldInfo)] + if len(field_infos) > 1: + raise ValueError(f'cannot specify multiple `Annotated` `Field`s for {field_name!r}') + field_info = next(iter(field_infos), None) + if field_info is not None: + if field_info.default not in (Undefined, Ellipsis): + raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}') + if value not in (Undefined, Ellipsis): + field_info.default = value + if isinstance(value, FieldInfo): + if field_info is not None: + raise ValueError(f'cannot specify `Annotated` and value `Field`s together for {field_name!r}') + field_info = value + if field_info is None: + field_info = FieldInfo(value, **field_info_from_config) + field_info.alias = field_info.alias or field_info_from_config.get('alias') + value = None if field_info.default_factory is not None else field_info.default + field_info._validate() + return field_info, value + @classmethod def infer( cls, @@ -298,21 +342,15 @@ def infer( class_validators: Optional[Dict[str, Validator]], config: Type['BaseConfig'], ) -> 'ModelField': - field_info_from_config = config.get_field_info(name) from .schema import get_annotation_from_field_info - if isinstance(value, FieldInfo): - field_info = value - value = None if field_info.default_factory is not None else field_info.default - else: - field_info = FieldInfo(value, **field_info_from_config) + field_info, value = cls._get_field_info(name, annotation, value, config) required: 'BoolUndefined' = Undefined if value is Required: required = True value = None elif value is not Undefined: required = False - field_info.alias = field_info.alias or field_info_from_config.get('alias') annotation = get_annotation_from_field_info(annotation, field_info, name) return cls( name=name, @@ -427,6 +465,10 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity) if isinstance(self.type_, type) and isinstance(None, self.type_): self.allow_none = True return + if origin is Annotated: + self.type_ = get_args(self.type_)[0] + self._type_analysis() + return if origin is Callable: return if origin is Union: diff --git a/pydantic/schema.py b/pydantic/schema.py index 6d80ea302c..e1e14912b7 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -59,6 +59,7 @@ ) from .typing import ( NONE_TYPES, + Annotated, ForwardRef, Literal, get_args, @@ -917,6 +918,8 @@ def go(type_: Any) -> Type[Any]: # forward refs cause infinite recursion below return type_ + if origin is Annotated: + return go(args[0]) if origin is Union: return Union[tuple(go(a) for a in args)] # type: ignore diff --git a/pydantic/typing.py b/pydantic/typing.py index 91cdf4e6bd..cf4f25e58f 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -80,6 +80,38 @@ def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: AnyCallable = TypingCallable[..., Any] NoArgAnyCallable = TypingCallable[[], Any] + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + if TYPE_CHECKING: + from typing_extensions import Annotated + else: # due to different mypy warnings raised during CI for python 3.7 and 3.8 + try: + from typing_extensions import Annotated + except ImportError: + # Create mock Annotated values distinct from `None`, which is a valid `get_origin` + # return value. + class _FalseMeta(type): + # Allow short circuiting with "Annotated[...] if Annotated else None". + def __bool__(cls): + return False + + # Give a nice suggestion for unguarded use + def __getitem__(cls, key): + raise RuntimeError( + 'Annotated is not supported in this python version, please `pip install typing-extensions`.' + ) + + class Annotated(metaclass=_FalseMeta): + pass + + +# Annotated[...] is implemented by returning an instance of one of these classes, depending on +# python/typing_extensions version. +AnnotatedTypeNames = ('AnnotatedMeta', '_AnnotatedAlias') + + if sys.version_info < (3, 8): # noqa: C901 if TYPE_CHECKING: from typing_extensions import Literal @@ -100,6 +132,8 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]: python 3.6). """ + if Annotated and type(t).__name__ in AnnotatedTypeNames: + return t.__args__ + t.__metadata__ return getattr(t, '__args__', ()) else: @@ -111,6 +145,8 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]: Mostly compatible with the python 3.8 `typing` module version and able to handle almost all use cases. """ + if Annotated and type(t).__name__ in AnnotatedTypeNames: + return t.__args__ + t.__metadata__ if isinstance(t, _GenericAlias): res = t.__args__ if t.__origin__ is Callable and res and res[0] is not Ellipsis: @@ -119,6 +155,8 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]: return getattr(t, '__args__', ()) def get_origin(t: Type[Any]) -> Optional[Type[Any]]: + if Annotated and type(t).__name__ in AnnotatedTypeNames: + return cast(Type[Any], Annotated) # mypy complains about _SpecialForm in py3.6 return getattr(t, '__origin__', None) @@ -132,6 +170,8 @@ def get_origin(tp: Type[Any]) -> Type[Any]: It should be useless once https://github.com/cython/cython/issues/3537 is solved and https://github.com/samuelcolvin/pydantic/pull/1753 is merged. """ + if Annotated and type(tp).__name__ in AnnotatedTypeNames: + return cast(Type[Any], Annotated) # mypy complains about _SpecialForm return typing_get_origin(tp) or getattr(tp, '__origin__', None) def generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]: @@ -156,6 +196,8 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]: get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) get_args(Callable[[], T][int]) == ([], int) """ + if Annotated and type(tp).__name__ in AnnotatedTypeNames: + return tp.__args__ + tp.__metadata__ # the fallback is needed for the same reasons as `get_origin` (see above) return typing_get_args(tp) or getattr(tp, '__args__', ()) or generic_get_args(tp) @@ -178,6 +220,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]: __all__ = ( 'ForwardRef', 'Callable', + 'Annotated', 'AnyCallable', 'NoArgAnyCallable', 'NoneType', diff --git a/setup.py b/setup.py index bbed2bec9c..e6b72c4461 100644 --- a/setup.py +++ b/setup.py @@ -131,7 +131,7 @@ def extra(self): ], extras_require={ 'email': ['email-validator>=1.0.3'], - 'typing_extensions': ['typing-extensions>=3.7.2'], + 'typing_extensions': ['typing-extensions>=3.7.4'], 'dotenv': ['python-dotenv>=0.10.4'], }, ext_modules=ext_modules, diff --git a/tests/test_annotated.py b/tests/test_annotated.py new file mode 100644 index 0000000000..4aff92c77f --- /dev/null +++ b/tests/test_annotated.py @@ -0,0 +1,136 @@ +import sys +from typing import get_type_hints + +import pytest + +from pydantic import BaseModel, Field +from pydantic.fields import Undefined +from pydantic.typing import Annotated + + +@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') +@pytest.mark.parametrize( + ['hint_fn', 'value'], + [ + # Test Annotated types with arbitrary metadata + pytest.param( + lambda: Annotated[int, 0], + 5, + id='misc-default', + ), + pytest.param( + lambda: Annotated[int, 0], + Field(default=5, ge=0), + id='misc-field-default-constraint', + ), + # Test valid Annotated Field uses + pytest.param( + lambda: Annotated[int, Field(description='Test')], + 5, + id='annotated-field-value-default', + ), + pytest.param( + lambda: Annotated[int, Field(default_factory=lambda: 5, description='Test')], + Undefined, + id='annotated-field-default_factory', + ), + ], +) +def test_annotated(hint_fn, value): + hint = hint_fn() + + class M(BaseModel): + x: hint = value + + assert M().x == 5 + assert M(x=10).x == 10 + + # get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full + # annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the + # full thing with the new include_extras flag. + if sys.version_info >= (3, 9): + assert get_type_hints(M)['x'] is int + assert get_type_hints(M, include_extras=True)['x'] == hint + else: + assert get_type_hints(M)['x'] == hint + + +@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') +@pytest.mark.parametrize( + ['hint_fn', 'value', 'subclass_ctx'], + [ + pytest.param( + lambda: Annotated[int, Field(5)], + Undefined, + pytest.raises(ValueError, match='`Field` default cannot be set in `Annotated`'), + id='annotated-field-default', + ), + pytest.param( + lambda: Annotated[int, Field(), Field()], + Undefined, + pytest.raises(ValueError, match='cannot specify multiple `Annotated` `Field`s'), + id='annotated-field-dup', + ), + pytest.param( + lambda: Annotated[int, Field()], + Field(), + pytest.raises(ValueError, match='cannot specify `Annotated` and value `Field`'), + id='annotated-field-value-field-dup', + ), + pytest.param( + lambda: Annotated[int, Field(default_factory=lambda: 5)], # The factory is not used + 5, + pytest.raises(ValueError, match='cannot specify both default and default_factory'), + id='annotated-field-default_factory-value-default', + ), + ], +) +def test_annotated_model_exceptions(hint_fn, value, subclass_ctx): + hint = hint_fn() + with subclass_ctx: + + class M(BaseModel): + x: hint = value + + +@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') +@pytest.mark.parametrize( + ['hint_fn', 'value', 'empty_init_ctx'], + [ + pytest.param( + lambda: Annotated[int, 0], + Undefined, + pytest.raises(ValueError, match='field required'), + id='misc-no-default', + ), + pytest.param( + lambda: Annotated[int, Field()], + Undefined, + pytest.raises(ValueError, match='field required'), + id='annotated-field-no-default', + ), + ], +) +def test_annotated_instance_exceptions(hint_fn, value, empty_init_ctx): + hint = hint_fn() + + class M(BaseModel): + x: hint = value + + with empty_init_ctx: + assert M().x == 5 + + +@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') +def test_field_reuse(): + field = Field(description='Long description') + + class Model(BaseModel): + one: int = field + + assert Model(one=1).dict() == {'one': 1} + + class AnnotatedModel(BaseModel): + one: Annotated[int, field] + + assert AnnotatedModel(one=1).dict() == {'one': 1} diff --git a/tests/test_utils.py b/tests/test_utils.py index 1850c0f73f..f7cd163607 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +import collections.abc import os import re import string @@ -14,11 +15,13 @@ from pydantic.dataclasses import dataclass from pydantic.fields import Undefined from pydantic.typing import ( + Annotated, ForwardRef, Literal, all_literal_values, display_as_type, get_args, + get_origin, is_new_type, new_type_supertype, resolve_annotations, @@ -432,6 +435,24 @@ def test_smart_deepcopy_collection(collection, mocker): T = TypeVar('T') +@pytest.mark.skipif(sys.version_info < (3, 7), reason='get_origin is only consistent for python >= 3.7') +@pytest.mark.parametrize( + 'input_value,output_value', + [ + (Annotated[int, 10] if Annotated else None, Annotated), + (Callable[[], T][int], collections.abc.Callable), + (Dict[str, int], dict), + (List[str], list), + (Union[int, str], Union), + (int, None), + ], +) +def test_get_origin(input_value, output_value): + if input_value is None: + pytest.skip('Skipping undefined hint for this python version') + assert get_origin(input_value) is output_value + + @pytest.mark.skipif(sys.version_info < (3, 8), reason='get_args is only consistent for python >= 3.8') @pytest.mark.parametrize( 'input_value,output_value', @@ -444,9 +465,12 @@ def test_smart_deepcopy_collection(collection, mocker): (Union[int, Union[T, int], str][int], (int, str)), (Union[int, Tuple[T, int]][str], (int, Tuple[str, int])), (Callable[[], T][int], ([], int)), + (Annotated[int, 10] if Annotated else None, (int, 10)), ], ) def test_get_args(input_value, output_value): + if input_value is None: + pytest.skip('Skipping undefined hint for this python version') assert get_args(input_value) == output_value