Skip to content

Commit

Permalink
Properly retain types of Mapping subclasses (#2325)
Browse files Browse the repository at this point in the history
* Properly retain types of Mapping subclasses

* Create 2325-ofek.md

* update with feedback

Co-Authored-By: Eric Jolibois <eric.jolibois@toucantoco.com>

* satisfy mypy?

* Update fields.py

Co-Authored-By: Eric Jolibois <eric.jolibois@toucantoco.com>

* show uncovered line numbers

* fix coverage

* update

* address feedback

* try

* update

Co-Authored-By: Eric Jolibois <eric.jolibois@toucantoco.com>

* rename test

* address feedback

Co-authored-by: Eric Jolibois <eric.jolibois@toucantoco.com>
Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
3 people committed Feb 25, 2021
1 parent aa92db5 commit 4ddf4f1
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 11 deletions.
1 change: 1 addition & 0 deletions changes/2325-ofek.md
@@ -0,0 +1 @@
Prevent `Mapping` subclasses from always being coerced to `dict`
49 changes: 43 additions & 6 deletions pydantic/fields.py
@@ -1,9 +1,10 @@
import warnings
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable as CollectionsIterable
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Deque,
Dict,
FrozenSet,
Expand Down Expand Up @@ -249,6 +250,8 @@ def Schema(default: Any, **kwargs: Any) -> Any:
SHAPE_ITERABLE = 9
SHAPE_GENERIC = 10
SHAPE_DEQUE = 11
SHAPE_DICT = 12
SHAPE_DEFAULTDICT = 13
SHAPE_NAME_LOOKUP = {
SHAPE_LIST: 'List[{}]',
SHAPE_SET: 'Set[{}]',
Expand All @@ -257,8 +260,12 @@ def Schema(default: Any, **kwargs: Any) -> Any:
SHAPE_FROZENSET: 'FrozenSet[{}]',
SHAPE_ITERABLE: 'Iterable[{}]',
SHAPE_DEQUE: 'Deque[{}]',
SHAPE_DICT: 'Dict[{}]',
SHAPE_DEFAULTDICT: 'DefaultDict[{}]',
}

MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING}


class ModelField(Representation):
__slots__ = (
Expand Down Expand Up @@ -572,6 +579,14 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
elif issubclass(origin, Sequence):
self.type_ = get_args(self.type_)[0]
self.shape = SHAPE_SEQUENCE
elif issubclass(origin, DefaultDict):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
self.shape = SHAPE_DEFAULTDICT
elif issubclass(origin, Dict):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
self.shape = SHAPE_DICT
elif issubclass(origin, Mapping):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
Expand Down Expand Up @@ -688,8 +703,8 @@ def validate(

if self.shape == SHAPE_SINGLETON:
v, errors = self._validate_singleton(v, values, loc, cls)
elif self.shape == SHAPE_MAPPING:
v, errors = self._validate_mapping(v, values, loc, cls)
elif self.shape in MAPPING_LIKE_SHAPES:
v, errors = self._validate_mapping_like(v, values, loc, cls)
elif self.shape == SHAPE_TUPLE:
v, errors = self._validate_tuple(v, values, loc, cls)
elif self.shape == SHAPE_ITERABLE:
Expand Down Expand Up @@ -806,7 +821,7 @@ def _validate_tuple(
else:
return tuple(result), None

def _validate_mapping(
def _validate_mapping_like(
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
) -> 'ValidateReturn':
try:
Expand All @@ -832,8 +847,30 @@ def _validate_mapping(
result[key_result] = value_result
if errors:
return v, errors
else:
elif self.shape == SHAPE_DICT:
return result, None
elif self.shape == SHAPE_DEFAULTDICT:
return defaultdict(self.type_, result), None
else:
return self._get_mapping_value(v, result), None

def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]:
"""
When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid
coercing to `dict` unwillingly.
"""
original_cls = original.__class__

if original_cls == dict or original_cls == Dict:
return converted
elif original_cls in {defaultdict, DefaultDict}:
return defaultdict(self.type_, converted)
else:
try:
# Counter, OrderedDict, UserDict, ...
return original_cls(converted) # type: ignore
except TypeError:
raise RuntimeError(f'Could not convert dictionary to {original_cls.__name__!r}') from None

def _validate_singleton(
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
Expand Down Expand Up @@ -876,7 +913,7 @@ def _type_display(self) -> PyObjectStr:
t = display_as_type(self.type_)

# have to do this since display_as_type(self.outer_type_) is different (and wrong) on python 3.6
if self.shape == SHAPE_MAPPING:
if self.shape in MAPPING_LIKE_SHAPES:
t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore
elif self.shape == SHAPE_TUPLE:
t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions pydantic/main.py
Expand Up @@ -28,7 +28,7 @@
from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
from .error_wrappers import ErrorWrapper, ValidationError
from .errors import ConfigError, DictError, ExtraError, MissingError
from .fields import SHAPE_MAPPING, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
from .fields import MAPPING_LIKE_SHAPES, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
from .json import custom_pydantic_encoder, pydantic_encoder
from .parse import Protocol, load_file, load_str_bytes
from .schema import default_ref_template, model_schema
Expand Down Expand Up @@ -559,7 +559,8 @@ def json(
@classmethod
def _enforce_dict_if_root(cls, obj: Any) -> Any:
if cls.__custom_root_type__ and (
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) or cls.__fields__[ROOT_KEY].shape == SHAPE_MAPPING
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY})
or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES
):
return {ROOT_KEY: obj}
else:
Expand Down
4 changes: 2 additions & 2 deletions pydantic/schema.py
Expand Up @@ -29,11 +29,11 @@
from typing_extensions import Annotated, Literal

from .fields import (
MAPPING_LIKE_SHAPES,
SHAPE_FROZENSET,
SHAPE_GENERIC,
SHAPE_ITERABLE,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
Expand Down Expand Up @@ -450,7 +450,7 @@ def field_type_schema(
if field.shape in {SHAPE_SET, SHAPE_FROZENSET}:
f_schema['uniqueItems'] = True

elif field.shape == SHAPE_MAPPING:
elif field.shape in MAPPING_LIKE_SHAPES:
f_schema = {'type': 'object'}
key_field = cast(ModelField, field.key_field)
regex = getattr(key_field.type_, 'regex', None)
Expand Down
67 changes: 66 additions & 1 deletion tests/test_main.py
@@ -1,6 +1,7 @@
import sys
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints
from typing import Any, Callable, ClassVar, DefaultDict, Dict, List, Mapping, Optional, Type, get_type_hints
from uuid import UUID, uuid4

import pytest
Expand Down Expand Up @@ -1611,6 +1612,70 @@ class Item(BaseModel):
assert id(image_2) == id(item.images[1])


def test_mapping_retains_type_subclass():
class CustomMap(dict):
pass

class Model(BaseModel):
x: Mapping[str, Mapping[str, int]]

m = Model(x=CustomMap(outer=CustomMap(inner=42)))
assert isinstance(m.x, CustomMap)
assert isinstance(m.x['outer'], CustomMap)
assert m.x['outer']['inner'] == 42


def test_mapping_retains_type_defaultdict():
class Model(BaseModel):
x: Mapping[str, int]

d = defaultdict(int)
d[1] = '2'
d['3']

m = Model(x=d)
assert isinstance(m.x, defaultdict)
assert m.x['1'] == 2
assert m.x['3'] == 0


def test_mapping_retains_type_fallback_error():
class CustomMap(dict):
def __init__(self, *args, **kwargs):
if args or kwargs:
raise TypeError('test')
super().__init__(*args, **kwargs)

class Model(BaseModel):
x: Mapping[str, int]

d = CustomMap()
d['one'] = 1
d['two'] = 2

with pytest.raises(RuntimeError, match="Could not convert dictionary to 'CustomMap'"):
Model(x=d)


def test_typing_coercion_dict():
class Model(BaseModel):
x: Dict[str, int]

m = Model(x={'one': 1, 'two': 2})
assert repr(m) == "Model(x={'one': 1, 'two': 2})"


def test_typing_coercion_defaultdict():
class Model(BaseModel):
x: DefaultDict[int, str]

d = defaultdict(str)
d['1']
m = Model(x=d)
m.x['a']
assert repr(m) == "Model(x=defaultdict(<class 'str'>, {1: '', 'a': ''}))"


def test_class_kwargs_config():
class Base(BaseModel, extra='forbid', alias_generator=str.upper):
a: int
Expand Down

0 comments on commit 4ddf4f1

Please sign in to comment.