diff --git a/pydantic/__init__.py b/pydantic/__init__.py index 2e7aab406b3..2b8fbfa644b 100644 --- a/pydantic/__init__.py +++ b/pydantic/__init__.py @@ -1,5 +1,5 @@ # flake8: noqa -from . import dataclasses +from . import dataclasses, dataclasses2 from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict from .class_validators import root_validator, validator from .decorator import validate_arguments @@ -22,6 +22,7 @@ 'create_model_from_typeddict', # dataclasses 'dataclasses', + 'dataclasses2', # class_validators 'root_validator', 'validator', diff --git a/pydantic/dataclasses2.py b/pydantic/dataclasses2.py new file mode 100644 index 00000000000..ac13c4164fb --- /dev/null +++ b/pydantic/dataclasses2.py @@ -0,0 +1,188 @@ +""" +The main purpose is to enhance stdlib dataclasses by adding validation +We also want to keep the dataclass untouched to still support the default hashing, +equality, repr, ... +This means we **don't want to create a new dataclass that inherits from it** + +To make this happen, we first attach a `BaseModel` to the dataclass +and magic methods to trigger the validation of the data. + +Now the problem is: for a stdlib dataclass `Item` that now has magic attributes for pydantic +how can we have a new class `ValidatedItem` to trigger validation by default and keep `Item` +behaviour untouched! + +To do this `ValidatedItem` will in fact be an instance of `DataclassWrapper`, a simple wrapper +around `Item` that acts like a proxy. +This wrapper will just inject an extra kwarg `__pydantic_run_validation__` for `ValidatedItem` +and not for `Item`! (Note that this can always be injected "a la mano" if needed) +""" +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union, overload + +from .class_validators import gather_all_validators +from .error_wrappers import ValidationError +from .fields import Field, FieldInfo, Required, Undefined +from .main import create_model, validate_model + +if TYPE_CHECKING: + from .main import BaseConfig, BaseModel # noqa: F401 + + class Dataclass: + # stdlib attributes + __dataclass_params__: Any # in reality `dataclasses._DataclassParams` + + # Added by pydantic + __pydantic_model__: Type[BaseModel] + __pydantic_validate_values__: Callable[['Dataclass'], None] + __has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value + + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + +@overload +def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Type[Any] = None, +) -> Callable[[Type[Any]], 'DataclassWrapper']: + ... + + +@overload +def dataclass( + _cls: Type[Any], + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Type[Any] = None, +) -> 'DataclassWrapper': + ... + + +def dataclass( + _cls: Optional[Type[Any]] = None, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Optional[Type['BaseConfig']] = None, +) -> Union[Callable[[Type[Any]], 'DataclassWrapper'], 'DataclassWrapper']: + """ + Like the python standard lib dataclasses but with type validation. + + Arguments are the same as for standard dataclasses, except for `validate_assignment`, which + has the same meaning as `Config.validate_assignment`. + """ + + def wrap(cls: Type[Any]) -> DataclassWrapper: + import dataclasses + + if not dataclasses.is_dataclass(cls): + cls = dataclasses.dataclass( # type: ignore + cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + ) + return DataclassWrapper(cls, config) + + if _cls is None: + return wrap + + return wrap(_cls) + + +class DataclassWrapper: + def __init__(self, dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']]) -> None: + if not hasattr(dc_cls, '__pydantic_model__'): + add_pydantic_validation_attributes(dc_cls, config) + self.dc_cls = dc_cls + + def __getattr__(self, attr: str) -> Any: + return getattr(self.dc_cls, attr) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # By default we run the validation with the wrapper but can still be overwritten + kwargs.setdefault('__pydantic_run_validation__', True) + return self.dc_cls(*args, **kwargs) + + +def add_pydantic_validation_attributes(dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']]) -> None: + # We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass + # it won't even exist (code is generated on the fly by `dataclasses`) + init_or_post_init_name = '__post_init__' if hasattr(dc_cls, '__post_init__') else '__init__' + init_or_post_init = getattr(dc_cls, init_or_post_init_name) + + def new_init_or_post_init( + self: 'Dataclass', *args: Any, __pydantic_run_validation__: bool = False, **kwargs: Any + ) -> None: + init_or_post_init(self, *args, **kwargs) + if __pydantic_run_validation__: + self.__pydantic_validate_values__() + + setattr(dc_cls, init_or_post_init_name, new_init_or_post_init) + setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config)) + setattr(dc_cls, '__pydantic_validate_values__', dataclass_validate_values) + if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: + setattr(dc_cls, '__setattr__', dataclass_validate_assignment_setattr) + + +def create_pydantic_model_from_dataclass( + dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']] = None +) -> Type['BaseModel']: + import dataclasses + + field_definitions: Dict[str, Any] = {} + for field in dataclasses.fields(dc_cls): + default: Any = Undefined + default_factory = None + field_info: FieldInfo + + if field.default is not dataclasses.MISSING: + default = field.default + # mypy issue 7020 and 708 + elif field.default_factory is not dataclasses.MISSING: # type: ignore + default_factory = field.default_factory # type: ignore + else: + default = Required + + if isinstance(default, FieldInfo): + field_info = default + dc_cls.__has_field_info_default__ = True + else: + field_info = Field(default=default, default_factory=default_factory, **field.metadata) + + field_definitions[field.name] = (field.type, field_info) + + validators = gather_all_validators(dc_cls) + return create_model( + dc_cls.__name__, __config__=config, __module__=dc_cls.__module__, __validators__=validators, **field_definitions + ) + + +def dataclass_validate_values(self: 'Dataclass') -> None: + d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__) + if validation_error: + raise validation_error + object.__setattr__(self, '__dict__', d) + + +def dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: + d = dict(self.__dict__) + d.pop(name, None) + known_field = self.__pydantic_model__.__fields__.get(name, None) + if known_field: + value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) + if error_: + raise ValidationError([error_], self.__class__) + + object.__setattr__(self, name, value) diff --git a/pydantic/types.py b/pydantic/types.py index 2e4eb284e31..bbbd876e54b 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -106,10 +106,11 @@ if TYPE_CHECKING: from .dataclasses import Dataclass # noqa: F401 + from .dataclasses2 import Dataclass as Dataclass2 # noqa: F401 from .main import BaseConfig, BaseModel # noqa: F401 from .typing import CallableGenerator - ModelOrDc = Type[Union['BaseModel', 'Dataclass']] + ModelOrDc = Type[Union['BaseModel', 'Dataclass', 'Dataclass2']] T = TypeVar('T') _DEFINED_TYPES: 'WeakSet[type]' = WeakSet()