Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Mar 21, 2021
1 parent 4ec6c52 commit f0b6777
Showing 1 changed file with 124 additions and 0 deletions.
124 changes: 124 additions & 0 deletions pydantic/dataclasses2.py
@@ -0,0 +1,124 @@
"""
The main purpose is to enhance stdlib dataclasses by adding validation
We also want to keep the dataclass untouched to still support the default hashing,
equality, repr, ...
This means we **don't want to create a new dataclass that inherits from it**
To make this happen, we first attach a `BaseModel` to the dataclass
and magic methods to trigger the validation of the data.
Now the problem is: for a stdlib dataclass `Item` that now has magic attributes for pydantic
how can we have a new class `ValidatedItem` to trigger validation by default and keep `Item`
behaviour untouched!
To do this `ValidatedItem` will in fact be a wrapper around `Item`, that acts like a proxy.
This wrapper will just inject an extra kwarg `__pydantic_run_validation__` for `ValidatedItem`
and not for `Item`! (Note that this can always be injected "a la mano" if needed)
"""
from typing import Any, Callable, Dict, Optional, Type, Union

from pydantic import create_model, validate_model
from pydantic.class_validators import gather_all_validators
from pydantic.fields import Field, FieldInfo, Required, Undefined


def dataclass(
_cls: Optional[Type[Any]] = None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]:
"""
Like the python standard lib dataclasses but with type validation.
Arguments are the same as for standard dataclasses, except for `validate_assignment`, which
has the same meaning as `Config.validate_assignment`.
"""

def wrap(cls: Type[Any]) -> Type['Dataclass']:
import dataclasses

if not dataclasses.is_dataclass(cls):
cls = dataclasses.dataclass(
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
return WithValidationWrapper(cls, config)

if _cls is None:
return wrap

return wrap(_cls)


class WithValidationWrapper:
def __init__(self, dc, config):
if not hasattr(dc, '__pydantic_model__'):
add_pydantic_validation_attributes(dc, config)
self.dc = dc

def __getattr__(self, attr):
return getattr(self.dc, attr)

def __call__(self, *args, **kwargs):
# By default we run the validation with the wrapper but can still be overwritten
kwargs.setdefault('__pydantic_run_validation__', True)
return self.dc(*args, **kwargs)


def add_pydantic_validation_attributes(cls, config):
# We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
# it won't even exist (code is generated on the fly by `dataclasses`)
init_or_post_init_name = '__post_init__' if hasattr(cls, '__post_init__') else '__init__'
init_or_post_init = getattr(cls, init_or_post_init_name)

def new_init_or_post_init(self, *args, __pydantic_run_validation__: bool = False, **kwargs):
init_or_post_init(self, *args, **kwargs)
if __pydantic_run_validation__:
self.__pydantic_validate_values__()

setattr(cls, init_or_post_init_name, new_init_or_post_init)
setattr(cls, '__pydantic_model__', create_pydantic_model_from_dataclass(cls, config))
setattr(cls, '__pydantic_validate_values__', dataclass_validate_values)


def create_pydantic_model_from_dataclass(cls, config=None):
import dataclasses

field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(cls):
default: Any = Undefined
default_factory = None
field_info: FieldInfo

if field.default is not dataclasses.MISSING:
default = field.default
# mypy issue 7020 and 708
elif field.default_factory is not dataclasses.MISSING: # type: ignore
default_factory = field.default_factory # type: ignore
else:
default = Required

if isinstance(default, FieldInfo):
field_info = default
cls.__has_field_info_default__ = True
else:
field_info = Field(default=default, default_factory=default_factory, **field.metadata)

field_definitions[field.name] = (field.type, field_info)

validators = gather_all_validators(cls)
return create_model(
cls.__name__, __config__=config, __module__=cls.__module__, __validators__=validators, **field_definitions
)


def dataclass_validate_values(self):
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)

0 comments on commit f0b6777

Please sign in to comment.