From bf4b5ceb49043c7165499a51b41e340928ddfd8e Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Thu, 29 Dec 2022 12:15:12 +0100 Subject: [PATCH] fix: use dataclasses proxy for frozen or empty dataclasses (#4878) * add tests * fix: dataclasses * chore: add change file * refactor: remove useless kwarg * test: add new test * keep old kwarg to avoid breaking change --- changes/4878-PrettyWood.md | 1 + pydantic/dataclasses.py | 40 +++++++--------- tests/test_dataclasses.py | 95 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 24 deletions(-) create mode 100644 changes/4878-PrettyWood.md diff --git a/changes/4878-PrettyWood.md b/changes/4878-PrettyWood.md new file mode 100644 index 0000000000..bd05c96d15 --- /dev/null +++ b/changes/4878-PrettyWood.md @@ -0,0 +1 @@ +fix: use dataclass proxy for frozen or empty dataclasses diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 1856a1203c..913d8cc693 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -34,20 +34,7 @@ class M: import sys from contextlib import contextmanager from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - Generator, - Optional, - Set, - Type, - TypeVar, - Union, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload from typing_extensions import dataclass_transform @@ -117,6 +104,7 @@ def dataclass( frozen: bool = False, config: Union[ConfigDict, Type[object], None] = None, validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, kw_only: bool = ..., ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: ... @@ -134,6 +122,7 @@ def dataclass( frozen: bool = False, config: Union[ConfigDict, Type[object], None] = None, validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, kw_only: bool = ..., ) -> 'DataclassClassOrWrapper': ... @@ -152,6 +141,7 @@ def dataclass( frozen: bool = False, config: Union[ConfigDict, Type[object], None] = None, validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: ... @@ -168,6 +158,7 @@ def dataclass( frozen: bool = False, config: Union[ConfigDict, Type[object], None] = None, validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, ) -> 'DataclassClassOrWrapper': ... @@ -184,6 +175,7 @@ def dataclass( frozen: bool = False, config: Union[ConfigDict, Type[object], None] = None, validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, kw_only: bool = False, ) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: """ @@ -197,7 +189,15 @@ def dataclass( def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': import dataclasses - if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore + should_use_proxy = ( + use_proxy + if use_proxy is not None + else ( + is_builtin_dataclass(cls) + and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0]))) + ) + ) + if should_use_proxy: dc_cls_doc = '' dc_cls = DataclassProxy(cls) default_validate_on_init = False @@ -437,14 +437,6 @@ def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: object.__setattr__(self, name, value) -def _extra_dc_args(cls: Type[Any]) -> Set[str]: - return { - x - for x in dir(cls) - if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__')) - } - - def is_builtin_dataclass(_cls: Type[Any]) -> bool: """ Whether a class is a stdlib dataclass @@ -482,4 +474,4 @@ def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig] and yield the validators It retrieves the parameters of the dataclass and forwards them to the newly created dataclass """ - yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False)) + yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True)) diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 65151801cc..1686ff2672 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -1513,3 +1513,98 @@ class Foo: assert config.bar == 'cat' setattr(config, 'bar', 'dog') assert config.bar == 'dog' + + +def test_frozen_dataclasses(): + @dataclasses.dataclass(frozen=True) + class First: + a: int + + @dataclasses.dataclass(frozen=True) + class Second(First): + @property + def b(self): + return self.a + + class My(BaseModel): + my: Second + + assert My(my=Second(a='1')).my.b == 1 + + +def test_empty_dataclass(): + """should be able to inherit without adding a field""" + + @dataclasses.dataclass + class UnvalidatedDataclass: + a: int = 0 + + @pydantic.dataclasses.dataclass + class ValidatedDerivedA(UnvalidatedDataclass): + ... + + @pydantic.dataclasses.dataclass() + class ValidatedDerivedB(UnvalidatedDataclass): + b: int = 0 + + @pydantic.dataclasses.dataclass() + class ValidatedDerivedC(UnvalidatedDataclass): + ... + + +def test_proxy_dataclass(): + @dataclasses.dataclass + class Foo: + a: Optional[int] = dataclasses.field(default=42) + b: List = dataclasses.field(default_factory=list) + + @dataclasses.dataclass + class Bar: + pass + + @dataclasses.dataclass + class Model1: + foo: Foo + + class Model2(BaseModel): + foo: Foo + + m1 = Model1(foo=Foo()) + m2 = Model2(foo=Foo()) + + assert m1.foo.a == m2.foo.a == 42 + assert m1.foo.b == m2.foo.b == [] + assert m1.foo.Bar() is not None + assert m2.foo.Bar() is not None + + +def test_proxy_dataclass_2(): + @dataclasses.dataclass + class M1: + a: int + b: str = 'b' + c: float = dataclasses.field(init=False) + + def __post_init__(self): + self.c = float(self.a) + + @dataclasses.dataclass + class M2: + a: int + b: str = 'b' + c: float = dataclasses.field(init=False) + + def __post_init__(self): + self.c = float(self.a) + + @pydantic.validator('b') + def check_b(cls, v): + if not v: + raise ValueError('b should not be empty') + return v + + m1 = pydantic.parse_obj_as(M1, {'a': 3}) + m2 = pydantic.parse_obj_as(M2, {'a': 3}) + assert m1.a == m2.a == 3 + assert m1.b == m2.b == 'b' + assert m1.c == m2.c == 3.0