Skip to content

Commit

Permalink
Extract Field from Annotated
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobHayes committed Jan 13, 2021
1 parent 687db0d commit 83771ef
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 49 deletions.
13 changes: 13 additions & 0 deletions docs/examples/schema_annotated.py
@@ -0,0 +1,13 @@
from uuid import uuid4

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated

from pydantic import BaseModel, Field


class Foo(BaseModel):
id: Annotated[str, Field(default_factory=lambda: uuid4().hex)]
name: Annotated[str, Field(max_length=256)] = 'Bar'
1 change: 1 addition & 0 deletions docs/requirements.txt
Expand Up @@ -6,5 +6,6 @@ mkdocs-exclude==1.0.2
mkdocs-material==6.2.3
markdown-include==0.6.0
sqlalchemy
typing-extensions==3.7.4
orjson
ujson
15 changes: 15 additions & 0 deletions docs/usage/schema.md
Expand Up @@ -106,6 +106,21 @@ to `Field()` with the raw schema attribute name:
```
_(This script is complete, it should run "as is")_

### typing.Annotated Fields

Rather than assigning a `Field` value, it can be specified in the type hint with `typing.Annotated`:

```py
{!.tmp_examples/schema_annotated.py!}
```
_(This script is complete, it should run "as is")_

`Field` can only be supplied once per field - an error will be raised if used in `Annotated` and as the assigned value.
Defaults can be set outside `Annotated` as the assigned value or with `Field.default_factory` inside `Annotated` - the
`Field.default` argument is not supported inside `Annotated`.

For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used.

## Modifying schema in custom fields

Custom field types can customise the schema generated for them using the `__modify_schema__` class method;
Expand Down
5 changes: 5 additions & 0 deletions docs/usage/types.md
Expand Up @@ -72,6 +72,11 @@ with custom properties and validation.
`typing.Any`
: allows any value include `None`, thus an `Any` field is optional

`typing.Annotated`
: allows wrapping another type with arbitrary metadata, as per [PEP-593](https://www.python.org/dev/peps/pep-0593/). The
`Annotated` hint may contain a single call to the [`Field` function](schema.md#typingannotated-fields), but otherwise
the additional metadata is ignored and the root type is used.

`typing.TypeVar`
: constrains the values allowed based on `constraints` or `bound`, see [TypeVar](#typevar)

Expand Down
59 changes: 48 additions & 11 deletions pydantic/fields.py
Expand Up @@ -120,6 +120,10 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.regex = kwargs.pop('regex', None)
self.extra = kwargs

def _validate(self) -> None:
if self.default is not Undefined and self.default_factory is not None:
raise ValueError('cannot specify both default and default_factory')


def Field(
default: Any = Undefined,
Expand Down Expand Up @@ -171,10 +175,7 @@ def Field(
pattern string. The schema will have a ``pattern`` validation keyword
:param **extra: any additional keyword arguments will be added as is to the schema
"""
if default is not Undefined and default_factory is not None:
raise ValueError('cannot specify both default and default_factory')

return FieldInfo(
field_info = FieldInfo(
default,
default_factory=default_factory,
alias=alias,
Expand All @@ -193,6 +194,8 @@ def Field(
regex=regex,
**extra,
)
field_info._validate()
return field_info


def Schema(default: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -288,6 +291,46 @@ def __init__(
def get_default(self) -> Any:
return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory()

@staticmethod
def _get_field_info(
field_name: str, annotation: Any, value: Any, config: Type['BaseConfig']
) -> Tuple[FieldInfo, Any]:
"""
Get a FieldInfo from a root typing.Annotated annotation, value, or config default.
The FieldInfo may be set in typing.Annotated or the value, but not both. If neither contain
a FieldInfo, a new one will be created using the config.
:param field_name: name of the field for use in error messages
:param annotation: a type hint such as `str` or `Annotated[str, Field(..., min_length=5)]`
:param value: the field's assigned value
:param config: the model's config object
:return: the FieldInfo contained in the `annotation`, the value, or a new one from the config.
"""
field_info_from_config = config.get_field_info(field_name)

field_info = None
if get_origin(annotation) is Annotated:
field_infos = [arg for arg in get_args(annotation)[1:] if isinstance(arg, FieldInfo)]
if len(field_infos) > 1:
raise ValueError(f'cannot specify multiple `Annotated` `Field`s for {field_name!r}')
field_info = next(iter(field_infos), None)
if field_info is not None:
if field_info.default is not Undefined:
raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}')
if value is not Undefined:
field_info.default = value
if isinstance(value, FieldInfo):
if field_info is not None:
raise ValueError(f'cannot specify `Annotated` and value `Field`s together for {field_name!r}')
field_info = value
if field_info is None:
field_info = FieldInfo(value, **field_info_from_config)
field_info.alias = field_info.alias or field_info_from_config.get('alias')
value = None if field_info.default_factory is not None else field_info.default
field_info._validate()
return field_info, value

@classmethod
def infer(
cls,
Expand All @@ -298,21 +341,15 @@ def infer(
class_validators: Optional[Dict[str, Validator]],
config: Type['BaseConfig'],
) -> 'ModelField':
field_info_from_config = config.get_field_info(name)
from .schema import get_annotation_from_field_info

if isinstance(value, FieldInfo):
field_info = value
value = None if field_info.default_factory is not None else field_info.default
else:
field_info = FieldInfo(value, **field_info_from_config)
field_info, value = cls._get_field_info(name, annotation, value, config)
required: 'BoolUndefined' = Undefined
if value is Required:
required = True
value = None
elif value is not Undefined:
required = False
field_info.alias = field_info.alias or field_info_from_config.get('alias')
annotation = get_annotation_from_field_info(annotation, field_info, name)
return cls(
name=name,
Expand Down
45 changes: 33 additions & 12 deletions pydantic/typing.py
Expand Up @@ -160,35 +160,56 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
return typing_get_args(tp) or getattr(tp, '__args__', ()) or generic_get_args(tp)


if sys.version_info < (3, 9):
if sys.version_info >= (3, 9):
from typing import Annotated
else:
if TYPE_CHECKING:
from typing_extensions import Annotated, _AnnotatedAlias
else: # due to different mypy warnings raised during CI for python 3.7 and 3.8
try:
from typing_extensions import Annotated, _AnnotatedAlias
from typing_extensions import Annotated

try:
from typing_extensions import _AnnotatedAlias
except ImportError:
# py3.6 typing_extensions doesn't have _AnnotatedAlias, but has AnnotatedMeta, which
# will satisfy our `isinstance` checks.
from typing_extensions import AnnotatedMeta as _AnnotatedAlias
except ImportError:
Annotated, _AnnotatedAlias = None, None
# Create mock Annotated/_AnnotatedAlias values distinct from `None`, which is a valid
# `get_origin` return value.
class _FalseMeta(type):
# Allow short circuiting with "Annotated[...] if Annotated else None".
def __bool__(cls):
return False

# Give a nice suggestion for unguarded use
def __getitem__(cls, key):
raise RuntimeError(
'Annotated is not supported in this python version, please `pip install typing-extensions`.'
)

class Annotated(metaclass=_FalseMeta):
pass

class _AnnotatedAlias(metaclass=_FalseMeta):
pass

# Our custom get_{args,origin} for <3.8 and the builtin 3.8 get_{args,origin} don't recognize
# typing_extensions.Annotated, so wrap them to short-circuit. We still want to use our wrapped
# get_origins defined above for non-Annotated data.
_get_args, _get_origin = get_args, get_origin

def get_args(tp: Type[Any]) -> Type[Any]:
if _AnnotatedAlias is not None and isinstance(tp, _AnnotatedAlias):
def get_args(tp: Type[Any], _get_args=get_args) -> Type[Any]:
if isinstance(tp, _AnnotatedAlias):
return tp.__args__ + tp.__metadata__
return _get_args(tp)

def get_origin(tp: Type[Any]) -> Type[Any]:
if _AnnotatedAlias is not None and isinstance(tp, _AnnotatedAlias):
def get_origin(tp: Type[Any], _get_origin=get_origin) -> Type[Any]:
if isinstance(tp, _AnnotatedAlias):
return Annotated
return _get_origin(tp)


else:
from typing import Annotated


if TYPE_CHECKING:
from .fields import ModelField

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -4,5 +4,5 @@ Cython==0.29.21;sys_platform!='win32'
devtools==0.6.1
email-validator==1.1.2
dataclasses==0.6; python_version < '3.7'
typing-extensions==3.7.4.1; python_version < '3.8'
typing-extensions==3.7.4.1; python_version < '3.9'
python-dotenv==0.15.0
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -130,7 +130,7 @@ def extra(self):
],
extras_require={
'email': ['email-validator>=1.0.3'],
'typing_extensions': ['typing-extensions>=3.7.2'],
'typing_extensions': ['typing-extensions>=3.7.4'],
'dotenv': ['python-dotenv>=0.10.4'],
},
ext_modules=ext_modules,
Expand Down
121 changes: 121 additions & 0 deletions tests/test_annotated.py
@@ -0,0 +1,121 @@
import sys
from typing import get_type_hints

import pytest

from pydantic import BaseModel, Field
from pydantic.fields import Undefined
from pydantic.typing import Annotated


@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
@pytest.mark.parametrize(
['hint_fn', 'value'],
[
# Test Annotated types with arbitrary metadata
pytest.param(
lambda: Annotated[int, 0],
5,
id='misc-default',
),
pytest.param(
lambda: Annotated[int, 0],
Field(default=5, ge=0),
id='misc-field-default-constraint',
),
# Test valid Annotated Field uses
pytest.param(
lambda: Annotated[int, Field(description='Test')],
5,
id='annotated-field-value-default',
),
pytest.param(
lambda: Annotated[int, Field(default_factory=lambda: 5, description='Test')],
Undefined,
id='annotated-field-default_factory',
),
],
)
def test_annotated(hint_fn, value):
hint = hint_fn()

class M(BaseModel):
x: hint = value

assert M().x == 5
assert M(x=10).x == 10

# get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full
# annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the
# full thing with the new include_extras flag.
if sys.version_info >= (3, 9):
assert get_type_hints(M)['x'] is int
assert get_type_hints(M, include_extras=True)['x'] == hint
else:
assert get_type_hints(M)['x'] == hint


@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
@pytest.mark.parametrize(
['hint_fn', 'value', 'subclass_ctx'],
[
pytest.param(
lambda: Annotated[int, Field(5)],
Undefined,
pytest.raises(ValueError, match='`Field` default cannot be set in `Annotated`'),
id='annotated-field-default',
),
pytest.param(
lambda: Annotated[int, Field(), Field()],
Undefined,
pytest.raises(ValueError, match='cannot specify multiple `Annotated` `Field`s'),
id='annotated-field-dup',
),
pytest.param(
lambda: Annotated[int, Field()],
Field(),
pytest.raises(ValueError, match='cannot specify `Annotated` and value `Field`'),
id='annotated-field-value-field-dup',
),
pytest.param(
lambda: Annotated[int, Field(default_factory=lambda: 5)], # The factory is not used
5,
pytest.raises(ValueError, match='cannot specify both default and default_factory'),
id='annotated-field-default_factory-value-default',
),
],
)
def test_annotated_model_exceptions(hint_fn, value, subclass_ctx):
hint = hint_fn()
with subclass_ctx:

class M(BaseModel):
x: hint = value


@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
@pytest.mark.parametrize(
['hint_fn', 'value', 'empty_init_ctx'],
[
pytest.param(
lambda: Annotated[int, 0],
Undefined,
pytest.raises(ValueError, match='field required'),
id='misc-no-default',
),
pytest.param(
lambda: Annotated[int, Field()],
Undefined,
pytest.raises(ValueError, match='field required'),
id='annotated-field-no-default',
),
],
)
def test_annotated_instance_exceptions(hint_fn, value, empty_init_ctx):
hint = hint_fn()

class M(BaseModel):
x: hint = value

with empty_init_ctx:
assert M().x == 5
23 changes: 1 addition & 22 deletions tests/test_main.py
Expand Up @@ -18,8 +18,7 @@
root_validator,
validator,
)
from pydantic.fields import Undefined
from pydantic.typing import Annotated, Literal
from pydantic.typing import Literal


def test_success():
Expand Down Expand Up @@ -1426,23 +1425,3 @@ class M(BaseModel):
a: int

get_type_hints(M.__config__)


@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
@pytest.mark.parametrize(['value'], [(Undefined,), (Field(default=5),), (Field(default=5, ge=0),)])
def test_annotated(value):
x_hint = Annotated[int, 5]

class M(BaseModel):
x: x_hint = value

assert M(x=5).x == 5

# get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full
# annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the
# full thing with the new include_extras flag.
if sys.version_info >= (3, 9):
assert get_type_hints(M)['x'] is int
assert get_type_hints(M, include_extras=True)['x'] == x_hint
else:
assert get_type_hints(M)['x'] == x_hint

0 comments on commit 83771ef

Please sign in to comment.