Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly retain types of Mapping subclasses #2325

Merged
merged 15 commits into from Feb 25, 2021
1 change: 1 addition & 0 deletions changes/2325-ofek.md
@@ -0,0 +1 @@
Prevent `Mapping` subclasses from always being coerced to `dict`
23 changes: 21 additions & 2 deletions pydantic/fields.py
@@ -1,5 +1,5 @@
import warnings
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable as CollectionsIterable
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -37,6 +37,7 @@
get_args,
get_origin,
is_literal_type,
is_mapping_type,
is_new_type,
new_type_supertype,
)
Expand Down Expand Up @@ -737,7 +738,25 @@ def _validate_mapping(
if errors:
return v, errors
else:
return result, None
return self._get_mapping_value(v, result), None

def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]:
target_type = get_origin(self.outer_type_)
original_type = type(original)

if is_mapping_type(target_type):
target_type = original_type

if target_type is dict:
ofek marked this conversation as resolved.
Show resolved Hide resolved
return converted
elif target_type is defaultdict:
ofek marked this conversation as resolved.
Show resolved Hide resolved
return defaultdict(getattr(original, 'default_factory', None), **converted)
ofek marked this conversation as resolved.
Show resolved Hide resolved
else:
try:
# Counter, OrderedDict, UserDict, ...
return target_type(**converted)
except TypeError:
return converted

def _validate_singleton(
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
Expand Down
8 changes: 8 additions & 0 deletions pydantic/typing.py
Expand Up @@ -186,6 +186,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
'resolve_annotations',
'is_callable_type',
'is_literal_type',
'is_mapping_type',
'literal_values',
'Literal',
'is_new_type',
Expand Down Expand Up @@ -268,6 +269,10 @@ def is_callable_type(type_: Type[Any]) -> bool:


if sys.version_info >= (3, 7):
from collections.abc import Mapping as CollectionsMapping

def is_mapping_type(type_: Type[Any]) -> bool:
return type_ is CollectionsMapping

def is_literal_type(type_: Type[Any]) -> bool:
return Literal is not None and get_origin(type_) is Literal
Expand All @@ -278,6 +283,9 @@ def literal_values(type_: Type[Any]) -> Tuple[Any, ...]:

else:

def is_mapping_type(type_: Type[Any]) -> bool:
return type_ is Mapping

def is_literal_type(type_: Type[Any]) -> bool:
return Literal is not None and hasattr(type_, '__values__') and type_ == Literal[type_.__values__]

Expand Down
38 changes: 38 additions & 0 deletions tests/test_main.py
@@ -1,4 +1,5 @@
import sys
from collections import ChainMap, defaultdict
from enum import Enum
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints
from uuid import UUID, uuid4
Expand Down Expand Up @@ -1425,3 +1426,40 @@ class M(BaseModel):
a: int

get_type_hints(M.__config__)


def test_mapping_retains_type_subclass():
class Map(dict):
ofek marked this conversation as resolved.
Show resolved Hide resolved
pass

class Model(BaseModel):
field: Mapping[str, Mapping[str, int]]
ofek marked this conversation as resolved.
Show resolved Hide resolved

m = Model(field=Map(outer=Map(inner=42)))
assert isinstance(m.field, Map)
assert isinstance(m.field['outer'], Map)
assert m.field['outer']['inner'] == 42


def test_mapping_retains_type_defaultdict():
class Model(BaseModel):
field: Mapping[str, int]

d = defaultdict(int)
d[1] = '2'
d['3']

m = Model(field=d)
assert isinstance(m.field, defaultdict)
assert m.field['1'] == 2
assert m.field['3'] == 0


def test_mapping_retains_type_dict_fallback():
class Model(BaseModel):
field: Mapping[str, int]

m = Model(field=ChainMap({'one': 1}, {'two': 2}))
assert isinstance(m.field, dict)
assert m.field['one'] == 1
assert m.field['two'] == 2