Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support arbitrary types with custom __eq__ and Annotated with python 3.9 #2502

Merged
merged 1 commit into from May 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/2483-PrettyWood.md
@@ -0,0 +1,2 @@
- support arbitrary types with custom `__eq__`
- support `Annotated` in `validate_arguments` and in generic models with python 3.9
23 changes: 5 additions & 18 deletions pydantic/decorator.py
@@ -1,23 +1,10 @@
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload

from . import validator
from .errors import ConfigError
from .main import BaseModel, Extra, create_model
from .typing import get_all_type_hints
from .utils import to_camel

__all__ = ('validate_arguments',)
Expand Down Expand Up @@ -87,17 +74,17 @@ def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901
self.v_args_name = 'args'
self.v_kwargs_name = 'kwargs'

type_hints = get_type_hints(function)
type_hints = get_all_type_hints(function)
takes_args = False
takes_kwargs = False
fields: Dict[str, Tuple[Any, Any]] = {}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation == p.empty:
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]

default = ... if p.default == p.empty else p.default
default = ... if p.default is p.empty else p.default
if p.kind == Parameter.POSITIONAL_ONLY:
self.arg_mapping[i] = name
fields[name] = annotation, default
Expand Down
8 changes: 4 additions & 4 deletions pydantic/fields.py
Expand Up @@ -169,7 +169,7 @@ def update_from_config(self, from_config: Dict[str, Any]) -> None:
setattr(self, attr_name, value)

def _validate(self) -> None:
if self.default not in (Undefined, Ellipsis) and self.default_factory is not None:
if self.default is not Undefined and self.default_factory is not None:
raise ValueError('cannot specify both default and default_factory')


Expand Down Expand Up @@ -370,9 +370,10 @@ def _get_field_info(
field_info = next(iter(field_infos), None)
if field_info is not None:
field_info.update_from_config(field_info_from_config)
if field_info.default not in (Undefined, Ellipsis):
if field_info.default is not Undefined:
raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}')
if value not in (Undefined, Ellipsis):
if value is not Undefined and value is not Required:
# check also `Required` because of `validate_arguments` that sets `...` as default value
field_info.default = value

if isinstance(value, FieldInfo):
Expand Down Expand Up @@ -450,7 +451,6 @@ def prepare(self) -> None:
self._type_analysis()
if self.required is Undefined:
self.required = True
self.field_info.default = Required
if self.default is Undefined and self.default_factory is None:
self.default = None
self.populate_validators()
Expand Down
11 changes: 8 additions & 3 deletions pydantic/generics.py
Expand Up @@ -15,13 +15,14 @@
TypeVar,
Union,
cast,
get_type_hints,
)

from typing_extensions import Annotated

from .class_validators import gather_all_validators
from .fields import DeferredType
from .main import BaseModel, create_model
from .typing import display_as_type, get_args, get_origin, typing_base
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from .utils import all_identical, lenient_issubclass

_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {}
Expand Down Expand Up @@ -73,7 +74,7 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)

type_hints = get_type_hints(cls).items()
type_hints = get_all_type_hints(cls).items()
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}

fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
Expand Down Expand Up @@ -159,6 +160,10 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
type_args = get_args(type_)
origin_type = get_origin(type_)

if origin_type is Annotated:
annotated_type, *annotations = type_args
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]

# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
Expand Down
14 changes: 14 additions & 0 deletions pydantic/typing.py
Expand Up @@ -18,6 +18,7 @@
Union,
_eval_type,
cast,
get_type_hints,
)

from typing_extensions import Annotated, Literal
Expand Down Expand Up @@ -70,6 +71,18 @@ def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return cast(Any, type_)._evaluate(globalns, localns, set())


if sys.version_info < (3, 9):
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
# For 3.6 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
# so it already returns the full annotation
get_all_type_hints = get_type_hints

else:

def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
return get_type_hints(obj, globalns, localns, include_extras=True)


if sys.version_info < (3, 7):
from typing import Callable as Callable

Expand Down Expand Up @@ -225,6 +238,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
'get_args',
'get_origin',
'typing_base',
'get_all_type_hints',
)


Expand Down
14 changes: 2 additions & 12 deletions tests/test_annotated.py
@@ -1,11 +1,9 @@
import sys
from typing import get_type_hints

import pytest
from typing_extensions import Annotated

from pydantic import BaseModel, Field
from pydantic.fields import Undefined
from pydantic.typing import get_all_type_hints


@pytest.mark.parametrize(
Expand Down Expand Up @@ -43,15 +41,7 @@ class M(BaseModel):

assert M().x == 5
assert M(x=10).x == 10

# get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full
# annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the
# full thing with the new include_extras flag.
if sys.version_info >= (3, 9):
assert get_type_hints(M)['x'] is int
assert get_type_hints(M, include_extras=True)['x'] == hint
else:
assert get_type_hints(M)['x'] == hint
assert get_all_type_hints(M)['x'] == hint


@pytest.mark.parametrize(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_decorator.py
Expand Up @@ -3,14 +3,13 @@
import sys
from pathlib import Path
from typing import List
from unittest.mock import ANY

import pytest
from typing_extensions import Annotated

from pydantic import BaseModel, Field, ValidationError, validate_arguments
from pydantic.decorator import ValidatedFunction
from pydantic.errors import ConfigError
from pydantic.typing import Annotated

skip_pre_38 = pytest.mark.skipif(sys.version_info < (3, 8), reason='testing >= 3.8 behaviour only')

Expand Down Expand Up @@ -154,13 +153,14 @@ def foo(a: int, b: int = Field(default_factory=lambda: 99), *args: int) -> int:
assert foo(1, 2, 3) == 6


@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
def test_annotated_field_can_provide_factory() -> None:
@validate_arguments
def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)] = ANY, *args: int) -> int:
def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)], *args: int) -> int:
"""mypy reports Incompatible default for argument "b" if we don't supply ANY as default"""
return a + b + sum(args)

assert foo2(1) == 100


@skip_pre_38
def test_positional_only(create_module):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_edge_cases.py
Expand Up @@ -1839,3 +1839,19 @@ class Config:
with pytest.raises(TypeError):
b.a = 'y'
assert b.dict() == {'a': 'x'}


def test_arbitrary_types_allowed_custom_eq():
class Foo:
def __eq__(self, other):
if other.__class__ is not Foo:
raise TypeError(f'Cannot interpret {other.__class__.__name__!r} as a valid type')
return True

class Model(BaseModel):
x: Foo = Foo()

class Config:
arbitrary_types_allowed = True

assert Model().x == Foo()
12 changes: 11 additions & 1 deletion tests/test_generics.py
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union

import pytest
from typing_extensions import Literal
from typing_extensions import Annotated, Literal

from pydantic import BaseModel, Field, ValidationError, root_validator, validator
from pydantic.generics import GenericModel, _generic_types_cache, iter_contained_typevars, replace_types
Expand Down Expand Up @@ -1071,3 +1071,13 @@ class GModel(GenericModel, Generic[FieldType, ValueType]):
Fields = Literal['foo', 'bar']
m = GModel[Fields, str](field={'foo': 'x'})
assert m.dict() == {'field': {'foo': 'x'}}


@skip_36
def test_generic_annotated():
T = TypeVar('T')

class SomeGenericModel(GenericModel, Generic[T]):
some_field: Annotated[T, Field(alias='the_alias')]

SomeGenericModel[str](the_alias='qwe')