From 4ddf4f14cd8f75e62e77f01f9bab42ba6377031d Mon Sep 17 00:00:00 2001 From: Ofek Lev Date: Thu, 25 Feb 2021 12:27:50 -0500 Subject: [PATCH] Properly retain types of Mapping subclasses (#2325) * Properly retain types of Mapping subclasses * Create 2325-ofek.md * update with feedback Co-Authored-By: Eric Jolibois * satisfy mypy? * Update fields.py Co-Authored-By: Eric Jolibois * show uncovered line numbers * fix coverage * update * address feedback * try * update Co-Authored-By: Eric Jolibois * rename test * address feedback Co-authored-by: Eric Jolibois Co-authored-by: Samuel Colvin --- changes/2325-ofek.md | 1 + pydantic/fields.py | 49 ++++++++++++++++++++++++++++---- pydantic/main.py | 5 ++-- pydantic/schema.py | 4 +-- tests/test_main.py | 67 +++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 115 insertions(+), 11 deletions(-) create mode 100644 changes/2325-ofek.md diff --git a/changes/2325-ofek.md b/changes/2325-ofek.md new file mode 100644 index 0000000000..8a9cfe1af3 --- /dev/null +++ b/changes/2325-ofek.md @@ -0,0 +1 @@ +Prevent `Mapping` subclasses from always being coerced to `dict` diff --git a/pydantic/fields.py b/pydantic/fields.py index a59c598f47..79c49a9366 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -1,9 +1,10 @@ import warnings -from collections import deque +from collections import defaultdict, deque from collections.abc import Iterable as CollectionsIterable from typing import ( TYPE_CHECKING, Any, + DefaultDict, Deque, Dict, FrozenSet, @@ -249,6 +250,8 @@ def Schema(default: Any, **kwargs: Any) -> Any: SHAPE_ITERABLE = 9 SHAPE_GENERIC = 10 SHAPE_DEQUE = 11 +SHAPE_DICT = 12 +SHAPE_DEFAULTDICT = 13 SHAPE_NAME_LOOKUP = { SHAPE_LIST: 'List[{}]', SHAPE_SET: 'Set[{}]', @@ -257,8 +260,12 @@ def Schema(default: Any, **kwargs: Any) -> Any: SHAPE_FROZENSET: 'FrozenSet[{}]', SHAPE_ITERABLE: 'Iterable[{}]', SHAPE_DEQUE: 'Deque[{}]', + SHAPE_DICT: 'Dict[{}]', + SHAPE_DEFAULTDICT: 'DefaultDict[{}]', } +MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING} + class ModelField(Representation): __slots__ = ( @@ -572,6 +579,14 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity) elif issubclass(origin, Sequence): self.type_ = get_args(self.type_)[0] self.shape = SHAPE_SEQUENCE + elif issubclass(origin, DefaultDict): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_DEFAULTDICT + elif issubclass(origin, Dict): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_DICT elif issubclass(origin, Mapping): self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) self.type_ = get_args(self.type_)[1] @@ -688,8 +703,8 @@ def validate( if self.shape == SHAPE_SINGLETON: v, errors = self._validate_singleton(v, values, loc, cls) - elif self.shape == SHAPE_MAPPING: - v, errors = self._validate_mapping(v, values, loc, cls) + elif self.shape in MAPPING_LIKE_SHAPES: + v, errors = self._validate_mapping_like(v, values, loc, cls) elif self.shape == SHAPE_TUPLE: v, errors = self._validate_tuple(v, values, loc, cls) elif self.shape == SHAPE_ITERABLE: @@ -806,7 +821,7 @@ def _validate_tuple( else: return tuple(result), None - def _validate_mapping( + def _validate_mapping_like( self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': try: @@ -832,8 +847,30 @@ def _validate_mapping( result[key_result] = value_result if errors: return v, errors - else: + elif self.shape == SHAPE_DICT: return result, None + elif self.shape == SHAPE_DEFAULTDICT: + return defaultdict(self.type_, result), None + else: + return self._get_mapping_value(v, result), None + + def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]: + """ + When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid + coercing to `dict` unwillingly. + """ + original_cls = original.__class__ + + if original_cls == dict or original_cls == Dict: + return converted + elif original_cls in {defaultdict, DefaultDict}: + return defaultdict(self.type_, converted) + else: + try: + # Counter, OrderedDict, UserDict, ... + return original_cls(converted) # type: ignore + except TypeError: + raise RuntimeError(f'Could not convert dictionary to {original_cls.__name__!r}') from None def _validate_singleton( self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] @@ -876,7 +913,7 @@ def _type_display(self) -> PyObjectStr: t = display_as_type(self.type_) # have to do this since display_as_type(self.outer_type_) is different (and wrong) on python 3.6 - if self.shape == SHAPE_MAPPING: + if self.shape in MAPPING_LIKE_SHAPES: t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore elif self.shape == SHAPE_TUPLE: t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore diff --git a/pydantic/main.py b/pydantic/main.py index e0155d39db..f1cb71f2d9 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -28,7 +28,7 @@ from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators from .error_wrappers import ErrorWrapper, ValidationError from .errors import ConfigError, DictError, ExtraError, MissingError -from .fields import SHAPE_MAPPING, ModelField, ModelPrivateAttr, PrivateAttr, Undefined +from .fields import MAPPING_LIKE_SHAPES, ModelField, ModelPrivateAttr, PrivateAttr, Undefined from .json import custom_pydantic_encoder, pydantic_encoder from .parse import Protocol, load_file, load_str_bytes from .schema import default_ref_template, model_schema @@ -559,7 +559,8 @@ def json( @classmethod def _enforce_dict_if_root(cls, obj: Any) -> Any: if cls.__custom_root_type__ and ( - not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) or cls.__fields__[ROOT_KEY].shape == SHAPE_MAPPING + not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) + or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES ): return {ROOT_KEY: obj} else: diff --git a/pydantic/schema.py b/pydantic/schema.py index 582674755e..b38488e118 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -29,11 +29,11 @@ from typing_extensions import Annotated, Literal from .fields import ( + MAPPING_LIKE_SHAPES, SHAPE_FROZENSET, SHAPE_GENERIC, SHAPE_ITERABLE, SHAPE_LIST, - SHAPE_MAPPING, SHAPE_SEQUENCE, SHAPE_SET, SHAPE_SINGLETON, @@ -450,7 +450,7 @@ def field_type_schema( if field.shape in {SHAPE_SET, SHAPE_FROZENSET}: f_schema['uniqueItems'] = True - elif field.shape == SHAPE_MAPPING: + elif field.shape in MAPPING_LIKE_SHAPES: f_schema = {'type': 'object'} key_field = cast(ModelField, field.key_field) regex = getattr(key_field.type_, 'regex', None) diff --git a/tests/test_main.py b/tests/test_main.py index 46fc16ee4b..08c926c40b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,7 @@ import sys +from collections import defaultdict from enum import Enum -from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints +from typing import Any, Callable, ClassVar, DefaultDict, Dict, List, Mapping, Optional, Type, get_type_hints from uuid import UUID, uuid4 import pytest @@ -1611,6 +1612,70 @@ class Item(BaseModel): assert id(image_2) == id(item.images[1]) +def test_mapping_retains_type_subclass(): + class CustomMap(dict): + pass + + class Model(BaseModel): + x: Mapping[str, Mapping[str, int]] + + m = Model(x=CustomMap(outer=CustomMap(inner=42))) + assert isinstance(m.x, CustomMap) + assert isinstance(m.x['outer'], CustomMap) + assert m.x['outer']['inner'] == 42 + + +def test_mapping_retains_type_defaultdict(): + class Model(BaseModel): + x: Mapping[str, int] + + d = defaultdict(int) + d[1] = '2' + d['3'] + + m = Model(x=d) + assert isinstance(m.x, defaultdict) + assert m.x['1'] == 2 + assert m.x['3'] == 0 + + +def test_mapping_retains_type_fallback_error(): + class CustomMap(dict): + def __init__(self, *args, **kwargs): + if args or kwargs: + raise TypeError('test') + super().__init__(*args, **kwargs) + + class Model(BaseModel): + x: Mapping[str, int] + + d = CustomMap() + d['one'] = 1 + d['two'] = 2 + + with pytest.raises(RuntimeError, match="Could not convert dictionary to 'CustomMap'"): + Model(x=d) + + +def test_typing_coercion_dict(): + class Model(BaseModel): + x: Dict[str, int] + + m = Model(x={'one': 1, 'two': 2}) + assert repr(m) == "Model(x={'one': 1, 'two': 2})" + + +def test_typing_coercion_defaultdict(): + class Model(BaseModel): + x: DefaultDict[int, str] + + d = defaultdict(str) + d['1'] + m = Model(x=d) + m.x['a'] + assert repr(m) == "Model(x=defaultdict(, {1: '', 'a': ''}))" + + def test_class_kwargs_config(): class Base(BaseModel, extra='forbid', alias_generator=str.upper): a: int