From f0b6777b5423bc29b96288cc2345bf84ec9d4bf8 Mon Sep 17 00:00:00 2001 From: PrettyWood Date: Sun, 21 Mar 2021 17:58:16 +0100 Subject: [PATCH] wip --- pydantic/dataclasses2.py | 124 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 pydantic/dataclasses2.py diff --git a/pydantic/dataclasses2.py b/pydantic/dataclasses2.py new file mode 100644 index 00000000000..e1dc26b7502 --- /dev/null +++ b/pydantic/dataclasses2.py @@ -0,0 +1,124 @@ +""" +The main purpose is to enhance stdlib dataclasses by adding validation +We also want to keep the dataclass untouched to still support the default hashing, +equality, repr, ... +This means we **don't want to create a new dataclass that inherits from it** + +To make this happen, we first attach a `BaseModel` to the dataclass +and magic methods to trigger the validation of the data. + +Now the problem is: for a stdlib dataclass `Item` that now has magic attributes for pydantic +how can we have a new class `ValidatedItem` to trigger validation by default and keep `Item` +behaviour untouched! + +To do this `ValidatedItem` will in fact be a wrapper around `Item`, that acts like a proxy. +This wrapper will just inject an extra kwarg `__pydantic_run_validation__` for `ValidatedItem` +and not for `Item`! (Note that this can always be injected "a la mano" if needed) +""" +from typing import Any, Callable, Dict, Optional, Type, Union + +from pydantic import create_model, validate_model +from pydantic.class_validators import gather_all_validators +from pydantic.fields import Field, FieldInfo, Required, Undefined + + +def dataclass( + _cls: Optional[Type[Any]] = None, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Type[Any] = None, +) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]: + """ + Like the python standard lib dataclasses but with type validation. + + Arguments are the same as for standard dataclasses, except for `validate_assignment`, which + has the same meaning as `Config.validate_assignment`. + """ + + def wrap(cls: Type[Any]) -> Type['Dataclass']: + import dataclasses + + if not dataclasses.is_dataclass(cls): + cls = dataclasses.dataclass( + cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + ) + return WithValidationWrapper(cls, config) + + if _cls is None: + return wrap + + return wrap(_cls) + + +class WithValidationWrapper: + def __init__(self, dc, config): + if not hasattr(dc, '__pydantic_model__'): + add_pydantic_validation_attributes(dc, config) + self.dc = dc + + def __getattr__(self, attr): + return getattr(self.dc, attr) + + def __call__(self, *args, **kwargs): + # By default we run the validation with the wrapper but can still be overwritten + kwargs.setdefault('__pydantic_run_validation__', True) + return self.dc(*args, **kwargs) + + +def add_pydantic_validation_attributes(cls, config): + # We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass + # it won't even exist (code is generated on the fly by `dataclasses`) + init_or_post_init_name = '__post_init__' if hasattr(cls, '__post_init__') else '__init__' + init_or_post_init = getattr(cls, init_or_post_init_name) + + def new_init_or_post_init(self, *args, __pydantic_run_validation__: bool = False, **kwargs): + init_or_post_init(self, *args, **kwargs) + if __pydantic_run_validation__: + self.__pydantic_validate_values__() + + setattr(cls, init_or_post_init_name, new_init_or_post_init) + setattr(cls, '__pydantic_model__', create_pydantic_model_from_dataclass(cls, config)) + setattr(cls, '__pydantic_validate_values__', dataclass_validate_values) + + +def create_pydantic_model_from_dataclass(cls, config=None): + import dataclasses + + field_definitions: Dict[str, Any] = {} + for field in dataclasses.fields(cls): + default: Any = Undefined + default_factory = None + field_info: FieldInfo + + if field.default is not dataclasses.MISSING: + default = field.default + # mypy issue 7020 and 708 + elif field.default_factory is not dataclasses.MISSING: # type: ignore + default_factory = field.default_factory # type: ignore + else: + default = Required + + if isinstance(default, FieldInfo): + field_info = default + cls.__has_field_info_default__ = True + else: + field_info = Field(default=default, default_factory=default_factory, **field.metadata) + + field_definitions[field.name] = (field.type, field_info) + + validators = gather_all_validators(cls) + return create_model( + cls.__name__, __config__=config, __module__=cls.__module__, __validators__=validators, **field_definitions + ) + + +def dataclass_validate_values(self): + d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__) + if validation_error: + raise validation_error + object.__setattr__(self, '__dict__', d)