Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4ec6c52
commit f0b6777
Showing
1 changed file
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
""" | ||
The main purpose is to enhance stdlib dataclasses by adding validation | ||
We also want to keep the dataclass untouched to still support the default hashing, | ||
equality, repr, ... | ||
This means we **don't want to create a new dataclass that inherits from it** | ||
To make this happen, we first attach a `BaseModel` to the dataclass | ||
and magic methods to trigger the validation of the data. | ||
Now the problem is: for a stdlib dataclass `Item` that now has magic attributes for pydantic | ||
how can we have a new class `ValidatedItem` to trigger validation by default and keep `Item` | ||
behaviour untouched! | ||
To do this `ValidatedItem` will in fact be a wrapper around `Item`, that acts like a proxy. | ||
This wrapper will just inject an extra kwarg `__pydantic_run_validation__` for `ValidatedItem` | ||
and not for `Item`! (Note that this can always be injected "a la mano" if needed) | ||
""" | ||
from typing import Any, Callable, Dict, Optional, Type, Union | ||
|
||
from pydantic import create_model, validate_model | ||
from pydantic.class_validators import gather_all_validators | ||
from pydantic.fields import Field, FieldInfo, Required, Undefined | ||
|
||
|
||
def dataclass( | ||
_cls: Optional[Type[Any]] = None, | ||
*, | ||
init: bool = True, | ||
repr: bool = True, | ||
eq: bool = True, | ||
order: bool = False, | ||
unsafe_hash: bool = False, | ||
frozen: bool = False, | ||
config: Type[Any] = None, | ||
) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]: | ||
""" | ||
Like the python standard lib dataclasses but with type validation. | ||
Arguments are the same as for standard dataclasses, except for `validate_assignment`, which | ||
has the same meaning as `Config.validate_assignment`. | ||
""" | ||
|
||
def wrap(cls: Type[Any]) -> Type['Dataclass']: | ||
import dataclasses | ||
|
||
if not dataclasses.is_dataclass(cls): | ||
cls = dataclasses.dataclass( | ||
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen | ||
) | ||
return WithValidationWrapper(cls, config) | ||
|
||
if _cls is None: | ||
return wrap | ||
|
||
return wrap(_cls) | ||
|
||
|
||
class WithValidationWrapper: | ||
def __init__(self, dc, config): | ||
if not hasattr(dc, '__pydantic_model__'): | ||
add_pydantic_validation_attributes(dc, config) | ||
self.dc = dc | ||
|
||
def __getattr__(self, attr): | ||
return getattr(self.dc, attr) | ||
|
||
def __call__(self, *args, **kwargs): | ||
# By default we run the validation with the wrapper but can still be overwritten | ||
kwargs.setdefault('__pydantic_run_validation__', True) | ||
return self.dc(*args, **kwargs) | ||
|
||
|
||
def add_pydantic_validation_attributes(cls, config): | ||
# We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass | ||
# it won't even exist (code is generated on the fly by `dataclasses`) | ||
init_or_post_init_name = '__post_init__' if hasattr(cls, '__post_init__') else '__init__' | ||
init_or_post_init = getattr(cls, init_or_post_init_name) | ||
|
||
def new_init_or_post_init(self, *args, __pydantic_run_validation__: bool = False, **kwargs): | ||
init_or_post_init(self, *args, **kwargs) | ||
if __pydantic_run_validation__: | ||
self.__pydantic_validate_values__() | ||
|
||
setattr(cls, init_or_post_init_name, new_init_or_post_init) | ||
setattr(cls, '__pydantic_model__', create_pydantic_model_from_dataclass(cls, config)) | ||
setattr(cls, '__pydantic_validate_values__', dataclass_validate_values) | ||
|
||
|
||
def create_pydantic_model_from_dataclass(cls, config=None): | ||
import dataclasses | ||
|
||
field_definitions: Dict[str, Any] = {} | ||
for field in dataclasses.fields(cls): | ||
default: Any = Undefined | ||
default_factory = None | ||
field_info: FieldInfo | ||
|
||
if field.default is not dataclasses.MISSING: | ||
default = field.default | ||
# mypy issue 7020 and 708 | ||
elif field.default_factory is not dataclasses.MISSING: # type: ignore | ||
default_factory = field.default_factory # type: ignore | ||
else: | ||
default = Required | ||
|
||
if isinstance(default, FieldInfo): | ||
field_info = default | ||
cls.__has_field_info_default__ = True | ||
else: | ||
field_info = Field(default=default, default_factory=default_factory, **field.metadata) | ||
|
||
field_definitions[field.name] = (field.type, field_info) | ||
|
||
validators = gather_all_validators(cls) | ||
return create_model( | ||
cls.__name__, __config__=config, __module__=cls.__module__, __validators__=validators, **field_definitions | ||
) | ||
|
||
|
||
def dataclass_validate_values(self): | ||
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__) | ||
if validation_error: | ||
raise validation_error | ||
object.__setattr__(self, '__dict__', d) |