/
dataclasses.py
254 lines (208 loc) · 8.71 KB
/
dataclasses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
from .class_validators import gather_all_validators
from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model
from .typing import resolve_annotations
from .utils import ClassAttribute
if TYPE_CHECKING:
from .main import BaseConfig, BaseModel # noqa: F401
from .typing import CallableGenerator, NoArgAnyCallable
DataclassT = TypeVar('DataclassT', bound='Dataclass')
class Dataclass:
__pydantic_model__: Type[BaseModel]
__initialised__: bool
__post_init_original__: Optional[Callable[..., None]]
__processed__: Optional[ClassAttribute]
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
@classmethod
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
pass
@classmethod
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
pass
def __call__(self: 'DataclassT', *args: Any, **kwargs: Any) -> 'DataclassT':
pass
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
if isinstance(v, cls):
return v
elif isinstance(v, (list, tuple)):
return cls(*v)
elif isinstance(v, dict):
return cls(**v)
# In nested dataclasses, v can be of type `dataclasses.dataclass`.
# But to validate fields `cls` will be in fact a `pydantic.dataclasses.dataclass`,
# which inherits directly from the class of `v`.
elif is_builtin_dataclass(v) and cls.__bases__[0] is type(v):
import dataclasses
return cls(**dataclasses.asdict(v))
else:
raise DataclassTypeError(class_name=cls.__name__)
def _get_validators(cls: Type['Dataclass']) -> 'CallableGenerator':
yield cls.__validate__
def setattr_validate_assignment(self: 'Dataclass', name: str, value: Any) -> None:
if self.__initialised__:
d = dict(self.__dict__)
d.pop(name, None)
known_field = self.__pydantic_model__.__fields__.get(name, None)
if known_field:
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
if error_:
raise ValidationError([error_], self.__class__)
object.__setattr__(self, name, value)
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
"""
`dataclasses.is_dataclass` is True if one of the class parents is a `dataclass`.
This is why we also add a class attribute `__processed__` to only consider 'direct' built-in dataclasses
"""
import dataclasses
return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)
def _process_class(
_cls: Type[Any],
init: bool,
repr: bool,
eq: bool,
order: bool,
unsafe_hash: bool,
frozen: bool,
config: Optional[Type[Any]],
) -> Type['Dataclass']:
import dataclasses
post_init_original = getattr(_cls, '__post_init__', None)
if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
post_init_original = None
if not post_init_original:
post_init_original = getattr(_cls, '__post_init_original__', None)
post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)
# We need to remove `FieldInfo` values since they are not valid as input
# It's ok since they are obviously the default values
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if post_init_post_parse is not None:
post_init_post_parse(self, *initvars)
# If the class is already a dataclass, __post_init__ will not be called automatically
# so no validation will be added.
# We hence create dynamically a new dataclass:
# ```
# @dataclasses.dataclass
# class NewClass(_cls):
# __post_init__ = _pydantic_post_init
# ```
# with the exact same fields as the base dataclass
# and register it on module level to address pickle problem:
# https://github.com/samuelcolvin/pydantic/issues/2111
if is_builtin_dataclass(_cls):
uniq_class_name = f'_Pydantic_{_cls.__name__}_{id(_cls)}'
_cls = type(
# for pretty output new class will have the name as original
_cls.__name__,
(_cls,),
{
'__annotations__': resolve_annotations(_cls.__annotations__, _cls.__module__),
'__post_init__': _pydantic_post_init,
# attrs for pickle to find this class
'__module__': __name__,
'__qualname__': uniq_class_name,
},
)
globals()[uniq_class_name] = _cls
else:
_cls.__post_init__ = _pydantic_post_init
cls: Type['Dataclass'] = dataclasses.dataclass( # type: ignore
_cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
cls.__processed__ = ClassAttribute('__processed__', True)
field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(cls):
default: Any = Undefined
default_factory: Optional['NoArgAnyCallable'] = None
field_info: FieldInfo
if field.default is not dataclasses.MISSING:
default = field.default
# mypy issue 7020 and 708
elif field.default_factory is not dataclasses.MISSING: # type: ignore
default_factory = field.default_factory # type: ignore
else:
default = Required
if isinstance(default, FieldInfo):
field_info = default
else:
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
field_definitions[field.name] = (field.type, field_info)
validators = gather_all_validators(cls)
cls.__pydantic_model__ = create_model(
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
)
cls.__initialised__ = False
cls.__validate__ = classmethod(_validate_dataclass) # type: ignore[assignment]
cls.__get_validators__ = classmethod(_get_validators) # type: ignore[assignment]
if post_init_original:
cls.__post_init_original__ = post_init_original
if cls.__pydantic_model__.__config__.validate_assignment and not frozen:
cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment]
return cls
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Callable[[Type[Any]], Type['Dataclass']]:
...
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Type['Dataclass']:
...
def dataclass(
_cls: Optional[Type[Any]] = None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]:
"""
Like the python standard lib dataclasses but with type validation.
Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
as Config.validate_assignment.
"""
def wrap(cls: Type[Any]) -> Type['Dataclass']:
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)
if _cls is None:
return wrap
return wrap(_cls)
def make_dataclass_validator(_cls: Type[Any], config: Type['BaseConfig']) -> 'CallableGenerator':
"""
Create a pydantic.dataclass from a builtin dataclass to add type validation
and yield the validators
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
"""
dataclass_params = _cls.__dataclass_params__
stdlib_dataclass_parameters = {param: getattr(dataclass_params, param) for param in dataclass_params.__slots__}
cls = dataclass(_cls, config=config, **stdlib_dataclass_parameters)
yield from _get_validators(cls)