Skip to content

Commit

Permalink
fix: use dataclasses proxy for frozen or empty dataclasses (#4878)
Browse files Browse the repository at this point in the history
* add tests

* fix: dataclasses

* chore: add change file

* refactor: remove useless kwarg

* test: add new test

* keep old kwarg to avoid breaking change
  • Loading branch information
PrettyWood committed Dec 29, 2022
1 parent a220f87 commit bf4b5ce
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 24 deletions.
1 change: 1 addition & 0 deletions changes/4878-PrettyWood.md
@@ -0,0 +1 @@
fix: use dataclass proxy for frozen or empty dataclasses
40 changes: 16 additions & 24 deletions pydantic/dataclasses.py
Expand Up @@ -34,20 +34,7 @@ class M:
import sys
from contextlib import contextmanager
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
Optional,
Set,
Type,
TypeVar,
Union,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload

from typing_extensions import dataclass_transform

Expand Down Expand Up @@ -117,6 +104,7 @@ def dataclass(
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
use_proxy: Optional[bool] = None,
kw_only: bool = ...,
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
...
Expand All @@ -134,6 +122,7 @@ def dataclass(
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
use_proxy: Optional[bool] = None,
kw_only: bool = ...,
) -> 'DataclassClassOrWrapper':
...
Expand All @@ -152,6 +141,7 @@ def dataclass(
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
use_proxy: Optional[bool] = None,
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
...

Expand All @@ -168,6 +158,7 @@ def dataclass(
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
use_proxy: Optional[bool] = None,
) -> 'DataclassClassOrWrapper':
...

Expand All @@ -184,6 +175,7 @@ def dataclass(
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
use_proxy: Optional[bool] = None,
kw_only: bool = False,
) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
"""
Expand All @@ -197,7 +189,15 @@ def dataclass(
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
import dataclasses

if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore
should_use_proxy = (
use_proxy
if use_proxy is not None
else (
is_builtin_dataclass(cls)
and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0])))
)
)
if should_use_proxy:
dc_cls_doc = ''
dc_cls = DataclassProxy(cls)
default_validate_on_init = False
Expand Down Expand Up @@ -437,14 +437,6 @@ def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value:
object.__setattr__(self, name, value)


def _extra_dc_args(cls: Type[Any]) -> Set[str]:
return {
x
for x in dir(cls)
if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__'))
}


def is_builtin_dataclass(_cls: Type[Any]) -> bool:
"""
Whether a class is a stdlib dataclass
Expand Down Expand Up @@ -482,4 +474,4 @@ def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]
and yield the validators
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
"""
yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False))
yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))
95 changes: 95 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -1513,3 +1513,98 @@ class Foo:
assert config.bar == 'cat'
setattr(config, 'bar', 'dog')
assert config.bar == 'dog'


def test_frozen_dataclasses():
@dataclasses.dataclass(frozen=True)
class First:
a: int

@dataclasses.dataclass(frozen=True)
class Second(First):
@property
def b(self):
return self.a

class My(BaseModel):
my: Second

assert My(my=Second(a='1')).my.b == 1


def test_empty_dataclass():
"""should be able to inherit without adding a field"""

@dataclasses.dataclass
class UnvalidatedDataclass:
a: int = 0

@pydantic.dataclasses.dataclass
class ValidatedDerivedA(UnvalidatedDataclass):
...

@pydantic.dataclasses.dataclass()
class ValidatedDerivedB(UnvalidatedDataclass):
b: int = 0

@pydantic.dataclasses.dataclass()
class ValidatedDerivedC(UnvalidatedDataclass):
...


def test_proxy_dataclass():
@dataclasses.dataclass
class Foo:
a: Optional[int] = dataclasses.field(default=42)
b: List = dataclasses.field(default_factory=list)

@dataclasses.dataclass
class Bar:
pass

@dataclasses.dataclass
class Model1:
foo: Foo

class Model2(BaseModel):
foo: Foo

m1 = Model1(foo=Foo())
m2 = Model2(foo=Foo())

assert m1.foo.a == m2.foo.a == 42
assert m1.foo.b == m2.foo.b == []
assert m1.foo.Bar() is not None
assert m2.foo.Bar() is not None


def test_proxy_dataclass_2():
@dataclasses.dataclass
class M1:
a: int
b: str = 'b'
c: float = dataclasses.field(init=False)

def __post_init__(self):
self.c = float(self.a)

@dataclasses.dataclass
class M2:
a: int
b: str = 'b'
c: float = dataclasses.field(init=False)

def __post_init__(self):
self.c = float(self.a)

@pydantic.validator('b')
def check_b(cls, v):
if not v:
raise ValueError('b should not be empty')
return v

m1 = pydantic.parse_obj_as(M1, {'a': 3})
m2 = pydantic.parse_obj_as(M2, {'a': 3})
assert m1.a == m2.a == 3
assert m1.b == m2.b == 'b'
assert m1.c == m2.c == 3.0

0 comments on commit bf4b5ce

Please sign in to comment.