diff --git a/changes/4477-PrettyWood.md b/changes/4477-PrettyWood.md new file mode 100644 index 00000000000..b74b32380a2 --- /dev/null +++ b/changes/4477-PrettyWood.md @@ -0,0 +1 @@ +fix: dataclass wrapper was not always called diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 856c96d18a0..688a2556f86 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -34,7 +34,20 @@ class M: import sys from contextlib import contextmanager from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Optional, + Set, + Type, + TypeVar, + Union, + overload, +) from typing_extensions import dataclass_transform @@ -184,7 +197,9 @@ def dataclass( def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': import dataclasses - if is_builtin_dataclass(cls): + has_extra_dataclass_fields = _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]) # type: ignore + + if is_builtin_dataclass(cls) and not has_extra_dataclass_fields: dc_cls_doc = '' dc_cls = DataclassProxy(cls) default_validate_on_init = False @@ -418,6 +433,14 @@ def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: object.__setattr__(self, name, value) +def _extra_dc_args(cls: Type[Any]) -> Set[str]: + return { + x + for x in dir(cls) + if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__')) + } + + def is_builtin_dataclass(_cls: Type[Any]) -> bool: """ Whether a class is a stdlib dataclass diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index f29bdfc3c80..e2027476bc0 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -1395,3 +1395,80 @@ class Bar: @pydantic.dataclasses.dataclass class Foo: a: List[Bar(a=1)] + + +def test_parent_post_init(): + @dataclasses.dataclass + class A: + a: float = 1 + + def __post_init__(self): + self.a *= 2 + + @pydantic.dataclasses.dataclass + class B(A): + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + assert B().a == 5 # 1 * 2 + 3 + + +def test_subclass_post_init_post_parse(): + @dataclasses.dataclass + class A: + a: float = 1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init_post_parse__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + assert B().a == 8 # (1 + 3) * 2 + + +def test_subclass_post_init(): + @dataclasses.dataclass + class A: + a: int = 1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + assert B().a == 5 # 1 * 2 + 3 + + +def test_subclass_post_init_inheritance(): + @dataclasses.dataclass + class A: + a: int = 1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + @pydantic.dataclasses.dataclass + class C(B): + def __post_init__(self): + self.a *= 3 + + assert C().a == 6 # 1 * 3 + 3