Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly retain types of Mapping subclasses #2325

Merged
merged 15 commits into from Feb 25, 2021
1 change: 1 addition & 0 deletions changes/2325-ofek.md
@@ -0,0 +1 @@
Prevent `Mapping` subclasses from always being coerced to `dict`
50 changes: 44 additions & 6 deletions pydantic/fields.py
@@ -1,9 +1,10 @@
import warnings
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable as CollectionsIterable
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Deque,
Dict,
FrozenSet,
Expand Down Expand Up @@ -211,6 +212,8 @@ def Schema(default: Any, **kwargs: Any) -> Any:
SHAPE_ITERABLE = 9
SHAPE_GENERIC = 10
SHAPE_DEQUE = 11
SHAPE_DICT = 12
SHAPE_DEFAULTDICT = 13
SHAPE_NAME_LOOKUP = {
SHAPE_LIST: 'List[{}]',
SHAPE_SET: 'Set[{}]',
Expand All @@ -219,8 +222,12 @@ def Schema(default: Any, **kwargs: Any) -> Any:
SHAPE_FROZENSET: 'FrozenSet[{}]',
SHAPE_ITERABLE: 'Iterable[{}]',
SHAPE_DEQUE: 'Deque[{}]',
SHAPE_DICT: 'Dict[{}]',
SHAPE_DEFAULTDICT: 'DefaultDict[{}]',
}

MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING}


class ModelField(Representation):
__slots__ = (
Expand Down Expand Up @@ -492,6 +499,14 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
elif issubclass(origin, Sequence):
self.type_ = get_args(self.type_)[0]
self.shape = SHAPE_SEQUENCE
elif issubclass(origin, DefaultDict):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
self.shape = SHAPE_DEFAULTDICT
elif issubclass(origin, Dict):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
self.shape = SHAPE_DICT
elif issubclass(origin, Mapping):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
Expand Down Expand Up @@ -608,8 +623,8 @@ def validate(

if self.shape == SHAPE_SINGLETON:
v, errors = self._validate_singleton(v, values, loc, cls)
elif self.shape == SHAPE_MAPPING:
v, errors = self._validate_mapping(v, values, loc, cls)
elif self.shape in MAPPING_LIKE_SHAPES:
v, errors = self._validate_mapping_like(v, values, loc, cls)
elif self.shape == SHAPE_TUPLE:
v, errors = self._validate_tuple(v, values, loc, cls)
elif self.shape == SHAPE_ITERABLE:
Expand Down Expand Up @@ -726,7 +741,7 @@ def _validate_tuple(
else:
return tuple(result), None

def _validate_mapping(
def _validate_mapping_like(
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
) -> 'ValidateReturn':
try:
Expand All @@ -752,8 +767,31 @@ def _validate_mapping(
result[key_result] = value_result
if errors:
return v, errors
else:
elif self.shape == SHAPE_DICT:
return result, None
elif self.shape == SHAPE_DEFAULTDICT:
return defaultdict(self.type_, result), None
else:
return self._get_mapping_value(v, result), None

def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]:
"""
When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid
coercing to `dict` unwillingly.
"""
original_cls = original.__class__

if original_cls == dict or original_cls == Dict:
return converted
elif original_cls in {defaultdict, DefaultDict}:
return defaultdict(self.type_, converted)
else:
try:
# Counter, OrderedDict, UserDict, ...
return original_cls(converted) # type: ignore
except TypeError:
warnings.warn(f'Could not convert dictionary to {original_cls.__name__!r}', UserWarning)
ofek marked this conversation as resolved.
Show resolved Hide resolved
return converted

def _validate_singleton(
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
Expand Down Expand Up @@ -796,7 +834,7 @@ def _type_display(self) -> PyObjectStr:
t = display_as_type(self.type_)

# have to do this since display_as_type(self.outer_type_) is different (and wrong) on python 3.6
if self.shape == SHAPE_MAPPING:
if self.shape in MAPPING_LIKE_SHAPES:
t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore
elif self.shape == SHAPE_TUPLE:
t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions pydantic/main.py
Expand Up @@ -29,7 +29,7 @@
from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
from .error_wrappers import ErrorWrapper, ValidationError
from .errors import ConfigError, DictError, ExtraError, MissingError
from .fields import SHAPE_MAPPING, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
from .fields import MAPPING_LIKE_SHAPES, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
from .json import custom_pydantic_encoder, pydantic_encoder
from .parse import Protocol, load_file, load_str_bytes
from .schema import default_ref_template, model_schema
Expand Down Expand Up @@ -524,7 +524,8 @@ def json(
@classmethod
def parse_obj(cls: Type['Model'], obj: Any) -> 'Model':
if cls.__custom_root_type__ and (
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) or cls.__fields__[ROOT_KEY].shape == SHAPE_MAPPING
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY})
or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES
):
obj = {ROOT_KEY: obj}
elif not isinstance(obj, dict):
Expand Down
4 changes: 2 additions & 2 deletions pydantic/schema.py
Expand Up @@ -26,10 +26,10 @@
from uuid import UUID

from .fields import (
MAPPING_LIKE_SHAPES,
SHAPE_FROZENSET,
SHAPE_ITERABLE,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
Expand Down Expand Up @@ -446,7 +446,7 @@ def field_type_schema(
if field.shape in {SHAPE_SET, SHAPE_FROZENSET}:
f_schema['uniqueItems'] = True

elif field.shape == SHAPE_MAPPING:
elif field.shape in MAPPING_LIKE_SHAPES:
f_schema = {'type': 'object'}
key_field = cast(ModelField, field.key_field)
regex = getattr(key_field.type_, 'regex', None)
Expand Down
71 changes: 70 additions & 1 deletion tests/test_main.py
@@ -1,6 +1,7 @@
import sys
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints
from typing import Any, Callable, ClassVar, DefaultDict, Dict, List, Mapping, Optional, Type, get_type_hints
from uuid import UUID, uuid4

import pytest
Expand Down Expand Up @@ -1472,3 +1473,71 @@ class Item(BaseModel):

assert id(image_1) == id(item.images[0])
assert id(image_2) == id(item.images[1])


def test_mapping_retains_type_subclass():
class Map(dict):
ofek marked this conversation as resolved.
Show resolved Hide resolved
pass

class Model(BaseModel):
field: Mapping[str, Mapping[str, int]]
ofek marked this conversation as resolved.
Show resolved Hide resolved

m = Model(field=Map(outer=Map(inner=42)))
assert isinstance(m.field, Map)
assert isinstance(m.field['outer'], Map)
assert m.field['outer']['inner'] == 42


def test_mapping_retains_type_defaultdict():
class Model(BaseModel):
field: Mapping[str, int]

d = defaultdict(int)
d[1] = '2'
d['3']

m = Model(field=d)
assert isinstance(m.field, defaultdict)
assert m.field['1'] == 2
assert m.field['3'] == 0


def test_mapping_retains_type_dict_fallback():
class Map(dict):
def __init__(self, *args, **kwargs):
if args or kwargs:
raise TypeError('test')
super().__init__(*args, **kwargs)

class Model(BaseModel):
field: Mapping[str, int]

d = Map()
d['one'] = 1
d['two'] = 2

with pytest.warns(UserWarning, match="Could not convert dictionary to 'Map'"):
m = Model(field=d)

assert isinstance(m.field, dict)
assert m.field['one'] == 1
assert m.field['two'] == 2


def test_typing_coercion_dict():
class Model(BaseModel):
x: Dict[str, int]

m = Model(x={'one': 1, 'two': 2})
assert repr(m) == "Model(x={'one': 1, 'two': 2})"


def test_typing_coercion_defaultdict():
class Model(BaseModel):
x: DefaultDict[int, str]

d = defaultdict(str)
d['1']
m = Model(x=d)
m.x['a']
assert repr(m) == "Model(x=defaultdict(<class 'str'>, {1: '', 'a': ''}))"