From b100e90119f3a648cb0f5fb08fa9291bd4ccb3c9 Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Mon, 5 Sep 2022 15:09:31 +0200 Subject: [PATCH] fix: always call `__post_init__` when set in dataclasses --- changes/4477-PrettyWood.md | 1 + pydantic/dataclasses.py | 4 +- tests/test_dataclasses.py | 77 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 changes/4477-PrettyWood.md diff --git a/changes/4477-PrettyWood.md b/changes/4477-PrettyWood.md new file mode 100644 index 00000000000..b74b32380a2 --- /dev/null +++ b/changes/4477-PrettyWood.md @@ -0,0 +1 @@ +fix: dataclass wrapper was not always called diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 856c96d18a0..99567182072 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -184,7 +184,9 @@ def dataclass( def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': import dataclasses - if is_builtin_dataclass(cls): + if is_builtin_dataclass(cls) and not ( + is_builtin_dataclass(cls.__bases__[0]) and not hasattr(cls.__bases__[0], '__post_init__') + ): dc_cls_doc = '' dc_cls = DataclassProxy(cls) default_validate_on_init = False diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index f29bdfc3c80..e2027476bc0 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -1395,3 +1395,80 @@ class Bar: @pydantic.dataclasses.dataclass class Foo: a: List[Bar(a=1)] + + +def test_parent_post_init(): + @dataclasses.dataclass + class A: + a: float = 1 + + def __post_init__(self): + self.a *= 2 + + @pydantic.dataclasses.dataclass + class B(A): + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + assert B().a == 5 # 1 * 2 + 3 + + +def test_subclass_post_init_post_parse(): + @dataclasses.dataclass + class A: + a: float = 1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init_post_parse__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + assert B().a == 8 # (1 + 3) * 2 + + +def test_subclass_post_init(): + @dataclasses.dataclass + class A: + a: int = 1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + assert B().a == 5 # 1 * 2 + 3 + + +def test_subclass_post_init_inheritance(): + @dataclasses.dataclass + class A: + a: int = 1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 3 + return value + + @pydantic.dataclasses.dataclass + class C(B): + def __post_init__(self): + self.a *= 3 + + assert C().a == 6 # 1 * 3 + 3