diff --git a/changes/2483-PrettyWood.md b/changes/2483-PrettyWood.md new file mode 100644 index 0000000000..6ca72df4bb --- /dev/null +++ b/changes/2483-PrettyWood.md @@ -0,0 +1,2 @@ +- support arbitrary types with custom `__eq__` +- support `Annotated` in `validate_arguments` and in generic models with python 3.9 \ No newline at end of file diff --git a/pydantic/decorator.py b/pydantic/decorator.py index 933a4d1a45..266195ce50 100644 --- a/pydantic/decorator.py +++ b/pydantic/decorator.py @@ -1,23 +1,10 @@ from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, - get_type_hints, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload from . import validator from .errors import ConfigError from .main import BaseModel, Extra, create_model +from .typing import get_all_type_hints from .utils import to_camel __all__ = ('validate_arguments',) @@ -87,17 +74,17 @@ def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901 self.v_args_name = 'args' self.v_kwargs_name = 'kwargs' - type_hints = get_type_hints(function) + type_hints = get_all_type_hints(function) takes_args = False takes_kwargs = False fields: Dict[str, Tuple[Any, Any]] = {} for i, (name, p) in enumerate(parameters.items()): - if p.annotation == p.empty: + if p.annotation is p.empty: annotation = Any else: annotation = type_hints[name] - default = ... if p.default == p.empty else p.default + default = ... if p.default is p.empty else p.default if p.kind == Parameter.POSITIONAL_ONLY: self.arg_mapping[i] = name fields[name] = annotation, default diff --git a/pydantic/fields.py b/pydantic/fields.py index 449dd55d4c..0c95d8ae3d 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -169,7 +169,7 @@ def update_from_config(self, from_config: Dict[str, Any]) -> None: setattr(self, attr_name, value) def _validate(self) -> None: - if self.default not in (Undefined, Ellipsis) and self.default_factory is not None: + if self.default is not Undefined and self.default_factory is not None: raise ValueError('cannot specify both default and default_factory') @@ -370,9 +370,10 @@ def _get_field_info( field_info = next(iter(field_infos), None) if field_info is not None: field_info.update_from_config(field_info_from_config) - if field_info.default not in (Undefined, Ellipsis): + if field_info.default is not Undefined: raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}') - if value not in (Undefined, Ellipsis): + if value is not Undefined and value is not Required: + # check also `Required` because of `validate_arguments` that sets `...` as default value field_info.default = value if isinstance(value, FieldInfo): @@ -450,7 +451,6 @@ def prepare(self) -> None: self._type_analysis() if self.required is Undefined: self.required = True - self.field_info.default = Required if self.default is Undefined and self.default_factory is None: self.default = None self.populate_validators() diff --git a/pydantic/generics.py b/pydantic/generics.py index ad224a477a..fc60a2b464 100644 --- a/pydantic/generics.py +++ b/pydantic/generics.py @@ -15,13 +15,14 @@ TypeVar, Union, cast, - get_type_hints, ) +from typing_extensions import Annotated + from .class_validators import gather_all_validators from .fields import DeferredType from .main import BaseModel, create_model -from .typing import display_as_type, get_args, get_origin, typing_base +from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base from .utils import all_identical, lenient_issubclass _generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {} @@ -73,7 +74,7 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T model_name = cls.__concrete_name__(params) validators = gather_all_validators(cls) - type_hints = get_type_hints(cls).items() + type_hints = get_all_type_hints(cls).items() instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar} fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__} @@ -159,6 +160,10 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any: type_args = get_args(type_) origin_type = get_origin(type_) + if origin_type is Annotated: + annotated_type, *annotations = type_args + return Annotated[replace_types(annotated_type, type_map), tuple(annotations)] + # Having type args is a good indicator that this is a typing module # class instantiation or a generic alias of some sort. if type_args: diff --git a/pydantic/typing.py b/pydantic/typing.py index bef7b12e10..d0e747e84b 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -18,6 +18,7 @@ Union, _eval_type, cast, + get_type_hints, ) from typing_extensions import Annotated, Literal @@ -70,6 +71,18 @@ def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: return cast(Any, type_)._evaluate(globalns, localns, set()) +if sys.version_info < (3, 9): + # Ensure we always get all the whole `Annotated` hint, not just the annotated type. + # For 3.6 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`, + # so it already returns the full annotation + get_all_type_hints = get_type_hints + +else: + + def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any: + return get_type_hints(obj, globalns, localns, include_extras=True) + + if sys.version_info < (3, 7): from typing import Callable as Callable @@ -225,6 +238,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]: 'get_args', 'get_origin', 'typing_base', + 'get_all_type_hints', ) diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 6fe90bb852..b5c27a8c52 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -1,11 +1,9 @@ -import sys -from typing import get_type_hints - import pytest from typing_extensions import Annotated from pydantic import BaseModel, Field from pydantic.fields import Undefined +from pydantic.typing import get_all_type_hints @pytest.mark.parametrize( @@ -43,15 +41,7 @@ class M(BaseModel): assert M().x == 5 assert M(x=10).x == 10 - - # get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full - # annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the - # full thing with the new include_extras flag. - if sys.version_info >= (3, 9): - assert get_type_hints(M)['x'] is int - assert get_type_hints(M, include_extras=True)['x'] == hint - else: - assert get_type_hints(M)['x'] == hint + assert get_all_type_hints(M)['x'] == hint @pytest.mark.parametrize( diff --git a/tests/test_decorator.py b/tests/test_decorator.py index bf8ed83deb..6b11fb285c 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -3,14 +3,13 @@ import sys from pathlib import Path from typing import List -from unittest.mock import ANY import pytest +from typing_extensions import Annotated from pydantic import BaseModel, Field, ValidationError, validate_arguments from pydantic.decorator import ValidatedFunction from pydantic.errors import ConfigError -from pydantic.typing import Annotated skip_pre_38 = pytest.mark.skipif(sys.version_info < (3, 8), reason='testing >= 3.8 behaviour only') @@ -154,13 +153,14 @@ def foo(a: int, b: int = Field(default_factory=lambda: 99), *args: int) -> int: assert foo(1, 2, 3) == 6 -@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed') def test_annotated_field_can_provide_factory() -> None: @validate_arguments - def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)] = ANY, *args: int) -> int: + def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)], *args: int) -> int: """mypy reports Incompatible default for argument "b" if we don't supply ANY as default""" return a + b + sum(args) + assert foo2(1) == 100 + @skip_pre_38 def test_positional_only(create_module): diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 2ce66334e8..6e77f26f75 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -1839,3 +1839,19 @@ class Config: with pytest.raises(TypeError): b.a = 'y' assert b.dict() == {'a': 'x'} + + +def test_arbitrary_types_allowed_custom_eq(): + class Foo: + def __eq__(self, other): + if other.__class__ is not Foo: + raise TypeError(f'Cannot interpret {other.__class__.__name__!r} as a valid type') + return True + + class Model(BaseModel): + x: Foo = Foo() + + class Config: + arbitrary_types_allowed = True + + assert Model().x == Foo() diff --git a/tests/test_generics.py b/tests/test_generics.py index d1e42d8666..0b0b5cd749 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -3,7 +3,7 @@ from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union import pytest -from typing_extensions import Literal +from typing_extensions import Annotated, Literal from pydantic import BaseModel, Field, ValidationError, root_validator, validator from pydantic.generics import GenericModel, _generic_types_cache, iter_contained_typevars, replace_types @@ -1071,3 +1071,13 @@ class GModel(GenericModel, Generic[FieldType, ValueType]): Fields = Literal['foo', 'bar'] m = GModel[Fields, str](field={'foo': 'x'}) assert m.dict() == {'field': {'foo': 'x'}} + + +@skip_36 +def test_generic_annotated(): + T = TypeVar('T') + + class SomeGenericModel(GenericModel, Generic[T]): + some_field: Annotated[T, Field(alias='the_alias')] + + SomeGenericModel[str](the_alias='qwe')