diff --git a/changes/2557-PrettyWood.md b/changes/2557-PrettyWood.md new file mode 100644 index 0000000000..e9cb1c52a9 --- /dev/null +++ b/changes/2557-PrettyWood.md @@ -0,0 +1,9 @@ +Refactor the whole _pydantic_ `dataclass` decorator to really act like its standard lib equivalent. +It hence keeps `__eq__`, `__hash__`, ... and makes comparison with its non-validated version possible. +It also fixes usage of `frozen` dataclasses in fields and usage of `default_factory` in nested dataclasses. +The support of `Config.extra` has been added. +Finally, config customization directly via a `dict` is now possible. +

+**BREAKING CHANGES** +- The `compiled` boolean (whether _pydantic_ is compiled with cython) has been moved from `main.py` to `version.py` +- Now that `Config.extra` is supported, `dataclass` ignores by default extra arguments (like `BaseModel`) diff --git a/docs/examples/dataclasses_config.py b/docs/examples/dataclasses_config.py new file mode 100644 index 0000000000..d7c4de52de --- /dev/null +++ b/docs/examples/dataclasses_config.py @@ -0,0 +1,26 @@ +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + + +# Option 1 - use directly a dict +# Note: `mypy` will still raise typo error +@dataclass(config=dict(validate_assignment=True)) +class MyDataclass1: + a: int + + +# Option 2 - use `ConfigDict` +# (same as before at runtime since it's a `TypedDict` but with intellisense) +@dataclass(config=ConfigDict(validate_assignment=True)) +class MyDataclass2: + a: int + + +# Option 3 - use a `Config` class like for a `BaseModel` +class Config: + validate_assignment = True + + +@dataclass(config=Config) +class MyDataclass3: + a: int diff --git a/docs/examples/dataclasses_stdlib_run_validation.py b/docs/examples/dataclasses_stdlib_run_validation.py new file mode 100644 index 0000000000..6d13c0af32 --- /dev/null +++ b/docs/examples/dataclasses_stdlib_run_validation.py @@ -0,0 +1,30 @@ +import dataclasses + +from pydantic import ValidationError +from pydantic.dataclasses import dataclass as pydantic_dataclass, set_validation + + +@dataclasses.dataclass +class User: + id: int + name: str + + +# Enhance stdlib dataclass +pydantic_dataclass(User) + + +user1 = User(id='whatever', name='I want') + +# validate data of `user1` +try: + user1.__pydantic_validate_values__() +except ValidationError as e: + print(e) + +# Enforce validation +try: + with set_validation(User, True): + User(id='whatever', name='I want') +except ValidationError as e: + print(e) diff --git a/docs/examples/dataclasses_stdlib_to_pydantic.py b/docs/examples/dataclasses_stdlib_to_pydantic.py index 88686b8df2..a3a04f62ce 100644 --- a/docs/examples/dataclasses_stdlib_to_pydantic.py +++ b/docs/examples/dataclasses_stdlib_to_pydantic.py @@ -16,20 +16,32 @@ class File(Meta): filename: str -File = pydantic.dataclasses.dataclass(File) +# `ValidatedFile` will be a proxy around `File` +ValidatedFile = pydantic.dataclasses.dataclass(File) -file = File( +# the original dataclass is the `__dataclass__` attribute +assert ValidatedFile.__dataclass__ is File + + +validated_file = ValidatedFile( filename=b'thefilename', modified_date='2020-01-01T00:00', seen_count='7', ) -print(file) +print(validated_file) try: - File( + ValidatedFile( filename=['not', 'a', 'string'], modified_date=None, seen_count=3, ) except pydantic.ValidationError as e: print(e) + +# `File` is not altered and still does no validation by default +print(File( + filename=['not', 'a', 'string'], + modified_date=None, + seen_count=3, +)) diff --git a/docs/usage/dataclasses.md b/docs/usage/dataclasses.md index a2bc2227cf..9bf09399a8 100644 --- a/docs/usage/dataclasses.md +++ b/docs/usage/dataclasses.md @@ -18,7 +18,7 @@ You can use all the standard _pydantic_ field types, and the resulting dataclass created by the standard library `dataclass` decorator. The underlying model and its schema can be accessed through `__pydantic_model__`. -Also, fields that require a `default_factory` can be specified by a `dataclasses.field`. +Also, fields that require a `default_factory` can be specified by either a `pydantic.Field` or a `dataclasses.field`. ```py {!.tmp_examples/dataclasses_default_schema.py!} @@ -34,6 +34,20 @@ keyword argument `config` which has the same meaning as [Config](model_config.md For more information about combining validators with dataclasses, see [dataclass validators](validators.md#dataclass-validators). +## Dataclass Config + +If you want to modify the `Config` like you would with a `BaseModel`, you have three options: + +```py +{!.tmp_examples/dataclasses_config.py!} +``` + +!!! warning + After v1.10, _pydantic_ dataclasses support `Config.extra` but some default behaviour of stdlib dataclasses + may prevail. For example, when `print`ing a _pydantic_ dataclass with allowed extra fields, it will still + use the `__str__` method of stdlib dataclass and show only the required fields. + This may be improved further in the future. + ## Nested dataclasses Nested dataclasses are supported both in dataclasses and normal models. @@ -51,12 +65,25 @@ Dataclasses attributes can be populated by tuples, dictionaries or instances of Stdlib dataclasses (nested or not) can be easily converted into _pydantic_ dataclasses by just decorating them with `pydantic.dataclasses.dataclass`. +_Pydantic_ will enhance the given stdlib dataclass but won't alter the default behaviour (i.e. without validation). +It will instead create a wrapper around it to trigger validation that will act like a plain proxy. +The stdlib dataclass can still be accessed via the `__dataclass__` attribute (see example below). ```py {!.tmp_examples/dataclasses_stdlib_to_pydantic.py!} ``` _(This script is complete, it should run "as is")_ +### Choose when to trigger validation + +As soon as your stdlib dataclass has been decorated with _pydantic_ dataclass decorator, magic methods have been +added to validate input data. If you want, you can still keep using your dataclass and choose when to trigger it. + +```py +{!.tmp_examples/dataclasses_stdlib_run_validation.py!} +``` +_(This script is complete, it should run "as is")_ + ### Inherit from stdlib dataclasses Stdlib dataclasses (nested or not) can also be inherited and _pydantic_ will automatically validate @@ -95,6 +122,11 @@ When you initialize a dataclass, it is possible to execute code *after* validati with the help of `__post_init_post_parse__`. This is not the same as `__post_init__`, which executes code *before* validation. +!!! tip + If you use a stdlib `dataclass`, you may only have `__post_init__` available and wish the validation to + be done before. In this case you can set `Config.post_init_call = 'after_validation'` + + ```py {!.tmp_examples/dataclasses_post_init_post_parse.py!} ``` diff --git a/docs/usage/model_config.md b/docs/usage/model_config.md index 1f1d3f97be..b6c15e06c1 100644 --- a/docs/usage/model_config.md +++ b/docs/usage/model_config.md @@ -118,6 +118,10 @@ not be included in the model schemas. **Note**: this means that attributes on th **`smart_union`** : whether _pydantic_ should try to check all types inside `Union` to prevent undesired coercion; see [the dedicated section](#smart-union) +**`post_init_call`** +: whether stdlib dataclasses `__post_init__` should be run before (default behaviour with value `'before_validation'`) + or after (value `'after_validation'`) parsing and validation when they are [converted](dataclasses.md#stdlib-dataclasses-and-_pydantic_-dataclasses). + ## Change behaviour globally If you wish to change the behaviour of _pydantic_ globally, you can create your own custom `BaseModel` diff --git a/pydantic/__init__.py b/pydantic/__init__.py index 982ea4755c..7c79fd1c87 100644 --- a/pydantic/__init__.py +++ b/pydantic/__init__.py @@ -2,7 +2,7 @@ from . import dataclasses from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict from .class_validators import root_validator, validator -from .config import BaseConfig, Extra +from .config import BaseConfig, ConfigDict, Extra from .decorator import validate_arguments from .env_settings import BaseSettings from .error_wrappers import ValidationError @@ -13,7 +13,7 @@ from .parse import Protocol from .tools import * from .types import * -from .version import VERSION +from .version import VERSION, compiled __version__ = VERSION @@ -30,6 +30,7 @@ 'validator', # config 'BaseConfig', + 'ConfigDict', 'Extra', # decorator 'validate_arguments', @@ -42,7 +43,6 @@ 'Required', # main 'BaseModel', - 'compiled', 'create_model', 'validate_model', # network @@ -120,5 +120,6 @@ 'PastDate', 'FutureDate', # version + 'compiled', 'VERSION', ] diff --git a/pydantic/config.py b/pydantic/config.py index ef4b3c008f..adb9fd4b21 100644 --- a/pydantic/config.py +++ b/pydantic/config.py @@ -4,6 +4,7 @@ from .typing import AnyCallable from .utils import GetterDict +from .version import compiled if TYPE_CHECKING: from typing import overload @@ -27,7 +28,7 @@ def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None else: SchemaExtraCallable = Callable[..., None] -__all__ = 'BaseConfig', 'Extra', 'inherit_config', 'prepare_config' +__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config' class Extra(str, Enum): @@ -36,6 +37,46 @@ class Extra(str, Enum): forbid = 'forbid' +# https://github.com/cython/cython/issues/4003 +# Will be fixed with Cython 3 but still in alpha right now +if not compiled: + from typing_extensions import Literal, TypedDict + + class ConfigDict(TypedDict, total=False): + title: Optional[str] + anystr_lower: bool + anystr_strip_whitespace: bool + min_anystr_length: int + max_anystr_length: Optional[int] + validate_all: bool + extra: Extra + allow_mutation: bool + frozen: bool + allow_population_by_field_name: bool + use_enum_values: bool + fields: Dict[str, Union[str, Dict[str, str]]] + validate_assignment: bool + error_msg_templates: Dict[str, str] + arbitrary_types_allowed: bool + orm_mode: bool + getter_dict: Type[GetterDict] + alias_generator: Optional[Callable[[str], str]] + keep_untouched: Tuple[type, ...] + schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] + json_loads: Callable[[str], Any] + json_dumps: Callable[..., str] + json_encoders: Dict[Type[Any], AnyCallable] + underscore_attrs_are_private: bool + + # whether or not inherited models as fields should be reconstructed as base model + copy_on_model_validation: bool + # whether dataclass `__post_init__` should be run after validation + post_init_call: Literal['before_validation', 'after_validation'] + +else: + ConfigDict = dict # type: ignore + + class BaseConfig: title: Optional[str] = None anystr_lower: bool = False @@ -66,6 +107,8 @@ class BaseConfig: copy_on_model_validation: bool = True # whether `Union` should check all allowed types before even trying to coerce smart_union: bool = False + # whether dataclass `__post_init__` should be run before or after validation + post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation' @classmethod def get_field_info(cls, name: str) -> Dict[str, Any]: @@ -100,6 +143,25 @@ def prepare_field(cls, field: 'ModelField') -> None: pass +def get_config(config: Union[ConfigDict, Type[BaseConfig], None]) -> Type[BaseConfig]: + if config is None: + return BaseConfig + + else: + config_dict = ( + config + if isinstance(config, dict) + else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')} + ) + + class Config(BaseConfig): + ... + + for k, v in config_dict.items(): + setattr(Config, k, v) + return Config + + def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType': if not self_config: base_classes: Tuple['ConfigType', ...] = (parent_config,) diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 692bfb9f3a..8bc12561d7 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -1,26 +1,69 @@ -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Optional, Type, TypeVar, Union, overload +""" +The main purpose is to enhance stdlib dataclasses by adding validation +A pydantic dataclass can be generated from scratch or from a stdlib one. + +Behind the scene, a pydantic dataclass is just like a regular one on which we attach +a `BaseModel` and magic methods to trigger the validation of the data. +`__init__` and `__post_init__` are hence overridden and have extra logic to be +able to validate input data. + +When a pydantic dataclass is generated from scratch, it's just a plain dataclass +with validation triggered at initialization + +The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g. + +```py +@dataclasses.dataclass +class M: + x: int + +ValidatedM = pydantic.dataclasses.dataclass(M) +``` + +We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one! + +```py +assert isinstance(ValidatedM(x=1), M) +assert ValidatedM(x=1) == M(x=1) +``` + +This means we **don't want to create a new dataclass that inherits from it** +The trick is to create a wrapper around `M` that will act as a proxy to trigger +validation without altering default `M` behaviour. +""" +from contextlib import contextmanager +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload from .class_validators import gather_all_validators +from .config import BaseConfig, ConfigDict, Extra, get_config from .error_wrappers import ValidationError from .errors import DataclassTypeError from .fields import Field, FieldInfo, Required, Undefined from .main import __dataclass_transform__, create_model, validate_model -from .typing import resolve_annotations from .utils import ClassAttribute if TYPE_CHECKING: - from .config import BaseConfig from .main import BaseModel from .typing import CallableGenerator, NoArgAnyCallable DataclassT = TypeVar('DataclassT', bound='Dataclass') + DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy'] + class Dataclass: + # stdlib attributes + __dataclass_fields__: ClassVar[Dict[str, Any]] + __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams` + __post_init__: ClassVar[Callable[..., None]] + + # Added by pydantic + __pydantic_run_validation__: ClassVar[bool] + __post_init_post_parse__: ClassVar[Callable[..., None]] + __pydantic_initialised__: ClassVar[bool] __pydantic_model__: ClassVar[Type[BaseModel]] - __initialised__: ClassVar[bool] - __post_init_original__: ClassVar[Optional[Callable[..., None]]] - __processed__: ClassVar[Optional[ClassAttribute]] - __has_field_info_default__: ClassVar[bool] # whether or not a `pydantic.Field` is used as default value + __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]] + __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value def __init__(self, *args: Any, **kwargs: Any) -> None: pass @@ -33,136 +76,233 @@ def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator': def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT': pass - def __call__(self: 'DataclassT', *args: Any, **kwargs: Any) -> 'DataclassT': - pass +__all__ = [ + 'dataclass', + 'set_validation', + 'create_pydantic_model_from_dataclass', + 'is_builtin_dataclass', + 'make_dataclass_validator', +] -def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT': - if isinstance(v, cls): - return v - elif isinstance(v, (list, tuple)): - return cls(*v) - elif isinstance(v, dict): - return cls(**v) - # In nested dataclasses, v can be of type `dataclasses.dataclass`. - # But to validate fields `cls` will be in fact a `pydantic.dataclasses.dataclass`, - # which inherits directly from the class of `v`. - elif is_builtin_dataclass(v) and cls.__bases__[0] is type(v): + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +@overload +def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[Any], None] = None, + validate_on_init: Optional[bool] = None, +) -> Callable[[Type[Any]], 'DataclassClassOrWrapper']: + ... + + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +@overload +def dataclass( + _cls: Type[Any], + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[Any], None] = None, + validate_on_init: Optional[bool] = None, +) -> 'DataclassClassOrWrapper': + ... + + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +def dataclass( + _cls: Optional[Type[Any]] = None, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[Any], None] = None, + validate_on_init: Optional[bool] = None, +) -> Union[Callable[[Type[Any]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: + """ + Like the python standard lib dataclasses but with type validation. + The result is either a pydantic dataclass that will validate input data + or a wrapper that will trigger validation around a stdlib dataclass + to avoid modifying it directly + """ + the_config = get_config(config) + + def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': import dataclasses - return cls(**dataclasses.asdict(v)) - else: - raise DataclassTypeError(class_name=cls.__name__) + if is_builtin_dataclass(cls): + dc_cls_doc = '' + dc_cls = DataclassProxy(cls) + default_validate_on_init = False + else: + dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass + dc_cls = dataclasses.dataclass( # type: ignore + cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + ) + default_validate_on_init = True + should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init + _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc) + dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) + return dc_cls -def _get_validators(cls: Type['Dataclass']) -> 'CallableGenerator': - yield cls.__validate__ + if _cls is None: + return wrap + return wrap(_cls) -def setattr_validate_assignment(self: 'Dataclass', name: str, value: Any) -> None: - if self.__initialised__: - d = dict(self.__dict__) - d.pop(name, None) - known_field = self.__pydantic_model__.__fields__.get(name, None) - if known_field: - value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) - if error_: - raise ValidationError([error_], self.__class__) - object.__setattr__(self, name, value) +@contextmanager +def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]: + original_run_validation = cls.__pydantic_run_validation__ + try: + cls.__pydantic_run_validation__ = value + yield cls + finally: + cls.__pydantic_run_validation__ = original_run_validation -def is_builtin_dataclass(_cls: Type[Any]) -> bool: +class DataclassProxy: + __slots__ = '__dataclass__' + + def __init__(self, dc_cls: Type['Dataclass']) -> None: + object.__setattr__(self, '__dataclass__', dc_cls) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + with set_validation(self.__dataclass__, True): + return self.__dataclass__(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self.__dataclass__, name) + + def __instancecheck__(self, instance: Any) -> bool: + return isinstance(instance, self.__dataclass__) + + +def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity) + dc_cls: Type['Dataclass'], + config: Type[BaseConfig], + validate_on_init: bool, + dc_cls_doc: str, +) -> None: """ - `dataclasses.is_dataclass` is True if one of the class parents is a `dataclass`. - This is why we also add a class attribute `__processed__` to only consider 'direct' built-in dataclasses + We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass + it won't even exist (code is generated on the fly by `dataclasses`) + By default, we run validation after `__init__` or `__post_init__` if defined """ - import dataclasses + init = dc_cls.__init__ - return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls) + @wraps(init) + def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + if config.extra == Extra.ignore: + init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) + elif config.extra == Extra.allow: + for k, v in kwargs.items(): + self.__dict__.setdefault(k, v) + init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) -def _generate_pydantic_post_init( - post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]] -) -> Callable[..., None]: - def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None: - if post_init_original is not None: - post_init_original(self, *initvars) - - if getattr(self, '__has_field_info_default__', False): - # We need to remove `FieldInfo` values since they are not valid as input - # It's ok to do that because they are obviously the default values! - input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)} else: - input_data = self.__dict__ - d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) - if validation_error: - raise validation_error - object.__setattr__(self, '__dict__', {**getattr(self, '__dict__', {}), **d}) - object.__setattr__(self, '__initialised__', True) - if post_init_post_parse is not None: - post_init_post_parse(self, *initvars) + init(self, *args, **kwargs) - return _pydantic_post_init + if hasattr(dc_cls, '__post_init__'): + post_init = dc_cls.__post_init__ + @wraps(post_init) + def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + if config.post_init_call == 'before_validation': + post_init(self, *args, **kwargs) -def _process_class( - _cls: Type[Any], - init: bool, - repr: bool, - eq: bool, - order: bool, - unsafe_hash: bool, - frozen: bool, - config: Optional[Type[Any]], -) -> Type['Dataclass']: - import dataclasses + if self.__class__.__pydantic_run_validation__: + self.__pydantic_validate_values__() + if hasattr(self, '__post_init_post_parse__'): + self.__post_init_post_parse__(*args, **kwargs) + + if config.post_init_call == 'after_validation': + post_init(self, *args, **kwargs) + + setattr(dc_cls, '__init__', handle_extra_init) + setattr(dc_cls, '__post_init__', new_post_init) - post_init_original = getattr(_cls, '__post_init__', None) - if post_init_original and post_init_original.__name__ == '_pydantic_post_init': - post_init_original = None - if not post_init_original: - post_init_original = getattr(_cls, '__post_init_original__', None) - - post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None) - - _pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse) - - # If the class is already a dataclass, __post_init__ will not be called automatically - # so no validation will be added. - # We hence create dynamically a new dataclass: - # ``` - # @dataclasses.dataclass - # class NewClass(_cls): - # __post_init__ = _pydantic_post_init - # ``` - # with the exact same fields as the base dataclass - # and register it on module level to address pickle problem: - # https://github.com/samuelcolvin/pydantic/issues/2111 - if is_builtin_dataclass(_cls): - uniq_class_name = f'_Pydantic_{_cls.__name__}_{id(_cls)}' - _cls = type( - # for pretty output new class will have the name as original - _cls.__name__, - (_cls,), - { - '__annotations__': resolve_annotations(_cls.__annotations__, _cls.__module__), - '__post_init__': _pydantic_post_init, - # attrs for pickle to find this class - '__module__': __name__, - '__qualname__': uniq_class_name, - }, - ) - globals()[uniq_class_name] = _cls else: - _cls.__post_init__ = _pydantic_post_init - cls: Type['Dataclass'] = dataclasses.dataclass( # type: ignore - _cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen - ) - cls.__processed__ = ClassAttribute('__processed__', True) + + @wraps(init) + def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + handle_extra_init(self, *args, **kwargs) + + if self.__class__.__pydantic_run_validation__: + self.__pydantic_validate_values__() + + if hasattr(self, '__post_init_post_parse__'): + # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of + # public method `dataclasses.fields` + import dataclasses + + # get all initvars and their default values + initvars_and_values: Dict[str, Any] = {} + for i, f in enumerate(self.__class__.__dataclass_fields__.values()): + if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined] + try: + # set arg value by default + initvars_and_values[f.name] = args[i] + except IndexError: + initvars_and_values[f.name] = f.default + initvars_and_values.update(kwargs) + + self.__post_init_post_parse__(**initvars_and_values) + + setattr(dc_cls, '__init__', new_init) + + setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init)) + setattr(dc_cls, '__pydantic_initialised__', False) + setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc)) + setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values) + setattr(dc_cls, '__validate__', classmethod(_validate_dataclass)) + setattr(dc_cls, '__get_validators__', classmethod(_get_validators)) + + if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: + setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr) + + +def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator': + yield cls.__validate__ + + +def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT': + with set_validation(cls, True): + if isinstance(v, cls): + v.__pydantic_validate_values__() + return v + elif isinstance(v, (list, tuple)): + return cls(*v) + elif isinstance(v, dict): + return cls(**v) + else: + raise DataclassTypeError(class_name=cls.__name__) + + +def create_pydantic_model_from_dataclass( + dc_cls: Type['Dataclass'], + config: Type[Any] = BaseConfig, + dc_cls_doc: Optional[str] = None, +) -> Type['BaseModel']: + import dataclasses field_definitions: Dict[str, Any] = {} - for field in dataclasses.fields(cls): + for field in dataclasses.fields(dc_cls): default: Any = Undefined default_factory: Optional['NoArgAnyCallable'] = None field_info: FieldInfo @@ -176,102 +316,87 @@ def _process_class( if isinstance(default, FieldInfo): field_info = default - cls.__has_field_info_default__ = True + dc_cls.__pydantic_has_field_info_default__ = True else: field_info = Field(default=default, default_factory=default_factory, **field.metadata) field_definitions[field.name] = (field.type, field_info) - validators = gather_all_validators(cls) - cls.__pydantic_model__ = create_model( - cls.__name__, + validators = gather_all_validators(dc_cls) + model: Type['BaseModel'] = create_model( + dc_cls.__name__, __config__=config, - __module__=_cls.__module__, + __module__=dc_cls.__module__, __validators__=validators, __cls_kwargs__={'__resolve_forward_refs__': False}, **field_definitions, ) + model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or '' + return model - cls.__initialised__ = False - cls.__validate__ = classmethod(_validate_dataclass) # type: ignore[assignment] - cls.__get_validators__ = classmethod(_get_validators) # type: ignore[assignment] - if post_init_original: - cls.__post_init_original__ = post_init_original - - if cls.__pydantic_model__.__config__.validate_assignment and not frozen: - cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment] - cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) - - return cls +def _dataclass_validate_values(self: 'Dataclass') -> None: + if getattr(self, '__pydantic_has_field_info_default__', False): + # We need to remove `FieldInfo` values since they are not valid as input + # It's ok to do that because they are obviously the default values! + input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)} + else: + input_data = self.__dict__ + d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) + if validation_error: + raise validation_error + self.__dict__.update(d) + object.__setattr__(self, '__pydantic_initialised__', True) -@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) -@overload -def dataclass( - *, - init: bool = True, - repr: bool = True, - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, - frozen: bool = False, - config: Type[Any] = None, -) -> Callable[[Type[Any]], Type['Dataclass']]: - ... - +def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: + if self.__pydantic_initialised__: + d = dict(self.__dict__) + d.pop(name, None) + known_field = self.__pydantic_model__.__fields__.get(name, None) + if known_field: + value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) + if error_: + raise ValidationError([error_], self.__class__) -@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) -@overload -def dataclass( - _cls: Type[Any], - *, - init: bool = True, - repr: bool = True, - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, - frozen: bool = False, - config: Type[Any] = None, -) -> Type['Dataclass']: - ... + object.__setattr__(self, name, value) -@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) -def dataclass( - _cls: Optional[Type[Any]] = None, - *, - init: bool = True, - repr: bool = True, - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, - frozen: bool = False, - config: Type[Any] = None, -) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]: +def is_builtin_dataclass(_cls: Type[Any]) -> bool: """ - Like the python standard lib dataclasses but with type validation. - - Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning - as Config.validate_assignment. + Whether a class is a stdlib dataclass + (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass) + + we check that + - `_cls` is a dataclass + - `_cls` is not a processed pydantic dataclass (with a basemodel attached) + - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass + e.g. + ``` + @dataclasses.dataclass + class A: + x: int + + @pydantic.dataclasses.dataclass + class B(A): + y: int + ``` + In this case, when we first check `B`, we make an extra check and look at the annotations ('y'), + which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x') """ + import dataclasses - def wrap(cls: Type[Any]) -> Type['Dataclass']: - return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config) - - if _cls is None: - return wrap - - return wrap(_cls) + return ( + dataclasses.is_dataclass(_cls) + and not hasattr(_cls, '__pydantic_model__') + and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {}))) + ) -def make_dataclass_validator(_cls: Type[Any], config: Type['BaseConfig']) -> 'CallableGenerator': +def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator': """ Create a pydantic.dataclass from a builtin dataclass to add type validation and yield the validators It retrieves the parameters of the dataclass and forwards them to the newly created dataclass """ - dataclass_params = _cls.__dataclass_params__ - stdlib_dataclass_parameters = {param: getattr(dataclass_params, param) for param in dataclass_params.__slots__} - cls = dataclass(_cls, config=config, **stdlib_dataclass_parameters) - yield from _get_validators(cls) + yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False)) diff --git a/pydantic/fields.py b/pydantic/fields.py index 7b55bcf637..09c70b449f 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -1131,8 +1131,8 @@ def is_complex(self) -> bool: return ( self.shape != SHAPE_SINGLETON + or hasattr(self.type_, '__pydantic_model__') or lenient_issubclass(self.type_, (BaseModel, list, set, frozenset, dict)) - or hasattr(self.type_, '__pydantic_model__') # pydantic dataclass ) def _type_display(self) -> PyObjectStr: diff --git a/pydantic/main.py b/pydantic/main.py index 0c20d9e69d..c8fac4bf94 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -78,18 +78,7 @@ Model = TypeVar('Model', bound='BaseModel') - -try: - import cython # type: ignore -except ImportError: - compiled: bool = False -else: # pragma: no cover - try: - compiled = cython.compiled - except AttributeError: - compiled = False - -__all__ = 'BaseModel', 'compiled', 'create_model', 'validate_model' +__all__ = 'BaseModel', 'create_model', 'validate_model' _T = TypeVar('_T') diff --git a/pydantic/schema.py b/pydantic/schema.py index 4f29bd5c74..057fc70abb 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -381,21 +381,14 @@ def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) -> :param known_models: used to solve circular references :return: a set with the model used in the declaration for this field, if any, and all its sub-models """ - from .dataclasses import dataclass, is_builtin_dataclass from .main import BaseModel flat_models: TypeModelSet = set() - # Handle dataclass-based models - if is_builtin_dataclass(field.type_): - field.type_ = dataclass(field.type_) - was_dataclass = True - else: - was_dataclass = False field_type = field.type_ if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): field_type = field_type.__pydantic_model__ - if field.sub_fields and (not lenient_issubclass(field_type, BaseModel) or was_dataclass): + if field.sub_fields and not lenient_issubclass(field_type, BaseModel): flat_models |= get_flat_models_from_fields(field.sub_fields, known_models=known_models) elif lenient_issubclass(field_type, BaseModel) and field_type not in known_models: flat_models |= get_flat_models_from_model(field_type, known_models=known_models) diff --git a/pydantic/types.py b/pydantic/types.py index 2d0cc18f8d..d93aad1648 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -117,7 +117,7 @@ from .main import BaseModel from .typing import CallableGenerator - ModelOrDc = Type[Union['BaseModel', 'Dataclass']] + ModelOrDc = Type[Union[BaseModel, Dataclass]] T = TypeVar('T') _DEFINED_TYPES: 'WeakSet[type]' = WeakSet() diff --git a/pydantic/version.py b/pydantic/version.py index 5b1ebc33e9..5a88a3d6e8 100644 --- a/pydantic/version.py +++ b/pydantic/version.py @@ -1,7 +1,17 @@ -__all__ = 'VERSION', 'version_info' +__all__ = 'compiled', 'VERSION', 'version_info' VERSION = '1.9.0' +try: + import cython # type: ignore +except ImportError: + compiled: bool = False +else: # pragma: no cover + try: + compiled = cython.compiled + except AttributeError: + compiled = False + def version_info() -> str: import platform @@ -9,8 +19,6 @@ def version_info() -> str: from importlib import import_module from pathlib import Path - from .main import compiled - optional_deps = [] for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'): try: diff --git a/tests/mypy/outputs/plugin-fail-strict.txt b/tests/mypy/outputs/plugin-fail-strict.txt index 6dbaff7cda..75a384974c 100644 --- a/tests/mypy/outputs/plugin-fail-strict.txt +++ b/tests/mypy/outputs/plugin-fail-strict.txt @@ -32,8 +32,4 @@ 185: error: Unexpected keyword argument "x" for "AliasGeneratorModel2" [call-arg] 186: error: Unexpected keyword argument "z" for "AliasGeneratorModel2" [call-arg] 189: error: Name "Missing" is not defined [name-defined] -197: error: No overload variant of "dataclass" matches argument type "Dict[, ]" [call-overload] -197: note: Possible overload variants: -197: note: def dataclass(*, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ..., unsafe_hash: bool = ..., frozen: bool = ..., config: Optional[Type[Any]] = ...) -> Callable[[Type[Any]], Type[Dataclass]] -197: note: def dataclass(_cls: Type[Any], *, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ..., unsafe_hash: bool = ..., frozen: bool = ..., config: Optional[Type[Any]] = ...) -> Type[Dataclass] -219: error: Property "y" defined in "FrozenModel" is read-only [misc] \ No newline at end of file +219: error: Property "y" defined in "FrozenModel" is read-only [misc] diff --git a/tests/mypy/outputs/plugin-fail.txt b/tests/mypy/outputs/plugin-fail.txt index 3176bc2dce..545575a2c4 100644 --- a/tests/mypy/outputs/plugin-fail.txt +++ b/tests/mypy/outputs/plugin-fail.txt @@ -21,8 +21,4 @@ 175: error: Unused "type: ignore" comment 182: error: Unused "type: ignore" comment 189: error: Name "Missing" is not defined [name-defined] -197: error: No overload variant of "dataclass" matches argument type "Dict[, ]" [call-overload] -197: note: Possible overload variants: -197: note: def dataclass(*, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ..., unsafe_hash: bool = ..., frozen: bool = ..., config: Optional[Type[Any]] = ...) -> Callable[[Type[Any]], Type[Dataclass]] -197: note: def dataclass(_cls: Type[Any], *, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ..., unsafe_hash: bool = ..., frozen: bool = ..., config: Optional[Type[Any]] = ...) -> Type[Dataclass] -219: error: Property "y" defined in "FrozenModel" is read-only [misc] \ No newline at end of file +219: error: Property "y" defined in "FrozenModel" is read-only [misc] diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 434742a001..ffcca68c6e 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -1,15 +1,16 @@ import dataclasses import pickle +import re from collections.abc import Hashable from datetime import datetime from pathlib import Path -from typing import Callable, ClassVar, Dict, FrozenSet, List, Optional, Union +from typing import Callable, ClassVar, Dict, FrozenSet, List, Optional, Set, Union import pytest from typing_extensions import Literal import pydantic -from pydantic import BaseModel, ValidationError, validator +from pydantic import BaseModel, Extra, ValidationError, validator def test_simple(): @@ -79,10 +80,7 @@ class MyDataclass: def test_validate_assignment_error(): - class Config: - validate_assignment = True - - @pydantic.dataclasses.dataclass(config=Config) + @pydantic.dataclasses.dataclass(config=dict(validate_assignment=True)) class MyDataclass: a: int @@ -159,6 +157,22 @@ def __post_init__(self): assert post_init_called +def test_post_init_validation(): + @dataclasses.dataclass + class DC: + a: int + + def __post_init__(self): + self.a *= 2 + + def __post_init_post_parse__(self): + self.a += 1 + + PydanticDC = pydantic.dataclasses.dataclass(DC) + assert DC(a='2').a == '22' + assert PydanticDC(a='2').a == 23 + + def test_post_init_inheritance_chain(): parent_post_init_called = False post_init_called = False @@ -659,13 +673,22 @@ class File: size: int content: Optional[bytes] = None - FileChecked = pydantic.dataclasses.dataclass(File) - f = FileChecked(hash='xxx', name=b'whatever.txt', size='456') - assert f.name == 'whatever.txt' - assert f.size == 456 + ValidFile = pydantic.dataclasses.dataclass(File) + + file = File(hash='xxx', name=b'whatever.txt', size='456') + valid_file = ValidFile(hash='xxx', name=b'whatever.txt', size='456') + + assert file.name == b'whatever.txt' + assert file.size == '456' + + assert valid_file.name == 'whatever.txt' + assert valid_file.size == 456 + + assert isinstance(valid_file, File) + assert isinstance(valid_file, ValidFile) with pytest.raises(ValidationError) as e: - FileChecked(hash=[1], name='name', size=3) + ValidFile(hash=[1], name='name', size=3) assert e.value.errors() == [{'loc': ('hash',), 'msg': 'str type expected', 'type': 'type_error.str'}] @@ -675,11 +698,15 @@ class Meta: modified_date: Optional[datetime] seen_count: int + Meta(modified_date='not-validated', seen_count=0) + @pydantic.dataclasses.dataclass @dataclasses.dataclass class File(Meta): filename: str + Meta(modified_date='still-not-validated', seen_count=0) + f = File(filename=b'thefilename', modified_date='2020-01-01T00:00', seen_count='7') assert f.filename == 'thefilename' assert f.modified_date == datetime(2020, 1, 1, 0, 0) @@ -923,6 +950,245 @@ class A2: } +def gen_2162_dataclasses(): + @dataclasses.dataclass(frozen=True) + class StdLibFoo: + a: str + b: int + + @pydantic.dataclasses.dataclass(frozen=True) + class PydanticFoo: + a: str + b: int + + @dataclasses.dataclass(frozen=True) + class StdLibBar: + c: StdLibFoo + + @pydantic.dataclasses.dataclass(frozen=True) + class PydanticBar: + c: PydanticFoo + + @dataclasses.dataclass(frozen=True) + class StdLibBaz: + c: PydanticFoo + + @pydantic.dataclasses.dataclass(frozen=True) + class PydanticBaz: + c: StdLibFoo + + foo = StdLibFoo(a='Foo', b=1) + yield foo, StdLibBar(c=foo) + + foo = PydanticFoo(a='Foo', b=1) + yield foo, PydanticBar(c=foo) + + foo = PydanticFoo(a='Foo', b=1) + yield foo, StdLibBaz(c=foo) + + foo = StdLibFoo(a='Foo', b=1) + yield foo, PydanticBaz(c=foo) + + +@pytest.mark.parametrize('foo,bar', gen_2162_dataclasses()) +def test_issue_2162(foo, bar): + assert dataclasses.asdict(foo) == dataclasses.asdict(bar.c) + assert dataclasses.astuple(foo) == dataclasses.astuple(bar.c) + assert foo == bar.c + + +def test_issue_2383(): + @dataclasses.dataclass + class A: + s: str + + def __hash__(self): + return 123 + + class B(pydantic.BaseModel): + a: A + + a = A('') + b = B(a=a) + + assert hash(a) == 123 + assert hash(b.a) == 123 + + +def test_issue_2398(): + @dataclasses.dataclass(order=True) + class DC: + num: int = 42 + + class Model(pydantic.BaseModel): + dc: DC + + real_dc = DC() + model = Model(dc=real_dc) + + # This works as expected. + assert real_dc <= real_dc + assert model.dc <= model.dc + assert real_dc <= model.dc + + +def test_issue_2424(): + @dataclasses.dataclass + class Base: + x: str + + @dataclasses.dataclass + class Thing(Base): + y: str = dataclasses.field(default_factory=str) + + assert Thing(x='hi').y == '' + + @pydantic.dataclasses.dataclass + class ValidatedThing(Base): + y: str = dataclasses.field(default_factory=str) + + assert Thing(x='hi').y == '' + assert ValidatedThing(x='hi').y == '' + + +def test_issue_2541(): + @dataclasses.dataclass(frozen=True) + class Infos: + id: int + + @dataclasses.dataclass(frozen=True) + class Item: + name: str + infos: Infos + + class Example(BaseModel): + item: Item + + e = Example.parse_obj({'item': {'name': 123, 'infos': {'id': '1'}}}) + assert e.item.name == '123' + assert e.item.infos.id == 1 + with pytest.raises(dataclasses.FrozenInstanceError): + e.item.infos.id = 2 + + +def test_issue_2555(): + @dataclasses.dataclass + class Span: + first: int + last: int + + @dataclasses.dataclass + class LabeledSpan(Span): + label: str + + @dataclasses.dataclass + class BinaryRelation: + subject: LabeledSpan + object: LabeledSpan + label: str + + @dataclasses.dataclass + class Sentence: + relations: BinaryRelation + + class M(pydantic.BaseModel): + s: Sentence + + assert M.schema() + + +def test_issue_2594(): + @dataclasses.dataclass + class Empty: + pass + + @pydantic.dataclasses.dataclass + class M: + e: Empty + + assert isinstance(M(e={}).e, Empty) + + +def test_schema_description_unset(): + @pydantic.dataclasses.dataclass + class A: + x: int + + assert 'description' not in A.__pydantic_model__.schema() + + @pydantic.dataclasses.dataclass + @dataclasses.dataclass + class B: + x: int + + assert 'description' not in B.__pydantic_model__.schema() + + +def test_schema_description_set(): + @pydantic.dataclasses.dataclass + class A: + """my description""" + + x: int + + assert A.__pydantic_model__.schema()['description'] == 'my description' + + @pydantic.dataclasses.dataclass + @dataclasses.dataclass + class B: + """my description""" + + x: int + + assert A.__pydantic_model__.schema()['description'] == 'my description' + + +def test_issue_3011(): + @dataclasses.dataclass + class A: + thing_a: str + + class B(A): + thing_b: str + + class Config: + arbitrary_types_allowed = True + + @pydantic.dataclasses.dataclass(config=Config) + class C: + thing: A + + b = B('Thing A') + c = C(thing=b) + assert c.thing.thing_a == 'Thing A' + + +def test_issue_3162(): + @dataclasses.dataclass + class User: + id: int + name: str + + class Users(BaseModel): + user: User + other_user: User + + assert Users.schema() == { + 'title': 'Users', + 'type': 'object', + 'properties': {'user': {'$ref': '#/definitions/User'}, 'other_user': {'$ref': '#/definitions/User'}}, + 'required': ['user', 'other_user'], + 'definitions': { + 'User': { + 'title': 'User', + 'type': 'object', + 'properties': {'id': {'title': 'Id', 'type': 'integer'}, 'name': {'title': 'Name', 'type': 'string'}}, + 'required': ['id', 'name'], + } + }, + } + + def test_discrimated_union_basemodel_instance_value(): @pydantic.dataclasses.dataclass class A: @@ -966,6 +1232,24 @@ class Top: } +def test_post_init_after_validation(): + @dataclasses.dataclass + class SetWrapper: + set: Set[int] + + def __post_init__(self): + assert isinstance( + self.set, set + ), f"self.set should be a set but it's {self.set!r} of type {type(self.set).__name__}" + + class Model(pydantic.BaseModel, post_init_call='after_validation'): + set_wrapper: SetWrapper + + model = Model(set_wrapper=SetWrapper({1, 2, 3})) + json_text = model.json() + assert Model.parse_raw(json_text) == model + + def test_keeps_custom_properties(): class StandardClass: """Class which modifies instance creation.""" @@ -991,6 +1275,71 @@ def __new__(cls, *args, **kwargs): assert instance.a == test_string +def test_ignore_extra(): + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.ignore)) + class Foo: + x: int + + foo = Foo(**{'x': '1', 'y': '2'}) + assert foo.__dict__ == {'x': 1, '__pydantic_initialised__': True} + + +def test_ignore_extra_subclass(): + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.ignore)) + class Foo: + x: int + + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.ignore)) + class Bar(Foo): + y: int + + bar = Bar(**{'x': '1', 'y': '2', 'z': '3'}) + assert bar.__dict__ == {'x': 1, 'y': 2, '__pydantic_initialised__': True} + + +def test_allow_extra(): + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.allow)) + class Foo: + x: int + + foo = Foo(**{'x': '1', 'y': '2'}) + assert foo.__dict__ == {'x': 1, 'y': '2', '__pydantic_initialised__': True} + + +def test_allow_extra_subclass(): + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.allow)) + class Foo: + x: int + + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.allow)) + class Bar(Foo): + y: int + + bar = Bar(**{'x': '1', 'y': '2', 'z': '3'}) + assert bar.__dict__ == {'x': 1, 'y': 2, 'z': '3', '__pydantic_initialised__': True} + + +def test_forbid_extra(): + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.forbid)) + class Foo: + x: int + + with pytest.raises(TypeError, match=re.escape("__init__() got an unexpected keyword argument 'y'")): + Foo(**{'x': '1', 'y': '2'}) + + +def test_post_init_allow_extra(): + @pydantic.dataclasses.dataclass(config=dict(extra=Extra.allow)) + class Foobar: + a: int + b: str + + def __post_init__(self): + self.a *= 2 + + assert Foobar(a=1, b='a', c=4).__dict__ == {'a': 2, 'b': 'a', 'c': 4, '__pydantic_initialised__': True} + + def test_self_reference_dataclass(): @pydantic.dataclasses.dataclass class MyDataclass: