Skip to content

Commit

Permalink
Support Field in dataclass + 'metadata' kwarg of `dataclasses.f…
Browse files Browse the repository at this point in the history
…ield`

Please enter the commit message for your changes. Lines starting
  • Loading branch information
PrettyWood committed Feb 22, 2021
1 parent b7a8ef2 commit 149b10c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 14 deletions.
1 change: 1 addition & 0 deletions changes/2384-PrettyWood.md
@@ -0,0 +1 @@
Improve field declaration for _pydantic_ `dataclass` by allowing the usage of _pydantic_ `Field` or `'metadata'` kwarg of `dataclasses.field`
9 changes: 8 additions & 1 deletion docs/examples/dataclasses_default_schema.py
@@ -1,5 +1,7 @@
import dataclasses
from typing import List
from typing import List, Optional

from pydantic import Field
from pydantic.dataclasses import dataclass


Expand All @@ -8,6 +10,11 @@ class User:
id: int
name: str = 'John Doe'
friends: List[int] = dataclasses.field(default_factory=lambda: [0])
age: Optional[int] = dataclasses.field(
default=None,
metadata=dict(title='The age of the user', description='do not lie!')
)
height: Optional[int] = Field(None, title='The height in cm', ge=50, le=300)


user = User(id='42')
Expand Down
34 changes: 23 additions & 11 deletions pydantic/dataclasses.py
Expand Up @@ -3,14 +3,14 @@
from .class_validators import gather_all_validators
from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Required
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model
from .typing import resolve_annotations
from .utils import ClassAttribute

if TYPE_CHECKING:
from .main import BaseConfig, BaseModel # noqa: F401
from .typing import CallableGenerator
from .typing import CallableGenerator, NoArgAnyCallable

DataclassT = TypeVar('DataclassT', bound='Dataclass')

Expand Down Expand Up @@ -103,7 +103,11 @@ def _process_class(
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)

# We need to remove `FieldInfo` values since they are not valid as input
# It's ok since they are obviously the default values
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)
Expand Down Expand Up @@ -144,22 +148,30 @@ def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
)
cls.__processed__ = ClassAttribute('__processed__', True)

fields: Dict[str, Any] = {}
field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(cls):
default: Any = Undefined
default_factory: Optional['NoArgAnyCallable'] = None
field_info: FieldInfo

if field.default != dataclasses.MISSING:
field_value = field.default
if field.default is not dataclasses.MISSING:
default = field.default
# mypy issue 7020 and 708
elif field.default_factory != dataclasses.MISSING: # type: ignore
field_value = field.default_factory() # type: ignore
elif field.default_factory is not dataclasses.MISSING: # type: ignore
default_factory = field.default_factory # type: ignore
else:
default = Required

if isinstance(default, FieldInfo):
field_info = default
else:
field_value = Required
field_info = Field(default=default, default_factory=default_factory, **field.metadata)

fields[field.name] = (field.type, field_value)
field_definitions[field.name] = (field.type, field_info)

validators = gather_all_validators(cls)
cls.__pydantic_model__ = create_model(
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **fields
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
)

cls.__initialised__ = False
Expand Down
18 changes: 16 additions & 2 deletions tests/test_dataclasses.py
Expand Up @@ -429,7 +429,7 @@ class User:
assert fields['id'].default is None

assert fields['aliases'].required is False
assert fields['aliases'].default == {'John': 'Joey'}
assert fields['aliases'].default_factory() == {'John': 'Joey'}


def test_default_factory_singleton_field():
Expand All @@ -456,6 +456,10 @@ class User:
name: str = 'John Doe'
aliases: Dict[str, str] = dataclasses.field(default_factory=lambda: {'John': 'Joey'})
signup_ts: datetime = None
age: Optional[int] = dataclasses.field(
default=None, metadata=dict(title='The age of the user', description='do not lie!')
)
height: Optional[int] = pydantic.Field(None, title='The height in cm', ge=50, le=300)

user = User(id=123)
assert user.__pydantic_model__.schema() == {
Expand All @@ -466,11 +470,21 @@ class User:
'name': {'title': 'Name', 'default': 'John Doe', 'type': 'string'},
'aliases': {
'title': 'Aliases',
'default': {'John': 'Joey'},
'type': 'object',
'additionalProperties': {'type': 'string'},
},
'signup_ts': {'title': 'Signup Ts', 'type': 'string', 'format': 'date-time'},
'age': {
'title': 'The age of the user',
'description': 'do not lie!',
'type': 'integer',
},
'height': {
'title': 'The height in cm',
'minimum': 50,
'maximum': 300,
'type': 'integer',
},
},
'required': ['id'],
}
Expand Down

0 comments on commit 149b10c

Please sign in to comment.