Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Field in dataclass + 'metadata' kwarg of dataclasses.field #2384

Merged
merged 4 commits into from Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2384-PrettyWood.md
@@ -0,0 +1 @@
Improve field declaration for _pydantic_ `dataclass` by allowing the usage of _pydantic_ `Field` or `'metadata'` kwarg of `dataclasses.field`
9 changes: 8 additions & 1 deletion docs/examples/dataclasses_default_schema.py
@@ -1,5 +1,7 @@
import dataclasses
from typing import List
from typing import List, Optional

from pydantic import Field
from pydantic.dataclasses import dataclass


Expand All @@ -8,6 +10,11 @@ class User:
id: int
name: str = 'John Doe'
friends: List[int] = dataclasses.field(default_factory=lambda: [0])
age: Optional[int] = dataclasses.field(
default=None,
metadata=dict(title='The age of the user', description='do not lie!')
)
height: Optional[int] = Field(None, title='The height in cm', ge=50, le=300)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄 👀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wink wink



user = User(id='42')
Expand Down
65 changes: 45 additions & 20 deletions pydantic/dataclasses.py
Expand Up @@ -3,14 +3,14 @@
from .class_validators import gather_all_validators
from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Required
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model
from .typing import resolve_annotations
from .utils import ClassAttribute

if TYPE_CHECKING:
from .main import BaseConfig, BaseModel # noqa: F401
from .typing import CallableGenerator
from .typing import CallableGenerator, NoArgAnyCallable

DataclassT = TypeVar('DataclassT', bound='Dataclass')

Expand All @@ -19,6 +19,7 @@ class Dataclass:
__initialised__: bool
__post_init_original__: Optional[Callable[..., None]]
__processed__: Optional[ClassAttribute]
__has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value

def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
Expand Down Expand Up @@ -80,6 +81,30 @@ def is_builtin_dataclass(_cls: Type[Any]) -> bool:
return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)


def _generate_pydantic_post_init(
post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]]
) -> Callable[..., None]:
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)

if getattr(self, '__has_field_info_default__', False):
# We need to remove `FieldInfo` values since they are not valid as input
# It's ok to do that because they are obviously the default values!
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
else:
input_data = self.__dict__
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if post_init_post_parse is not None:
post_init_post_parse(self, *initvars)

return _pydantic_post_init


def _process_class(
_cls: Type[Any],
init: bool,
Expand All @@ -100,16 +125,7 @@ def _process_class(

post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)

def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if post_init_post_parse is not None:
post_init_post_parse(self, *initvars)
_pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse)

# If the class is already a dataclass, __post_init__ will not be called automatically
# so no validation will be added.
Expand Down Expand Up @@ -144,22 +160,31 @@ def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
)
cls.__processed__ = ClassAttribute('__processed__', True)

fields: Dict[str, Any] = {}
field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(cls):
default: Any = Undefined
default_factory: Optional['NoArgAnyCallable'] = None
field_info: FieldInfo

if field.default != dataclasses.MISSING:
field_value = field.default
if field.default is not dataclasses.MISSING:
default = field.default
# mypy issue 7020 and 708
elif field.default_factory != dataclasses.MISSING: # type: ignore
field_value = field.default_factory() # type: ignore
elif field.default_factory is not dataclasses.MISSING: # type: ignore
default_factory = field.default_factory # type: ignore
else:
default = Required

if isinstance(default, FieldInfo):
field_info = default
cls.__has_field_info_default__ = True
else:
field_value = Required
field_info = Field(default=default, default_factory=default_factory, **field.metadata)

fields[field.name] = (field.type, field_value)
field_definitions[field.name] = (field.type, field_info)

validators = gather_all_validators(cls)
cls.__pydantic_model__ = create_model(
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **fields
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
)

cls.__initialised__ = False
Expand Down
18 changes: 16 additions & 2 deletions tests/test_dataclasses.py
Expand Up @@ -429,7 +429,7 @@ class User:
assert fields['id'].default is None

assert fields['aliases'].required is False
assert fields['aliases'].default == {'John': 'Joey'}
assert fields['aliases'].default_factory() == {'John': 'Joey'}


def test_default_factory_singleton_field():
Expand All @@ -456,6 +456,10 @@ class User:
name: str = 'John Doe'
aliases: Dict[str, str] = dataclasses.field(default_factory=lambda: {'John': 'Joey'})
signup_ts: datetime = None
age: Optional[int] = dataclasses.field(
default=None, metadata=dict(title='The age of the user', description='do not lie!')
)
height: Optional[int] = pydantic.Field(None, title='The height in cm', ge=50, le=300)

user = User(id=123)
assert user.__pydantic_model__.schema() == {
Expand All @@ -466,11 +470,21 @@ class User:
'name': {'title': 'Name', 'default': 'John Doe', 'type': 'string'},
'aliases': {
'title': 'Aliases',
'default': {'John': 'Joey'},
'type': 'object',
'additionalProperties': {'type': 'string'},
},
'signup_ts': {'title': 'Signup Ts', 'type': 'string', 'format': 'date-time'},
'age': {
'title': 'The age of the user',
'description': 'do not lie!',
'type': 'integer',
},
'height': {
'title': 'The height in cm',
'minimum': 50,
'maximum': 300,
'type': 'integer',
},
},
'required': ['id'],
}
Expand Down