Skip to content

Commit

Permalink
Fix nested all include exclude
Browse files Browse the repository at this point in the history
  • Loading branch information
xspirus committed Jun 2, 2020
1 parent f89e372 commit 7ee0a64
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 13 deletions.
2 changes: 2 additions & 0 deletions changes/1579-xspirus.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix behavior of `__all__` key when used in conjunction with index keys in advanced include/exclude of fields that are
sequences.
60 changes: 48 additions & 12 deletions pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Generator,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -221,6 +222,28 @@ def to_camel(string: str) -> str:
return ''.join(word.capitalize() for word in string.split('_'))


def update_normalized_all(
item: Union['AbstractSetIntStr', 'MappingIntStrAny'], all_items: Union['AbstractSetIntStr', 'MappingIntStrAny'],
) -> Union['AbstractSetIntStr', 'MappingIntStrAny']:
if not item:
return all_items
if isinstance(item, dict) and isinstance(all_items, dict):
item.update({k: update_normalized_all(item[k], v) for k, v in all_items.items() if k in item})
item.update({k: v for k, v in all_items.items() if k not in item})
return item
if isinstance(item, set) and isinstance(all_items, set):
item.update(all_items)
return item
if isinstance(item, dict) and isinstance(all_items, set):
item.update({k: ... for k in all_items if k not in item})
return item
if isinstance(item, set) and isinstance(all_items, dict):
new_item = {k: ... for k in item}
new_item.update({k: v for k, v in all_items.items() if k not in item})
return new_item
return all_items # pragma: no cover


class PyObjectStr(str):
"""
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
Expand Down Expand Up @@ -358,21 +381,15 @@ def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrA
self._type: Type[Union[set, dict]] # type: ignore

# For further type checks speed-up
if isinstance(items, dict):
if isinstance(items, Mapping):
self._type = dict
elif isinstance(items, AbstractSet):
self._type = set
else:
raise TypeError(f'Unexpected type of exclude value {items.__class__}')

if isinstance(value, (list, tuple)):
try:
items = self._normalize_indexes(items, len(value))
except TypeError as e:
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
) from e
items = self._normalize_indexes(items, len(value))

self._items = items

Expand Down Expand Up @@ -423,19 +440,38 @@ def _normalize_indexes(
>>> self._normalize_indexes({'__all__'}, 4)
{0, 1, 2, 3}
"""
if any(not isinstance(i, int) and i != '__all__' for i in items):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
if self._type is set:
if '__all__' in items:
if items != {'__all__'}:
raise ValueError('set with keyword "__all__" must not contain other elements')
return {i for i in range(v_length)}
return {v_length + i if i < 0 else i for i in items}
else:
all_items = items.get('__all__')
for i, v in items.items():
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or v is ...):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
normalized_items = {v_length + i if i < 0 else i: v for i, v in items.items() if i != '__all__'}
all_set = items.get('__all__')
if all_set:
if all_items:
default: Type[Union[set, dict]] # type: ignore
if isinstance(all_items, Mapping):
default = dict
elif isinstance(all_items, AbstractSet):
default = set
else:
default = ...
for i in range(v_length):
normalized_items.setdefault(i, set()).update(all_set)

if default is ...:
normalized_items.setdefault(i, ...)
continue
normalized_item = normalized_items.setdefault(i, default())
if normalized_item is not ...:
normalized_items[i] = update_normalized_all(normalized_item, all_items)
return normalized_items

def __repr_args__(self) -> 'ReprArgs':
Expand Down
126 changes: 126 additions & 0 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,132 @@ class Model(BaseModel):
}


def test_advanced_exclude_nested_lists():
class SubSubModel(BaseModel):
i: int
j: int

class SubModel(BaseModel):
k: int
subsubs: List[SubSubModel]

class Model(BaseModel):
subs: List[SubModel]

m = Model(
subs=[
SubModel(k=1, subsubs=[SubSubModel(i=1, j=1), SubSubModel(i=2, j=2)]),
SubModel(k=2, subsubs=[SubSubModel(i=3, j=3)]),
]
)

# Normal nested __all__
assert m.dict(exclude={'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}}}) == {
'subs': [{'k': 1, 'subsubs': [{'j': 1}, {'j': 2}]}, {'k': 2, 'subsubs': [{'j': 3}]}]
}
# Merge sub dicts
assert m.dict(
exclude={'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}, 0: {'subsubs': {'__all__': {'j'}}}}}
) == {'subs': [{'k': 1, 'subsubs': [{}, {}]}, {'k': 2, 'subsubs': [{'j': 3}]}]}
# Merge sub sets
assert m.dict(exclude={'subs': {'__all__': {'subsubs': {0}}, 0: {'subsubs': {1}}}}) == {
'subs': [{'k': 1, 'subsubs': []}, {'k': 2, 'subsubs': []}]
}
# Merge sub dict-set
assert m.dict(exclude={'subs': {'__all__': {'subsubs': {0: {'i'}}}, 0: {'subsubs': {1}}}}) == {
'subs': [{'k': 1, 'subsubs': [{'j': 1}]}, {'k': 2, 'subsubs': [{'j': 3}]}]
}
# Different keys
assert m.dict(exclude={'subs': {'__all__': {'subsubs'}, 0: {'k'}}}) == {'subs': [{}, {'k': 2}]}
assert m.dict(exclude={'subs': {'__all__': {'subsubs': ...}, 0: {'k'}}}) == {'subs': [{}, {'k': 2}]}
assert m.dict(exclude={'subs': {'__all__': {'subsubs'}, 0: {'k': ...}}}) == {'subs': [{}, {'k': 2}]}
# Nested different keys
assert m.dict(exclude={'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j'}}}}}) == {
'subs': [{'k': 1, 'subsubs': [{}, {'j': 2}]}, {'k': 2, 'subsubs': [{}]}]
}
assert m.dict(exclude={'subs': {'__all__': {'subsubs': {'__all__': {'i': ...}, 0: {'j'}}}}}) == {
'subs': [{'k': 1, 'subsubs': [{}, {'j': 2}]}, {'k': 2, 'subsubs': [{}]}]
}
assert m.dict(exclude={'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j': ...}}}}}) == {
'subs': [{'k': 1, 'subsubs': [{}, {'j': 2}]}, {'k': 2, 'subsubs': [{}]}]
}
# Ignore __all__ for index with defined exclude.
assert m.dict(exclude={'subs': {'__all__': {'subsubs'}, 0: {'subsubs': {'__all__': {'j'}}}}}) == {
'subs': [{'k': 1, 'subsubs': [{'i': 1}, {'i': 2}]}, {'k': 2}]
}
assert m.dict(exclude={'subs': {'__all__': {'subsubs': {'__all__': {'j'}}}, 0: ...}}) == {
'subs': [{'k': 2, 'subsubs': [{'i': 3}]}]
}
assert m.dict(exclude={'subs': {'__all__': ..., 0: {'subsubs'}}}) == {'subs': [{'k': 1}]}


def test_advanced_include_nested_lists():
class SubSubModel(BaseModel):
i: int
j: int

class SubModel(BaseModel):
k: int
subsubs: List[SubSubModel]

class Model(BaseModel):
subs: List[SubModel]

m = Model(
subs=[
SubModel(k=1, subsubs=[SubSubModel(i=1, j=1), SubSubModel(i=2, j=2)]),
SubModel(k=2, subsubs=[SubSubModel(i=3, j=3)]),
]
)

# Normal nested __all__
assert m.dict(include={'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}}}) == {
'subs': [{'subsubs': [{'i': 1}, {'i': 2}]}, {'subsubs': [{'i': 3}]}]
}
# Merge sub dicts
assert m.dict(
include={'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}, 0: {'subsubs': {'__all__': {'j'}}}}}
) == {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3}]}]}
# Merge sub sets
assert m.dict(include={'subs': {'__all__': {'subsubs': {0}}, 0: {'subsubs': {1}}}}) == {
'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
# Merge sub dict-set
assert m.dict(include={'subs': {'__all__': {'subsubs': {0: {'i'}}}, 0: {'subsubs': {1}}}}) == {
'subs': [{'subsubs': [{'i': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3}]}]
}
# Different keys
assert m.dict(include={'subs': {'__all__': {'subsubs'}, 0: {'k'}}}) == {
'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
assert m.dict(include={'subs': {'__all__': {'subsubs': ...}, 0: {'k'}}}) == {
'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
assert m.dict(include={'subs': {'__all__': {'subsubs'}, 0: {'k': ...}}}) == {
'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
# Nested different keys
assert m.dict(include={'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j'}}}}}) == {
'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
assert m.dict(include={'subs': {'__all__': {'subsubs': {'__all__': {'i': ...}, 0: {'j'}}}}}) == {
'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
assert m.dict(include={'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j': ...}}}}}) == {
'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
# Ignore __all__ for index with defined include.
assert m.dict(include={'subs': {'__all__': {'subsubs'}, 0: {'subsubs': {'__all__': {'j'}}}}}) == {
'subs': [{'subsubs': [{'j': 1}, {'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]
}
assert m.dict(include={'subs': {'__all__': {'subsubs': {'__all__': {'j'}}}, 0: ...}}) == {
'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'j': 3}]}]
}
assert m.dict(include={'subs': {'__all__': ..., 0: {'subsubs'}}}) == {
'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'k': 2, 'subsubs': [{'i': 3, 'j': 3}]}]
}


def test_field_set_ignore_extra():
class Model(BaseModel):
a: int
Expand Down
8 changes: 7 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,13 +1004,19 @@ class Bar(BaseModel):

with pytest.raises(TypeError, match='expected integer keys'):
m.dict(exclude={'foos': {'a'}})
with pytest.raises(TypeError, match='expected integer keys'):
m.dict(exclude={'foos': {0: ..., 'a': ...}})
with pytest.raises(TypeError, match='Unexpected type'):
m.dict(exclude={'foos': {0: 1}})
with pytest.raises(TypeError, match='Unexpected type'):
m.dict(exclude={'foos': {'__all__': 1}})

assert m.dict(exclude={'foos': {0: {'b'}, '__all__': {'a'}}}) == {'c': 3, 'foos': [{}, {'b': 4}]}
assert m.dict(exclude={'foos': {'__all__': {'a'}}}) == {'c': 3, 'foos': [{'b': 2}, {'b': 4}]}
assert m.dict(exclude={'foos': {'__all__'}}) == {'c': 3, 'foos': []}

with pytest.raises(ValueError, match='set with keyword "__all__" must not contain other elements'):
m.dict(exclude={'foos': {'a', '__all__'}})
m.dict(exclude={'foos': {1, '__all__'}})


def test_model_export_dict_exclusion():
Expand Down

0 comments on commit 7ee0a64

Please sign in to comment.