Skip to content

Commit

Permalink
Support Field in dataclass + 'metadata' kwarg of `dataclasses.f…
Browse files Browse the repository at this point in the history
…ield` (#2384)

* Support `Field` in `dataclass` + `'metadata'` kwarg of `dataclasses.field`

Please enter the commit message for your changes. Lines starting

* add `__has_field_info_default__` for minimal effect on perf

* lower complexity of `_process_class`
  • Loading branch information
PrettyWood committed Feb 25, 2021
1 parent f32832a commit 3ec3559
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 23 deletions.
1 change: 1 addition & 0 deletions changes/2384-PrettyWood.md
@@ -0,0 +1 @@
Improve field declaration for _pydantic_ `dataclass` by allowing the usage of _pydantic_ `Field` or `'metadata'` kwarg of `dataclasses.field`
9 changes: 8 additions & 1 deletion docs/examples/dataclasses_default_schema.py
@@ -1,5 +1,7 @@
import dataclasses
from typing import List
from typing import List, Optional

from pydantic import Field
from pydantic.dataclasses import dataclass


Expand All @@ -8,6 +10,11 @@ class User:
id: int
name: str = 'John Doe'
friends: List[int] = dataclasses.field(default_factory=lambda: [0])
age: Optional[int] = dataclasses.field(
default=None,
metadata=dict(title='The age of the user', description='do not lie!')
)
height: Optional[int] = Field(None, title='The height in cm', ge=50, le=300)


user = User(id='42')
Expand Down
65 changes: 45 additions & 20 deletions pydantic/dataclasses.py
Expand Up @@ -3,14 +3,14 @@
from .class_validators import gather_all_validators
from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Required
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model
from .typing import resolve_annotations
from .utils import ClassAttribute

if TYPE_CHECKING:
from .main import BaseConfig, BaseModel # noqa: F401
from .typing import CallableGenerator
from .typing import CallableGenerator, NoArgAnyCallable

DataclassT = TypeVar('DataclassT', bound='Dataclass')

Expand All @@ -19,6 +19,7 @@ class Dataclass:
__initialised__: bool
__post_init_original__: Optional[Callable[..., None]]
__processed__: Optional[ClassAttribute]
__has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value

def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
Expand Down Expand Up @@ -80,6 +81,30 @@ def is_builtin_dataclass(_cls: Type[Any]) -> bool:
return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)


def _generate_pydantic_post_init(
post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]]
) -> Callable[..., None]:
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)

if getattr(self, '__has_field_info_default__', False):
# We need to remove `FieldInfo` values since they are not valid as input
# It's ok to do that because they are obviously the default values!
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
else:
input_data = self.__dict__
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if post_init_post_parse is not None:
post_init_post_parse(self, *initvars)

return _pydantic_post_init


def _process_class(
_cls: Type[Any],
init: bool,
Expand All @@ -100,16 +125,7 @@ def _process_class(

post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)

def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if post_init_post_parse is not None:
post_init_post_parse(self, *initvars)
_pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse)

# If the class is already a dataclass, __post_init__ will not be called automatically
# so no validation will be added.
Expand Down Expand Up @@ -144,22 +160,31 @@ def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
)
cls.__processed__ = ClassAttribute('__processed__', True)

fields: Dict[str, Any] = {}
field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(cls):
default: Any = Undefined
default_factory: Optional['NoArgAnyCallable'] = None
field_info: FieldInfo

if field.default != dataclasses.MISSING:
field_value = field.default
if field.default is not dataclasses.MISSING:
default = field.default
# mypy issue 7020 and 708
elif field.default_factory != dataclasses.MISSING: # type: ignore
field_value = field.default_factory() # type: ignore
elif field.default_factory is not dataclasses.MISSING: # type: ignore
default_factory = field.default_factory # type: ignore
else:
default = Required

if isinstance(default, FieldInfo):
field_info = default
cls.__has_field_info_default__ = True
else:
field_value = Required
field_info = Field(default=default, default_factory=default_factory, **field.metadata)

fields[field.name] = (field.type, field_value)
field_definitions[field.name] = (field.type, field_info)

validators = gather_all_validators(cls)
cls.__pydantic_model__ = create_model(
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **fields
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
)

cls.__initialised__ = False
Expand Down
18 changes: 16 additions & 2 deletions tests/test_dataclasses.py
Expand Up @@ -429,7 +429,7 @@ class User:
assert fields['id'].default is None

assert fields['aliases'].required is False
assert fields['aliases'].default == {'John': 'Joey'}
assert fields['aliases'].default_factory() == {'John': 'Joey'}


def test_default_factory_singleton_field():
Expand All @@ -456,6 +456,10 @@ class User:
name: str = 'John Doe'
aliases: Dict[str, str] = dataclasses.field(default_factory=lambda: {'John': 'Joey'})
signup_ts: datetime = None
age: Optional[int] = dataclasses.field(
default=None, metadata=dict(title='The age of the user', description='do not lie!')
)
height: Optional[int] = pydantic.Field(None, title='The height in cm', ge=50, le=300)

user = User(id=123)
assert user.__pydantic_model__.schema() == {
Expand All @@ -466,11 +470,21 @@ class User:
'name': {'title': 'Name', 'default': 'John Doe', 'type': 'string'},
'aliases': {
'title': 'Aliases',
'default': {'John': 'Joey'},
'type': 'object',
'additionalProperties': {'type': 'string'},
},
'signup_ts': {'title': 'Signup Ts', 'type': 'string', 'format': 'date-time'},
'age': {
'title': 'The age of the user',
'description': 'do not lie!',
'type': 'integer',
},
'height': {
'title': 'The height in cm',
'minimum': 50,
'maximum': 300,
'type': 'integer',
},
},
'required': ['id'],
}
Expand Down

0 comments on commit 3ec3559

Please sign in to comment.