From 83771ef4c015572447a7d839710f5e1dab08605e Mon Sep 17 00:00:00 2001 From: Jacob Hayes Date: Tue, 1 Dec 2020 22:05:13 -0500 Subject: [PATCH] Extract Field from Annotated --- docs/examples/schema_annotated.py | 13 ++++ docs/requirements.txt | 1 + docs/usage/schema.md | 15 ++++ docs/usage/types.md | 5 ++ pydantic/fields.py | 59 ++++++++++++--- pydantic/typing.py | 45 ++++++++--- requirements.txt | 2 +- setup.py | 2 +- tests/test_annotated.py | 121 ++++++++++++++++++++++++++++++ tests/test_main.py | 23 +----- tests/test_utils.py | 4 +- 11 files changed, 241 insertions(+), 49 deletions(-) create mode 100644 docs/examples/schema_annotated.py create mode 100644 tests/test_annotated.py diff --git a/docs/examples/schema_annotated.py b/docs/examples/schema_annotated.py new file mode 100644 index 00000000000..ab29ad38a97 --- /dev/null +++ b/docs/examples/schema_annotated.py @@ -0,0 +1,13 @@ +from uuid import uuid4 + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +from pydantic import BaseModel, Field + + +class Foo(BaseModel): + id: Annotated[str, Field(default_factory=lambda: uuid4().hex)] + name: Annotated[str, Field(max_length=256)] = 'Bar' diff --git a/docs/requirements.txt b/docs/requirements.txt index 1b8ba1000e1..a6b92373815 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,5 +6,6 @@ mkdocs-exclude==1.0.2 mkdocs-material==6.2.3 markdown-include==0.6.0 sqlalchemy +typing-extensions==3.7.4 orjson ujson diff --git a/docs/usage/schema.md b/docs/usage/schema.md index f2f889df8b4..7ec256c9014 100644 --- a/docs/usage/schema.md +++ b/docs/usage/schema.md @@ -106,6 +106,21 @@ to `Field()` with the raw schema attribute name: ``` _(This script is complete, it should run "as is")_ +### typing.Annotated Fields + +Rather than assigning a `Field` value, it can be specified in the type hint with `typing.Annotated`: + +```py +{!.tmp_examples/schema_annotated.py!} +``` +_(This script is complete, it should run "as is")_ + +`Field` can only be supplied once per field - an error will be raised if used in `Annotated` and as the assigned value. +Defaults can be set outside `Annotated` as the assigned value or with `Field.default_factory` inside `Annotated` - the +`Field.default` argument is not supported inside `Annotated`. + +For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used. + ## Modifying schema in custom fields Custom field types can customise the schema generated for them using the `__modify_schema__` class method; diff --git a/docs/usage/types.md b/docs/usage/types.md index 7f1c4d301fb..787f16f49de 100644 --- a/docs/usage/types.md +++ b/docs/usage/types.md @@ -72,6 +72,11 @@ with custom properties and validation. `typing.Any` : allows any value include `None`, thus an `Any` field is optional +`typing.Annotated` +: allows wrapping another type with arbitrary metadata, as per [PEP-593](https://www.python.org/dev/peps/pep-0593/). The + `Annotated` hint may contain a single call to the [`Field` function](schema.md#typingannotated-fields), but otherwise + the additional metadata is ignored and the root type is used. + `typing.TypeVar` : constrains the values allowed based on `constraints` or `bound`, see [TypeVar](#typevar) diff --git a/pydantic/fields.py b/pydantic/fields.py index cca10a27ba7..2e5d5c0d549 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -120,6 +120,10 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.regex = kwargs.pop('regex', None) self.extra = kwargs + def _validate(self) -> None: + if self.default is not Undefined and self.default_factory is not None: + raise ValueError('cannot specify both default and default_factory') + def Field( default: Any = Undefined, @@ -171,10 +175,7 @@ def Field( pattern string. The schema will have a ``pattern`` validation keyword :param **extra: any additional keyword arguments will be added as is to the schema """ - if default is not Undefined and default_factory is not None: - raise ValueError('cannot specify both default and default_factory') - - return FieldInfo( + field_info = FieldInfo( default, default_factory=default_factory, alias=alias, @@ -193,6 +194,8 @@ def Field( regex=regex, **extra, ) + field_info._validate() + return field_info def Schema(default: Any, **kwargs: Any) -> Any: @@ -288,6 +291,46 @@ def __init__( def get_default(self) -> Any: return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() + @staticmethod + def _get_field_info( + field_name: str, annotation: Any, value: Any, config: Type['BaseConfig'] + ) -> Tuple[FieldInfo, Any]: + """ + Get a FieldInfo from a root typing.Annotated annotation, value, or config default. + + The FieldInfo may be set in typing.Annotated or the value, but not both. If neither contain + a FieldInfo, a new one will be created using the config. + + :param field_name: name of the field for use in error messages + :param annotation: a type hint such as `str` or `Annotated[str, Field(..., min_length=5)]` + :param value: the field's assigned value + :param config: the model's config object + :return: the FieldInfo contained in the `annotation`, the value, or a new one from the config. + """ + field_info_from_config = config.get_field_info(field_name) + + field_info = None + if get_origin(annotation) is Annotated: + field_infos = [arg for arg in get_args(annotation)[1:] if isinstance(arg, FieldInfo)] + if len(field_infos) > 1: + raise ValueError(f'cannot specify multiple `Annotated` `Field`s for {field_name!r}') + field_info = next(iter(field_infos), None) + if field_info is not None: + if field_info.default is not Undefined: + raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}') + if value is not Undefined: + field_info.default = value + if isinstance(value, FieldInfo): + if field_info is not None: + raise ValueError(f'cannot specify `Annotated` and value `Field`s together for {field_name!r}') + field_info = value + if field_info is None: + field_info = FieldInfo(value, **field_info_from_config) + field_info.alias = field_info.alias or field_info_from_config.get('alias') + value = None if field_info.default_factory is not None else field_info.default + field_info._validate() + return field_info, value + @classmethod def infer( cls, @@ -298,21 +341,15 @@ def infer( class_validators: Optional[Dict[str, Validator]], config: Type['BaseConfig'], ) -> 'ModelField': - field_info_from_config = config.get_field_info(name) from .schema import get_annotation_from_field_info - if isinstance(value, FieldInfo): - field_info = value - value = None if field_info.default_factory is not None else field_info.default - else: - field_info = FieldInfo(value, **field_info_from_config) + field_info, value = cls._get_field_info(name, annotation, value, config) required: 'BoolUndefined' = Undefined if value is Required: required = True value = None elif value is not Undefined: required = False - field_info.alias = field_info.alias or field_info_from_config.get('alias') annotation = get_annotation_from_field_info(annotation, field_info, name) return cls( name=name, diff --git a/pydantic/typing.py b/pydantic/typing.py index 5cc5abbd61b..15d4eff6593 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -160,35 +160,56 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]: return typing_get_args(tp) or getattr(tp, '__args__', ()) or generic_get_args(tp) -if sys.version_info < (3, 9): +if sys.version_info >= (3, 9): + from typing import Annotated +else: if TYPE_CHECKING: from typing_extensions import Annotated, _AnnotatedAlias else: # due to different mypy warnings raised during CI for python 3.7 and 3.8 try: - from typing_extensions import Annotated, _AnnotatedAlias + from typing_extensions import Annotated + + try: + from typing_extensions import _AnnotatedAlias + except ImportError: + # py3.6 typing_extensions doesn't have _AnnotatedAlias, but has AnnotatedMeta, which + # will satisfy our `isinstance` checks. + from typing_extensions import AnnotatedMeta as _AnnotatedAlias except ImportError: - Annotated, _AnnotatedAlias = None, None + # Create mock Annotated/_AnnotatedAlias values distinct from `None`, which is a valid + # `get_origin` return value. + class _FalseMeta(type): + # Allow short circuiting with "Annotated[...] if Annotated else None". + def __bool__(cls): + return False + + # Give a nice suggestion for unguarded use + def __getitem__(cls, key): + raise RuntimeError( + 'Annotated is not supported in this python version, please `pip install typing-extensions`.' + ) + + class Annotated(metaclass=_FalseMeta): + pass + + class _AnnotatedAlias(metaclass=_FalseMeta): + pass # Our custom get_{args,origin} for <3.8 and the builtin 3.8 get_{args,origin} don't recognize # typing_extensions.Annotated, so wrap them to short-circuit. We still want to use our wrapped # get_origins defined above for non-Annotated data. - _get_args, _get_origin = get_args, get_origin - def get_args(tp: Type[Any]) -> Type[Any]: - if _AnnotatedAlias is not None and isinstance(tp, _AnnotatedAlias): + def get_args(tp: Type[Any], _get_args=get_args) -> Type[Any]: + if isinstance(tp, _AnnotatedAlias): return tp.__args__ + tp.__metadata__ return _get_args(tp) - def get_origin(tp: Type[Any]) -> Type[Any]: - if _AnnotatedAlias is not None and isinstance(tp, _AnnotatedAlias): + def get_origin(tp: Type[Any], _get_origin=get_origin) -> Type[Any]: + if isinstance(tp, _AnnotatedAlias): return Annotated return _get_origin(tp) -else: - from typing import Annotated - - if TYPE_CHECKING: from .fields import ModelField diff --git a/requirements.txt b/requirements.txt index 065d6b6d785..25e8b907a98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ Cython==0.29.21;sys_platform!='win32' devtools==0.6.1 email-validator==1.1.2 dataclasses==0.6; python_version < '3.7' -typing-extensions==3.7.4.1; python_version < '3.8' +typing-extensions==3.7.4.1; python_version < '3.9' python-dotenv==0.15.0 diff --git a/setup.py b/setup.py index 3171cb77cb1..8c5ec722782 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,7 @@ def extra(self): ], extras_require={ 'email': ['email-validator>=1.0.3'], - 'typing_extensions': ['typing-extensions>=3.7.2'], + 'typing_extensions': ['typing-extensions>=3.7.4'], 'dotenv': ['python-dotenv>=0.10.4'], }, ext_modules=ext_modules, diff --git a/tests/test_annotated.py b/tests/test_annotated.py new file mode 100644 index 00000000000..a80c70595ab --- /dev/null +++ b/tests/test_annotated.py @@ -0,0 +1,121 @@ +import sys +from typing import get_type_hints + +import pytest + +from pydantic import BaseModel, Field +from pydantic.fields import Undefined +from pydantic.typing import Annotated + + +@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') +@pytest.mark.parametrize( + ['hint_fn', 'value'], + [ + # Test Annotated types with arbitrary metadata + pytest.param( + lambda: Annotated[int, 0], + 5, + id='misc-default', + ), + pytest.param( + lambda: Annotated[int, 0], + Field(default=5, ge=0), + id='misc-field-default-constraint', + ), + # Test valid Annotated Field uses + pytest.param( + lambda: Annotated[int, Field(description='Test')], + 5, + id='annotated-field-value-default', + ), + pytest.param( + lambda: Annotated[int, Field(default_factory=lambda: 5, description='Test')], + Undefined, + id='annotated-field-default_factory', + ), + ], +) +def test_annotated(hint_fn, value): + hint = hint_fn() + + class M(BaseModel): + x: hint = value + + assert M().x == 5 + assert M(x=10).x == 10 + + # get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full + # annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the + # full thing with the new include_extras flag. + if sys.version_info >= (3, 9): + assert get_type_hints(M)['x'] is int + assert get_type_hints(M, include_extras=True)['x'] == hint + else: + assert get_type_hints(M)['x'] == hint + + +@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') +@pytest.mark.parametrize( + ['hint_fn', 'value', 'subclass_ctx'], + [ + pytest.param( + lambda: Annotated[int, Field(5)], + Undefined, + pytest.raises(ValueError, match='`Field` default cannot be set in `Annotated`'), + id='annotated-field-default', + ), + pytest.param( + lambda: Annotated[int, Field(), Field()], + Undefined, + pytest.raises(ValueError, match='cannot specify multiple `Annotated` `Field`s'), + id='annotated-field-dup', + ), + pytest.param( + lambda: Annotated[int, Field()], + Field(), + pytest.raises(ValueError, match='cannot specify `Annotated` and value `Field`'), + id='annotated-field-value-field-dup', + ), + pytest.param( + lambda: Annotated[int, Field(default_factory=lambda: 5)], # The factory is not used + 5, + pytest.raises(ValueError, match='cannot specify both default and default_factory'), + id='annotated-field-default_factory-value-default', + ), + ], +) +def test_annotated_model_exceptions(hint_fn, value, subclass_ctx): + hint = hint_fn() + with subclass_ctx: + + class M(BaseModel): + x: hint = value + + +@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') +@pytest.mark.parametrize( + ['hint_fn', 'value', 'empty_init_ctx'], + [ + pytest.param( + lambda: Annotated[int, 0], + Undefined, + pytest.raises(ValueError, match='field required'), + id='misc-no-default', + ), + pytest.param( + lambda: Annotated[int, Field()], + Undefined, + pytest.raises(ValueError, match='field required'), + id='annotated-field-no-default', + ), + ], +) +def test_annotated_instance_exceptions(hint_fn, value, empty_init_ctx): + hint = hint_fn() + + class M(BaseModel): + x: hint = value + + with empty_init_ctx: + assert M().x == 5 diff --git a/tests/test_main.py b/tests/test_main.py index 8b8235a8774..e026f59d058 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -18,8 +18,7 @@ root_validator, validator, ) -from pydantic.fields import Undefined -from pydantic.typing import Annotated, Literal +from pydantic.typing import Literal def test_success(): @@ -1426,23 +1425,3 @@ class M(BaseModel): a: int get_type_hints(M.__config__) - - -@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') -@pytest.mark.parametrize(['value'], [(Undefined,), (Field(default=5),), (Field(default=5, ge=0),)]) -def test_annotated(value): - x_hint = Annotated[int, 5] - - class M(BaseModel): - x: x_hint = value - - assert M(x=5).x == 5 - - # get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full - # annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the - # full thing with the new include_extras flag. - if sys.version_info >= (3, 9): - assert get_type_hints(M)['x'] is int - assert get_type_hints(M, include_extras=True)['x'] == x_hint - else: - assert get_type_hints(M)['x'] == x_hint diff --git a/tests/test_utils.py b/tests/test_utils.py index 7b6b4d04cf4..f7cd1636074 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -439,7 +439,7 @@ def test_smart_deepcopy_collection(collection, mocker): @pytest.mark.parametrize( 'input_value,output_value', [ - (Annotated and Annotated[int, 10], Annotated), + (Annotated[int, 10] if Annotated else None, Annotated), (Callable[[], T][int], collections.abc.Callable), (Dict[str, int], dict), (List[str], list), @@ -465,7 +465,7 @@ def test_get_origin(input_value, output_value): (Union[int, Union[T, int], str][int], (int, str)), (Union[int, Tuple[T, int]][str], (int, Tuple[str, int])), (Callable[[], T][int], ([], int)), - (Annotated and Annotated[int, 10], (int, 10)), + (Annotated[int, 10] if Annotated else None, (int, 10)), ], ) def test_get_args(input_value, output_value):