Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly retain types of Mapping subclasses #2325

Merged
merged 15 commits into from Feb 25, 2021
1 change: 1 addition & 0 deletions changes/2325-ofek.md
@@ -0,0 +1 @@
Prevent `Mapping` subclasses from always being coerced to `dict`
26 changes: 25 additions & 1 deletion pydantic/fields.py
@@ -1,5 +1,5 @@
import warnings
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable as CollectionsIterable
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -736,6 +736,13 @@ def _validate_mapping(
result[key_result] = value_result
if errors:
return v, errors
elif isinstance(v, Mapping):
same_mapping_type_res = _get_same_mapping_type_res(v, result)
if same_mapping_type_res is not None:
return same_mapping_type_res, None
else:
warnings.warn(f'Could not keep {v.__class__.__name__} when validating. Fallback done on dict...')
return result, None
else:
return result, None

Expand Down Expand Up @@ -850,3 +857,20 @@ def PrivateAttr(
default,
default_factory=default_factory,
)


def _get_same_mapping_type_res(mapping: T, converted: Dict[Any, Any]) -> Optional[T]:
"""
Try to return the same object as `mapping` but with `converted` values
"""
mapping_type = type(mapping)
if mapping_type is dict:
return converted # type: ignore
elif mapping_type is defaultdict:
return defaultdict(mapping.default_factory, **converted) # type: ignore
else:
try:
# Counter, OrderedDict, UserDict, ...
return mapping_type(**converted) # type: ignore
except TypeError:
return None
40 changes: 40 additions & 0 deletions tests/test_main.py
@@ -1,4 +1,5 @@
import sys
from collections import ChainMap, defaultdict
from enum import Enum
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints
from uuid import UUID, uuid4
Expand Down Expand Up @@ -1425,3 +1426,42 @@ class M(BaseModel):
a: int

get_type_hints(M.__config__)


def test_mapping_retains_type_subclass():
class Map(dict):
ofek marked this conversation as resolved.
Show resolved Hide resolved
pass

class Model(BaseModel):
field: Mapping[str, Mapping[str, int]]
ofek marked this conversation as resolved.
Show resolved Hide resolved

m = Model(field=Map(outer=Map(inner=42)))
assert isinstance(m.field, Map)
assert isinstance(m.field['outer'], Map)
assert m.field['outer']['inner'] == 42


def test_mapping_retains_type_defaultdict():
class Model(BaseModel):
field: Mapping[str, int]

d = defaultdict(int)
d[1] = '2'
d['3']

m = Model(field=d)
assert isinstance(m.field, defaultdict)
assert m.field['1'] == 2
assert m.field['3'] == 0


def test_mapping_retains_type_dict_fallback():
class Model(BaseModel):
field: Mapping[str, int]

with pytest.warns(UserWarning, match='Could not keep ChainMap when validating. Fallback done on dict...'):
m = Model(field=ChainMap({'one': 1}, {'two': 2}))

assert isinstance(m.field, dict)
assert m.field['one'] == 1
assert m.field['two'] == 2