From 3ec35590f14c1fb4b7f67c6806abb74ada144bbf Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Thu, 25 Feb 2021 21:04:16 +0100 Subject: [PATCH] Support `Field` in `dataclass` + `'metadata'` kwarg of `dataclasses.field` (#2384) * Support `Field` in `dataclass` + `'metadata'` kwarg of `dataclasses.field` Please enter the commit message for your changes. Lines starting * add `__has_field_info_default__` for minimal effect on perf * lower complexity of `_process_class` --- changes/2384-PrettyWood.md | 1 + docs/examples/dataclasses_default_schema.py | 9 ++- pydantic/dataclasses.py | 65 ++++++++++++++------- tests/test_dataclasses.py | 18 +++++- 4 files changed, 70 insertions(+), 23 deletions(-) create mode 100644 changes/2384-PrettyWood.md diff --git a/changes/2384-PrettyWood.md b/changes/2384-PrettyWood.md new file mode 100644 index 0000000000..f2dcd54486 --- /dev/null +++ b/changes/2384-PrettyWood.md @@ -0,0 +1 @@ +Improve field declaration for _pydantic_ `dataclass` by allowing the usage of _pydantic_ `Field` or `'metadata'` kwarg of `dataclasses.field` \ No newline at end of file diff --git a/docs/examples/dataclasses_default_schema.py b/docs/examples/dataclasses_default_schema.py index 125b5e7763..c0c72f6fca 100644 --- a/docs/examples/dataclasses_default_schema.py +++ b/docs/examples/dataclasses_default_schema.py @@ -1,5 +1,7 @@ import dataclasses -from typing import List +from typing import List, Optional + +from pydantic import Field from pydantic.dataclasses import dataclass @@ -8,6 +10,11 @@ class User: id: int name: str = 'John Doe' friends: List[int] = dataclasses.field(default_factory=lambda: [0]) + age: Optional[int] = dataclasses.field( + default=None, + metadata=dict(title='The age of the user', description='do not lie!') + ) + height: Optional[int] = Field(None, title='The height in cm', ge=50, le=300) user = User(id='42') diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 485437d8a4..42ae685cac 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -3,14 +3,14 @@ from .class_validators import gather_all_validators from .error_wrappers import ValidationError from .errors import DataclassTypeError -from .fields import Required +from .fields import Field, FieldInfo, Required, Undefined from .main import create_model, validate_model from .typing import resolve_annotations from .utils import ClassAttribute if TYPE_CHECKING: from .main import BaseConfig, BaseModel # noqa: F401 - from .typing import CallableGenerator + from .typing import CallableGenerator, NoArgAnyCallable DataclassT = TypeVar('DataclassT', bound='Dataclass') @@ -19,6 +19,7 @@ class Dataclass: __initialised__: bool __post_init_original__: Optional[Callable[..., None]] __processed__: Optional[ClassAttribute] + __has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value def __init__(self, *args: Any, **kwargs: Any) -> None: pass @@ -80,6 +81,30 @@ def is_builtin_dataclass(_cls: Type[Any]) -> bool: return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls) +def _generate_pydantic_post_init( + post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]] +) -> Callable[..., None]: + def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None: + if post_init_original is not None: + post_init_original(self, *initvars) + + if getattr(self, '__has_field_info_default__', False): + # We need to remove `FieldInfo` values since they are not valid as input + # It's ok to do that because they are obviously the default values! + input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)} + else: + input_data = self.__dict__ + d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) + if validation_error: + raise validation_error + object.__setattr__(self, '__dict__', d) + object.__setattr__(self, '__initialised__', True) + if post_init_post_parse is not None: + post_init_post_parse(self, *initvars) + + return _pydantic_post_init + + def _process_class( _cls: Type[Any], init: bool, @@ -100,16 +125,7 @@ def _process_class( post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None) - def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None: - if post_init_original is not None: - post_init_original(self, *initvars) - d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__) - if validation_error: - raise validation_error - object.__setattr__(self, '__dict__', d) - object.__setattr__(self, '__initialised__', True) - if post_init_post_parse is not None: - post_init_post_parse(self, *initvars) + _pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse) # If the class is already a dataclass, __post_init__ will not be called automatically # so no validation will be added. @@ -144,22 +160,31 @@ def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None: ) cls.__processed__ = ClassAttribute('__processed__', True) - fields: Dict[str, Any] = {} + field_definitions: Dict[str, Any] = {} for field in dataclasses.fields(cls): + default: Any = Undefined + default_factory: Optional['NoArgAnyCallable'] = None + field_info: FieldInfo - if field.default != dataclasses.MISSING: - field_value = field.default + if field.default is not dataclasses.MISSING: + default = field.default # mypy issue 7020 and 708 - elif field.default_factory != dataclasses.MISSING: # type: ignore - field_value = field.default_factory() # type: ignore + elif field.default_factory is not dataclasses.MISSING: # type: ignore + default_factory = field.default_factory # type: ignore + else: + default = Required + + if isinstance(default, FieldInfo): + field_info = default + cls.__has_field_info_default__ = True else: - field_value = Required + field_info = Field(default=default, default_factory=default_factory, **field.metadata) - fields[field.name] = (field.type, field_value) + field_definitions[field.name] = (field.type, field_info) validators = gather_all_validators(cls) cls.__pydantic_model__ = create_model( - cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **fields + cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions ) cls.__initialised__ = False diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 80d79dc97e..dde9ad4b2d 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -429,7 +429,7 @@ class User: assert fields['id'].default is None assert fields['aliases'].required is False - assert fields['aliases'].default == {'John': 'Joey'} + assert fields['aliases'].default_factory() == {'John': 'Joey'} def test_default_factory_singleton_field(): @@ -456,6 +456,10 @@ class User: name: str = 'John Doe' aliases: Dict[str, str] = dataclasses.field(default_factory=lambda: {'John': 'Joey'}) signup_ts: datetime = None + age: Optional[int] = dataclasses.field( + default=None, metadata=dict(title='The age of the user', description='do not lie!') + ) + height: Optional[int] = pydantic.Field(None, title='The height in cm', ge=50, le=300) user = User(id=123) assert user.__pydantic_model__.schema() == { @@ -466,11 +470,21 @@ class User: 'name': {'title': 'Name', 'default': 'John Doe', 'type': 'string'}, 'aliases': { 'title': 'Aliases', - 'default': {'John': 'Joey'}, 'type': 'object', 'additionalProperties': {'type': 'string'}, }, 'signup_ts': {'title': 'Signup Ts', 'type': 'string', 'format': 'date-time'}, + 'age': { + 'title': 'The age of the user', + 'description': 'do not lie!', + 'type': 'integer', + }, + 'height': { + 'title': 'The height in cm', + 'minimum': 50, + 'maximum': 300, + 'type': 'integer', + }, }, 'required': ['id'], }