diff --git a/changes/1579-xspirus.md b/changes/1579-xspirus.md new file mode 100644 index 0000000000..5691505e37 --- /dev/null +++ b/changes/1579-xspirus.md @@ -0,0 +1 @@ +Fix behavior of `__all__` key when used in conjunction with index keys in advanced include/exclude of fields that are sequences. diff --git a/pydantic/utils.py b/pydantic/utils.py index e640ded34e..3978228a90 100644 --- a/pydantic/utils.py +++ b/pydantic/utils.py @@ -10,6 +10,7 @@ Generator, Iterator, List, + Mapping, Optional, Set, Tuple, @@ -238,6 +239,48 @@ def unique_list(input_list: Union[List[T], Tuple[T, ...]]) -> List[T]: return result +def update_normalized_all( + item: Union['AbstractSetIntStr', 'MappingIntStrAny'], all_items: Union['AbstractSetIntStr', 'MappingIntStrAny'], +) -> Union['AbstractSetIntStr', 'MappingIntStrAny']: + """ + Update item based on what all items contains. + + The update is done based on these cases: + + - if both arguments are dicts then each key-value pair existing in ``all_items`` is merged into ``item``, + while the rest of the key-value pairs are updated recursively with this function. + - if both arguments are sets then they are just merged. + - if ``item`` is a dictionary and ``all_items`` is a set then all values of it are added to ``item`` as + ``key: ...``. + - if ``item`` is set and ``all_items`` is a dictionary, then ``item`` is converted to a dictionary and then the + key-value pairs of ``all_items`` are merged in it. + + During recursive calls, there is a case where ``all_items`` can be an Ellipsis, in which case the ``item`` is + returned as is. + """ + if not item: + return all_items + if isinstance(item, dict) and isinstance(all_items, dict): + item = dict(item) + item.update({k: update_normalized_all(item[k], v) for k, v in all_items.items() if k in item}) + item.update({k: v for k, v in all_items.items() if k not in item}) + return item + if isinstance(item, set) and isinstance(all_items, set): + item = set(item) + item.update(all_items) + return item + if isinstance(item, dict) and isinstance(all_items, set): + item = dict(item) + item.update({k: ... for k in all_items if k not in item}) + return item + if isinstance(item, set) and isinstance(all_items, dict): + item = {k: ... for k in item} + item.update({k: v for k, v in all_items.items() if k not in item}) + return item + # Case when item or all_items is ... (in recursive calls). + return item + + class PyObjectStr(str): """ String class where repr doesn't include quotes. Useful with Representation when you want to return a string @@ -375,7 +418,7 @@ def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrA self._type: Type[Union[set, dict]] # type: ignore # For further type checks speed-up - if isinstance(items, dict): + if isinstance(items, Mapping): self._type = dict elif isinstance(items, AbstractSet): self._type = set @@ -383,13 +426,7 @@ def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrA raise TypeError(f'Unexpected type of exclude value {items.__class__}') if isinstance(value, (list, tuple)): - try: - items = self._normalize_indexes(items, len(value)) - except TypeError as e: - raise TypeError( - 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' - 'expected integer keys or keyword "__all__"' - ) from e + items = self._normalize_indexes(items, len(value)) self._items = items @@ -440,6 +477,11 @@ def _normalize_indexes( >>> self._normalize_indexes({'__all__'}, 4) {0, 1, 2, 3} """ + if any(not isinstance(i, int) and i != '__all__' for i in items): + raise TypeError( + 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' + 'expected integer keys or keyword "__all__"' + ) if self._type is set: if '__all__' in items: if items != {'__all__'}: @@ -447,12 +489,25 @@ def _normalize_indexes( return {i for i in range(v_length)} return {v_length + i if i < 0 else i for i in items} else: + all_items = items.get('__all__') + for i, v in items.items(): + if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or v is ...): + raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') normalized_items = {v_length + i if i < 0 else i: v for i, v in items.items() if i != '__all__'} - all_set = items.get('__all__') - if all_set: + if all_items: + default: Type[Union[Set[Any], Dict[Any, Any]]] + if isinstance(all_items, Mapping): + default = dict + elif isinstance(all_items, AbstractSet): + default = set + else: + for i in range(v_length): + normalized_items.setdefault(i, ...) + return normalized_items for i in range(v_length): - normalized_items.setdefault(i, set()).update(all_set) - + normalized_item = normalized_items.setdefault(i, default()) + if normalized_item is not ...: + normalized_items[i] = update_normalized_all(normalized_item, all_items) return normalized_items def __repr_args__(self) -> 'ReprArgs': diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 13e9e57d47..36acb17feb 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -509,6 +509,169 @@ class Model(BaseModel): } +@pytest.mark.parametrize( + 'exclude,expected', + [ + # Normal nested __all__ + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}}}, + {'subs': [{'k': 1, 'subsubs': [{'j': 1}, {'j': 2}]}, {'k': 2, 'subsubs': [{'j': 3}]}]}, + ), + # Merge sub dicts + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}, 0: {'subsubs': {'__all__': {'j'}}}}}, + {'subs': [{'k': 1, 'subsubs': [{}, {}]}, {'k': 2, 'subsubs': [{'j': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': ...}, 0: {'subsubs': {'__all__': {'j'}}}}}, + {'subs': [{'k': 1, 'subsubs': [{'i': 1}, {'i': 2}]}, {'k': 2}]}, + ), + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'j'}}}, 0: {'subsubs': ...}}}, + {'subs': [{'k': 1}, {'k': 2, 'subsubs': [{'i': 3}]}]}, + ), + # Merge sub sets + ( + {'subs': {'__all__': {'subsubs': {0}}, 0: {'subsubs': {1}}}}, + {'subs': [{'k': 1, 'subsubs': []}, {'k': 2, 'subsubs': []}]}, + ), + # Merge sub dict-set + ( + {'subs': {'__all__': {'subsubs': {0: {'i'}}}, 0: {'subsubs': {1}}}}, + {'subs': [{'k': 1, 'subsubs': [{'j': 1}]}, {'k': 2, 'subsubs': [{'j': 3}]}]}, + ), + # Different keys + ({'subs': {'__all__': {'subsubs'}, 0: {'k'}}}, {'subs': [{}, {'k': 2}]}), + ({'subs': {'__all__': {'subsubs': ...}, 0: {'k'}}}, {'subs': [{}, {'k': 2}]}), + ({'subs': {'__all__': {'subsubs'}, 0: {'k': ...}}}, {'subs': [{}, {'k': 2}]}), + # Nested different keys + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j'}}}}}, + {'subs': [{'k': 1, 'subsubs': [{}, {'j': 2}]}, {'k': 2, 'subsubs': [{}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i': ...}, 0: {'j'}}}}}, + {'subs': [{'k': 1, 'subsubs': [{}, {'j': 2}]}, {'k': 2, 'subsubs': [{}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j': ...}}}}}, + {'subs': [{'k': 1, 'subsubs': [{}, {'j': 2}]}, {'k': 2, 'subsubs': [{}]}]}, + ), + # Ignore __all__ for index with defined exclude + ( + {'subs': {'__all__': {'subsubs'}, 0: {'subsubs': {'__all__': {'j'}}}}}, + {'subs': [{'k': 1, 'subsubs': [{'i': 1}, {'i': 2}]}, {'k': 2}]}, + ), + ({'subs': {'__all__': {'subsubs': {'__all__': {'j'}}}, 0: ...}}, {'subs': [{'k': 2, 'subsubs': [{'i': 3}]}]}), + ({'subs': {'__all__': ..., 0: {'subsubs'}}}, {'subs': [{'k': 1}]}), + ], +) +def test_advanced_exclude_nested_lists(exclude, expected): + class SubSubModel(BaseModel): + i: int + j: int + + class SubModel(BaseModel): + k: int + subsubs: List[SubSubModel] + + class Model(BaseModel): + subs: List[SubModel] + + m = Model(subs=[dict(k=1, subsubs=[dict(i=1, j=1), dict(i=2, j=2)]), dict(k=2, subsubs=[dict(i=3, j=3)])]) + + assert m.dict(exclude=exclude) == expected + + +@pytest.mark.parametrize( + 'include,expected', + [ + # Normal nested __all__ + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}}}, + {'subs': [{'subsubs': [{'i': 1}, {'i': 2}]}, {'subsubs': [{'i': 3}]}]}, + ), + # Merge sub dicts + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}}}, 0: {'subsubs': {'__all__': {'j'}}}}}, + {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': ...}, 0: {'subsubs': {'__all__': {'j'}}}}}, + {'subs': [{'subsubs': [{'j': 1}, {'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'j'}}}, 0: {'subsubs': ...}}}, + {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'j': 3}]}]}, + ), + # Merge sub sets + ( + {'subs': {'__all__': {'subsubs': {0}}, 0: {'subsubs': {1}}}}, + {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + # Merge sub dict-set + ( + {'subs': {'__all__': {'subsubs': {0: {'i'}}}, 0: {'subsubs': {1}}}}, + {'subs': [{'subsubs': [{'i': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3}]}]}, + ), + # Different keys + ( + {'subs': {'__all__': {'subsubs'}, 0: {'k'}}}, + {'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': ...}, 0: {'k'}}}, + {'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs'}, 0: {'k': ...}}}, + {'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + # Nested different keys + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j'}}}}}, + {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i': ...}, 0: {'j'}}}}}, + {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'i'}, 0: {'j': ...}}}}}, + {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + # Ignore __all__ for index with defined include + ( + {'subs': {'__all__': {'subsubs'}, 0: {'subsubs': {'__all__': {'j'}}}}}, + {'subs': [{'subsubs': [{'j': 1}, {'j': 2}]}, {'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + ( + {'subs': {'__all__': {'subsubs': {'__all__': {'j'}}}, 0: ...}}, + {'subs': [{'k': 1, 'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'subsubs': [{'j': 3}]}]}, + ), + ( + {'subs': {'__all__': ..., 0: {'subsubs'}}}, + {'subs': [{'subsubs': [{'i': 1, 'j': 1}, {'i': 2, 'j': 2}]}, {'k': 2, 'subsubs': [{'i': 3, 'j': 3}]}]}, + ), + ], +) +def test_advanced_include_nested_lists(include, expected): + class SubSubModel(BaseModel): + i: int + j: int + + class SubModel(BaseModel): + k: int + subsubs: List[SubSubModel] + + class Model(BaseModel): + subs: List[SubModel] + + m = Model(subs=[dict(k=1, subsubs=[dict(i=1, j=1), dict(i=2, j=2)]), dict(k=2, subsubs=[dict(i=3, j=3)])]) + + assert m.dict(include=include) == expected + + def test_field_set_ignore_extra(): class Model(BaseModel): a: int diff --git a/tests/test_main.py b/tests/test_main.py index e39d22df36..9912aa0271 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1030,13 +1030,19 @@ class Bar(BaseModel): with pytest.raises(TypeError, match='expected integer keys'): m.dict(exclude={'foos': {'a'}}) + with pytest.raises(TypeError, match='expected integer keys'): + m.dict(exclude={'foos': {0: ..., 'a': ...}}) + with pytest.raises(TypeError, match='Unexpected type'): + m.dict(exclude={'foos': {0: 1}}) + with pytest.raises(TypeError, match='Unexpected type'): + m.dict(exclude={'foos': {'__all__': 1}}) assert m.dict(exclude={'foos': {0: {'b'}, '__all__': {'a'}}}) == {'c': 3, 'foos': [{}, {'b': 4}]} assert m.dict(exclude={'foos': {'__all__': {'a'}}}) == {'c': 3, 'foos': [{'b': 2}, {'b': 4}]} assert m.dict(exclude={'foos': {'__all__'}}) == {'c': 3, 'foos': []} with pytest.raises(ValueError, match='set with keyword "__all__" must not contain other elements'): - m.dict(exclude={'foos': {'a', '__all__'}}) + m.dict(exclude={'foos': {1, '__all__'}}) def test_model_export_dict_exclusion():