Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4ec6c52
commit 2f82543
Showing
3 changed files
with
192 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
""" | ||
The main purpose is to enhance stdlib dataclasses by adding validation | ||
We also want to keep the dataclass untouched to still support the default hashing, | ||
equality, repr, ... | ||
This means we **don't want to create a new dataclass that inherits from it** | ||
To make this happen, we first attach a `BaseModel` to the dataclass | ||
and magic methods to trigger the validation of the data. | ||
Now the problem is: for a stdlib dataclass `Item` that now has magic attributes for pydantic | ||
how can we have a new class `ValidatedItem` to trigger validation by default and keep `Item` | ||
behaviour untouched! | ||
To do this `ValidatedItem` will in fact be an instance of `DataclassWrapper`, a simple wrapper | ||
around `Item` that acts like a proxy. | ||
This wrapper will just inject an extra kwarg `__pydantic_run_validation__` for `ValidatedItem` | ||
and not for `Item`! (Note that this can always be injected "a la mano" if needed) | ||
""" | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union, overload | ||
|
||
from .class_validators import gather_all_validators | ||
from .error_wrappers import ValidationError | ||
from .fields import Field, FieldInfo, Required, Undefined | ||
from .main import create_model, validate_model | ||
|
||
if TYPE_CHECKING: | ||
from .main import BaseConfig, BaseModel # noqa: F401 | ||
|
||
class Dataclass: | ||
# stdlib attributes | ||
__dataclass_params__: Any # in reality `dataclasses._DataclassParams` | ||
|
||
# Added by pydantic | ||
__pydantic_model__: Type[BaseModel] | ||
__pydantic_validate_values__: Callable[['Dataclass'], None] | ||
__has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value | ||
|
||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
pass | ||
|
||
|
||
@overload | ||
def dataclass( | ||
*, | ||
init: bool = True, | ||
repr: bool = True, | ||
eq: bool = True, | ||
order: bool = False, | ||
unsafe_hash: bool = False, | ||
frozen: bool = False, | ||
config: Type[Any] = None, | ||
) -> Callable[[Type[Any]], 'DataclassWrapper']: | ||
... | ||
|
||
|
||
@overload | ||
def dataclass( | ||
_cls: Type[Any], | ||
*, | ||
init: bool = True, | ||
repr: bool = True, | ||
eq: bool = True, | ||
order: bool = False, | ||
unsafe_hash: bool = False, | ||
frozen: bool = False, | ||
config: Type[Any] = None, | ||
) -> 'DataclassWrapper': | ||
... | ||
|
||
|
||
def dataclass( | ||
_cls: Optional[Type[Any]] = None, | ||
*, | ||
init: bool = True, | ||
repr: bool = True, | ||
eq: bool = True, | ||
order: bool = False, | ||
unsafe_hash: bool = False, | ||
frozen: bool = False, | ||
config: Optional[Type['BaseConfig']] = None, | ||
) -> Union[Callable[[Type[Any]], 'DataclassWrapper'], 'DataclassWrapper']: | ||
""" | ||
Like the python standard lib dataclasses but with type validation. | ||
Arguments are the same as for standard dataclasses, except for `validate_assignment`, which | ||
has the same meaning as `Config.validate_assignment`. | ||
""" | ||
|
||
def wrap(cls: Type[Any]) -> DataclassWrapper: | ||
import dataclasses | ||
|
||
if not dataclasses.is_dataclass(cls): | ||
cls = dataclasses.dataclass( # type: ignore | ||
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen | ||
) | ||
return DataclassWrapper(cls, config) | ||
|
||
if _cls is None: | ||
return wrap | ||
|
||
return wrap(_cls) | ||
|
||
|
||
class DataclassWrapper: | ||
def __init__(self, dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']]) -> None: | ||
if not hasattr(dc_cls, '__pydantic_model__'): | ||
add_pydantic_validation_attributes(dc_cls, config) | ||
self.dc_cls = dc_cls | ||
|
||
def __getattr__(self, attr: str) -> Any: | ||
return getattr(self.dc_cls, attr) | ||
|
||
def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||
# By default we run the validation with the wrapper but can still be overwritten | ||
kwargs.setdefault('__pydantic_run_validation__', True) | ||
return self.dc_cls(*args, **kwargs) | ||
|
||
|
||
def add_pydantic_validation_attributes(dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']]) -> None: | ||
# We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass | ||
# it won't even exist (code is generated on the fly by `dataclasses`) | ||
init_or_post_init_name = '__post_init__' if hasattr(dc_cls, '__post_init__') else '__init__' | ||
init_or_post_init = getattr(dc_cls, init_or_post_init_name) | ||
|
||
def new_init_or_post_init( | ||
self: 'Dataclass', *args: Any, __pydantic_run_validation__: bool = False, **kwargs: Any | ||
) -> None: | ||
init_or_post_init(self, *args, **kwargs) | ||
if __pydantic_run_validation__: | ||
self.__pydantic_validate_values__() | ||
|
||
setattr(dc_cls, init_or_post_init_name, new_init_or_post_init) | ||
setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config)) | ||
setattr(dc_cls, '__pydantic_validate_values__', dataclass_validate_values) | ||
if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: | ||
setattr(dc_cls, '__setattr__', dataclass_validate_assignment_setattr) | ||
|
||
|
||
def create_pydantic_model_from_dataclass( | ||
dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']] = None | ||
) -> Type['BaseModel']: | ||
import dataclasses | ||
|
||
field_definitions: Dict[str, Any] = {} | ||
for field in dataclasses.fields(dc_cls): | ||
default: Any = Undefined | ||
default_factory = None | ||
field_info: FieldInfo | ||
|
||
if field.default is not dataclasses.MISSING: | ||
default = field.default | ||
# mypy issue 7020 and 708 | ||
elif field.default_factory is not dataclasses.MISSING: # type: ignore | ||
default_factory = field.default_factory # type: ignore | ||
else: | ||
default = Required | ||
|
||
if isinstance(default, FieldInfo): | ||
field_info = default | ||
dc_cls.__has_field_info_default__ = True | ||
else: | ||
field_info = Field(default=default, default_factory=default_factory, **field.metadata) | ||
|
||
field_definitions[field.name] = (field.type, field_info) | ||
|
||
validators = gather_all_validators(dc_cls) | ||
return create_model( | ||
dc_cls.__name__, __config__=config, __module__=dc_cls.__module__, __validators__=validators, **field_definitions | ||
) | ||
|
||
|
||
def dataclass_validate_values(self: 'Dataclass') -> None: | ||
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__) | ||
if validation_error: | ||
raise validation_error | ||
object.__setattr__(self, '__dict__', d) | ||
|
||
|
||
def dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: | ||
d = dict(self.__dict__) | ||
d.pop(name, None) | ||
known_field = self.__pydantic_model__.__fields__.get(name, None) | ||
if known_field: | ||
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) | ||
if error_: | ||
raise ValidationError([error_], self.__class__) | ||
|
||
object.__setattr__(self, name, value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters