Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Mar 21, 2021
1 parent 4ec6c52 commit 2f82543
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pydantic/__init__.py
@@ -1,5 +1,5 @@
# flake8: noqa
from . import dataclasses
from . import dataclasses, dataclasses2
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict
from .class_validators import root_validator, validator
from .decorator import validate_arguments
Expand All @@ -22,6 +22,7 @@
'create_model_from_typeddict',
# dataclasses
'dataclasses',
'dataclasses2',
# class_validators
'root_validator',
'validator',
Expand Down
188 changes: 188 additions & 0 deletions pydantic/dataclasses2.py
@@ -0,0 +1,188 @@
"""
The main purpose is to enhance stdlib dataclasses by adding validation
We also want to keep the dataclass untouched to still support the default hashing,
equality, repr, ...
This means we **don't want to create a new dataclass that inherits from it**
To make this happen, we first attach a `BaseModel` to the dataclass
and magic methods to trigger the validation of the data.
Now the problem is: for a stdlib dataclass `Item` that now has magic attributes for pydantic
how can we have a new class `ValidatedItem` to trigger validation by default and keep `Item`
behaviour untouched!
To do this `ValidatedItem` will in fact be an instance of `DataclassWrapper`, a simple wrapper
around `Item` that acts like a proxy.
This wrapper will just inject an extra kwarg `__pydantic_run_validation__` for `ValidatedItem`
and not for `Item`! (Note that this can always be injected "a la mano" if needed)
"""
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union, overload

from .class_validators import gather_all_validators
from .error_wrappers import ValidationError
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model

if TYPE_CHECKING:
from .main import BaseConfig, BaseModel # noqa: F401

class Dataclass:
# stdlib attributes
__dataclass_params__: Any # in reality `dataclasses._DataclassParams`

# Added by pydantic
__pydantic_model__: Type[BaseModel]
__pydantic_validate_values__: Callable[['Dataclass'], None]
__has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value

def __init__(self, *args: Any, **kwargs: Any) -> None:
pass


@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Callable[[Type[Any]], 'DataclassWrapper']:
...


@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> 'DataclassWrapper':
...


def dataclass(
_cls: Optional[Type[Any]] = None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Optional[Type['BaseConfig']] = None,
) -> Union[Callable[[Type[Any]], 'DataclassWrapper'], 'DataclassWrapper']:
"""
Like the python standard lib dataclasses but with type validation.
Arguments are the same as for standard dataclasses, except for `validate_assignment`, which
has the same meaning as `Config.validate_assignment`.
"""

def wrap(cls: Type[Any]) -> DataclassWrapper:
import dataclasses

if not dataclasses.is_dataclass(cls):
cls = dataclasses.dataclass( # type: ignore
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
return DataclassWrapper(cls, config)

if _cls is None:
return wrap

return wrap(_cls)


class DataclassWrapper:
def __init__(self, dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']]) -> None:
if not hasattr(dc_cls, '__pydantic_model__'):
add_pydantic_validation_attributes(dc_cls, config)
self.dc_cls = dc_cls

def __getattr__(self, attr: str) -> Any:
return getattr(self.dc_cls, attr)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# By default we run the validation with the wrapper but can still be overwritten
kwargs.setdefault('__pydantic_run_validation__', True)
return self.dc_cls(*args, **kwargs)


def add_pydantic_validation_attributes(dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']]) -> None:
# We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
# it won't even exist (code is generated on the fly by `dataclasses`)
init_or_post_init_name = '__post_init__' if hasattr(dc_cls, '__post_init__') else '__init__'
init_or_post_init = getattr(dc_cls, init_or_post_init_name)

def new_init_or_post_init(
self: 'Dataclass', *args: Any, __pydantic_run_validation__: bool = False, **kwargs: Any
) -> None:
init_or_post_init(self, *args, **kwargs)
if __pydantic_run_validation__:
self.__pydantic_validate_values__()

setattr(dc_cls, init_or_post_init_name, new_init_or_post_init)
setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config))
setattr(dc_cls, '__pydantic_validate_values__', dataclass_validate_values)
if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
setattr(dc_cls, '__setattr__', dataclass_validate_assignment_setattr)


def create_pydantic_model_from_dataclass(
dc_cls: Type['Dataclass'], config: Optional[Type['BaseConfig']] = None
) -> Type['BaseModel']:
import dataclasses

field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(dc_cls):
default: Any = Undefined
default_factory = None
field_info: FieldInfo

if field.default is not dataclasses.MISSING:
default = field.default
# mypy issue 7020 and 708
elif field.default_factory is not dataclasses.MISSING: # type: ignore
default_factory = field.default_factory # type: ignore
else:
default = Required

if isinstance(default, FieldInfo):
field_info = default
dc_cls.__has_field_info_default__ = True
else:
field_info = Field(default=default, default_factory=default_factory, **field.metadata)

field_definitions[field.name] = (field.type, field_info)

validators = gather_all_validators(dc_cls)
return create_model(
dc_cls.__name__, __config__=config, __module__=dc_cls.__module__, __validators__=validators, **field_definitions
)


def dataclass_validate_values(self: 'Dataclass') -> None:
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)


def dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
d = dict(self.__dict__)
d.pop(name, None)
known_field = self.__pydantic_model__.__fields__.get(name, None)
if known_field:
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
if error_:
raise ValidationError([error_], self.__class__)

object.__setattr__(self, name, value)
3 changes: 2 additions & 1 deletion pydantic/types.py
Expand Up @@ -106,10 +106,11 @@

if TYPE_CHECKING:
from .dataclasses import Dataclass # noqa: F401
from .dataclasses2 import Dataclass as Dataclass2 # noqa: F401
from .main import BaseConfig, BaseModel # noqa: F401
from .typing import CallableGenerator

ModelOrDc = Type[Union['BaseModel', 'Dataclass']]
ModelOrDc = Type[Union['BaseModel', 'Dataclass', 'Dataclass2']]

T = TypeVar('T')
_DEFINED_TYPES: 'WeakSet[type]' = WeakSet()
Expand Down

0 comments on commit 2f82543

Please sign in to comment.