diff --git a/changes/660-daviskirk.md b/changes/660-daviskirk.md new file mode 100644 index 0000000000..8076b02621 --- /dev/null +++ b/changes/660-daviskirk.md @@ -0,0 +1 @@ +Add "exclude" as a field parameter so that it can be configured using model config instead of purely at `.dict` / `.json` export time. diff --git a/docs/examples/exporting_models_exclude1.py b/docs/examples/exporting_models_exclude1.py index 386c1cf468..ca14abefff 100644 --- a/docs/examples/exporting_models_exclude1.py +++ b/docs/examples/exporting_models_exclude1.py @@ -27,6 +27,6 @@ class Transaction(BaseModel): print(t.dict(exclude={'user', 'value'})) # using a dict: -print(t.dict(exclude={'user': {'username', 'password'}, 'value': ...})) +print(t.dict(exclude={'user': {'username', 'password'}, 'value': True})) -print(t.dict(include={'id': ..., 'user': {'id'}})) +print(t.dict(include={'id': True, 'user': {'id'}})) diff --git a/docs/examples/exporting_models_exclude2.py b/docs/examples/exporting_models_exclude2.py index 9c5ce06cb9..90d0447990 100644 --- a/docs/examples/exporting_models_exclude2.py +++ b/docs/examples/exporting_models_exclude2.py @@ -53,17 +53,17 @@ class User(BaseModel): ) exclude_keys = { - 'second_name': ..., - 'address': {'post_code': ..., 'country': {'phone_code'}}, - 'card_details': ..., + 'second_name': True, + 'address': {'post_code': True, 'country': {'phone_code'}}, + 'card_details': True, # You can exclude fields from specific members of a tuple/list by index: 'hobbies': {-1: {'info'}}, } include_keys = { - 'first_name': ..., + 'first_name': True, 'address': {'country': {'name'}}, - 'hobbies': {0: ..., -1: {'name'}}, + 'hobbies': {0: True, -1: {'name'}}, } # would be the same as user.dict(exclude=exclude_keys) in this case: diff --git a/docs/examples/exporting_models_exclude3.py b/docs/examples/exporting_models_exclude3.py new file mode 100644 index 0000000000..d43e75199b --- /dev/null +++ b/docs/examples/exporting_models_exclude3.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel, Field, SecretStr + + +class User(BaseModel): + id: int + username: str + password: SecretStr = Field(..., exclude=True) + + +class Transaction(BaseModel): + id: str + user: User = Field(..., exclude={'username'}) + value: int + + class Config: + fields = {'value': {'exclude': True}} + + +t = Transaction( + id='1234567890', + user=User( + id=42, + username='JohnDoe', + password='hashedpassword' + ), + value=9876543210, +) + +print(t.dict()) diff --git a/docs/examples/exporting_models_exclude4.py b/docs/examples/exporting_models_exclude4.py new file mode 100644 index 0000000000..713ddf6944 --- /dev/null +++ b/docs/examples/exporting_models_exclude4.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field, SecretStr + + +class User(BaseModel): + id: int + username: str # overridden by explicit exclude + password: SecretStr = Field(exclude=True) + + +class Transaction(BaseModel): + id: str + user: User + value: int + + +t = Transaction( + id='1234567890', + user=User( + id=42, + username='JohnDoe', + password='hashedpassword' + ), + value=9876543210, +) + +print(t.dict(exclude={'value': True, 'user': {'username'}})) diff --git a/docs/examples/exporting_models_exclude5.py b/docs/examples/exporting_models_exclude5.py new file mode 100644 index 0000000000..c6262cb1b9 --- /dev/null +++ b/docs/examples/exporting_models_exclude5.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field, SecretStr + + +class User(BaseModel): + id: int = Field(..., include=True) + username: str = Field(..., include=True) # overridden by explicit include + password: SecretStr + + +class Transaction(BaseModel): + id: str + user: User + value: int + + +t = Transaction( + id='1234567890', + user=User( + id=42, + username='JohnDoe', + password='hashedpassword' + ), + value=9876543210, +) + +print(t.dict(include={'id': True, 'user': {'id'}})) diff --git a/docs/usage/exporting_models.md b/docs/usage/exporting_models.md index 15cbd206b3..7c21d520cd 100644 --- a/docs/usage/exporting_models.md +++ b/docs/usage/exporting_models.md @@ -162,7 +162,7 @@ sets or dictionaries. This allows nested selection of which fields to export: {!.tmp_examples/exporting_models_exclude1.py!} ``` -The ellipsis (``...``) indicates that we want to exclude or include an entire key, just as if we included it in a set. +The `True` indicates that we want to exclude or include an entire key, just as if we included it in a set. Of course, the same can be done at any depth level. Special care must be taken when including or excluding fields from a list or tuple of submodels or dictionaries. In this scenario, @@ -174,3 +174,30 @@ member of a list or tuple, the dictionary key `'__all__'` can be used as follows ``` The same holds for the `json` and `copy` methods. + +### Model and field level include and exclude + +In addition to the explicit arguments `exclude` and `include` passed to `dict`, `json` and `copy` methods, we can also pass the `include`/`exclude` arguments directly to the `Field` constructor or the equivalent `field` entry in the models `Config` class: + +```py +{!.tmp_examples/exporting_models_exclude3.py!} +``` + +In the case where multiple strategies are used, `exclude`/`include` fields are merged according to the following rules: + +* First, model config level settings (via `"fields"` entry) are merged per field with the field constructor settings (i.e. `Field(..., exclude=True)`), with the field constructor taking priority. +* The resulting settings are merged per class with the explicit settings on `dict`, `json`, `copy` calls with the explicit settings taking priority. + +Note that while merging settings, `exclude` entries are merged by computing the "union" of keys, while `include` entries are merged by computing the "intersection" of keys. + +The resulting merged exclude settings: + +```py +{!.tmp_examples/exporting_models_exclude4.py!} +``` + +are the same as using merged include settings as follows: + +```py +{!.tmp_examples/exporting_models_exclude5.py!} +``` diff --git a/docs/usage/schema.md b/docs/usage/schema.md index 83ffd6d109..1a4c5f06f1 100644 --- a/docs/usage/schema.md +++ b/docs/usage/schema.md @@ -52,6 +52,8 @@ It has the following arguments: * `title`: if omitted, `field_name.title()` is used * `description`: if omitted and the annotation is a sub-model, the docstring of the sub-model will be used +* `exclude`: exclude this field when dumping (`.dict` and `.json`) the instance. The exact syntax and configuration options are described in details in the [exporting models section](exporting_models.md#advanced-include-and-exclude). +* `include`: include (only) this field when dumping (`.dict` and `.json`) the instance. The exact syntax and configuration options are described in details in the [exporting models section](exporting_models.md#advanced-include-and-exclude). * `const`: this argument *must* be the same as the field's default value if present. * `gt`: for numeric values (``int``, `float`, `Decimal`), adds a validation of "greater than" and an annotation of `exclusiveMinimum` to the JSON Schema diff --git a/pydantic/fields.py b/pydantic/fields.py index 449dd55d4c..75652a84e7 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -43,7 +43,7 @@ is_typeddict, new_type_supertype, ) -from .utils import PyObjectStr, Representation, lenient_issubclass, sequence_like, smart_deepcopy +from .utils import PyObjectStr, Representation, ValueItems, lenient_issubclass, sequence_like, smart_deepcopy from .validators import constant_validator, dict_validator, find_validators, validate_json Required: Any = Ellipsis @@ -72,7 +72,7 @@ def __deepcopy__(self: T, _: Any) -> T: from .error_wrappers import ErrorList from .main import BaseConfig, BaseModel # noqa: F401 from .types import ModelOrDc # noqa: F401 - from .typing import ReprArgs # noqa: F401 + from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs # noqa: F401 ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] LocStr = Union[Tuple[Union[int, str], ...], str] @@ -91,6 +91,8 @@ class FieldInfo(Representation): 'alias_priority', 'title', 'description', + 'exclude', + 'include', 'const', 'gt', 'ge', @@ -128,6 +130,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.alias_priority = kwargs.pop('alias_priority', 2 if self.alias else None) self.title = kwargs.pop('title', None) self.description = kwargs.pop('description', None) + self.exclude = kwargs.pop('exclude', None) + self.include = kwargs.pop('include', None) self.const = kwargs.pop('const', None) self.gt = kwargs.pop('gt', None) self.ge = kwargs.pop('ge', None) @@ -167,6 +171,10 @@ def update_from_config(self, from_config: Dict[str, Any]) -> None: else: if current_value is self.__field_constraints__.get(attr_name, None): setattr(self, attr_name, value) + elif attr_name == 'exclude': + self.exclude = ValueItems.merge(value, current_value) + elif attr_name == 'include': + self.include = ValueItems.merge(value, current_value, intersect=True) def _validate(self) -> None: if self.default not in (Undefined, Ellipsis) and self.default_factory is not None: @@ -180,6 +188,8 @@ def Field( alias: str = None, title: str = None, description: str = None, + exclude: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, + include: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, const: bool = None, gt: float = None, ge: float = None, @@ -205,6 +215,10 @@ def Field( :param alias: the public name of the field :param title: can be any string, used in the schema :param description: can be any string, used in the schema + :param exclude: exclude this field while dumping. + Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. + :param include: include this field while dumping. + Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. :param const: this field is required and *must* take it's default value :param gt: only applies to numbers, requires the field to be "greater than". The schema will have an ``exclusiveMinimum`` validation keyword @@ -232,6 +246,8 @@ def Field( alias=alias, title=title, description=description, + exclude=exclude, + include=include, const=const, gt=gt, ge=ge, @@ -382,7 +398,6 @@ def _get_field_info( field_info.update_from_config(field_info_from_config) elif field_info is None: field_info = FieldInfo(value, **field_info_from_config) - value = None if field_info.default_factory is not None else field_info.default field_info._validate() return field_info, value @@ -407,6 +422,7 @@ def infer( elif value is not Undefined: required = False annotation = get_annotation_from_field_info(annotation, field_info, name, config.validate_assignment) + return cls( name=name, type_=annotation, @@ -429,6 +445,12 @@ def set_config(self, config: Type['BaseConfig']) -> None: self.field_info.alias = new_alias self.field_info.alias_priority = new_alias_priority self.alias = new_alias + new_exclude = info_from_config.get('exclude') + if new_exclude is not None: + self.field_info.exclude = ValueItems.merge(self.field_info.exclude, new_exclude) + new_include = info_from_config.get('include') + if new_include is not None: + self.field_info.include = ValueItems.merge(self.field_info.include, new_include, intersect=True) @property def alt_alias(self) -> bool: diff --git a/pydantic/main.py b/pydantic/main.py index f6aca41048..b0d58cae71 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -345,6 +345,14 @@ def is_untouched(v: Any) -> bool: new_namespace = { '__config__': config, '__fields__': fields, + '__exclude_fields__': { + name: field.field_info.exclude for name, field in fields.items() if field.field_info.exclude is not None + } + or None, + '__include_fields__': { + name: field.field_info.include for name, field in fields.items() if field.field_info.include is not None + } + or None, '__validators__': vg.validators, '__pre_root_validators__': unique_list(pre_root_validators + pre_rv_new), '__post_root_validators__': unique_list(post_root_validators + post_rv_new), @@ -371,6 +379,8 @@ class BaseModel(Representation, metaclass=ModelMetaclass): if TYPE_CHECKING: # populated by the metaclass, defined here to help IDEs only __fields__: Dict[str, ModelField] = {} + __include_fields__: Optional[Mapping[str, Any]] = None + __exclude_fields__: Optional[Mapping[str, Any]] = None __validators__: Dict[str, AnyCallable] = {} __pre_root_validators__: List[AnyCallable] __post_root_validators__: List[Tuple[bool, AnyCallable]] @@ -842,14 +852,24 @@ def _iter( exclude_none: bool = False, ) -> 'TupleGenerator': - allowed_keys = self._calculate_keys(include=include, exclude=exclude, exclude_unset=exclude_unset) + # Merge field set excludes with explicit exclude parameter with explicit overriding field set options. + # The extra "is not None" guards are not logically necessary but optimizes performance for the simple case. + if exclude is not None or self.__exclude_fields__ is not None: + exclude = ValueItems.merge(self.__exclude_fields__, exclude) + + if include is not None or self.__include_fields__ is not None: + include = ValueItems.merge(self.__include_fields__, include, intersect=True) + + allowed_keys = self._calculate_keys( + include=include, exclude=exclude, exclude_unset=exclude_unset # type: ignore + ) if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none): # huge boost for plain _iter() yield from self.__dict__.items() return - value_exclude = ValueItems(self, exclude) if exclude else None - value_include = ValueItems(self, include) if include else None + value_exclude = ValueItems(self, exclude) if exclude is not None else None + value_include = ValueItems(self, include) if include is not None else None for field_key, v in self.__dict__.items(): if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None): @@ -880,8 +900,8 @@ def _iter( def _calculate_keys( self, - include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], + include: Optional['MappingIntStrAny'], + exclude: Optional['MappingIntStrAny'], exclude_unset: bool, update: Optional['DictStrAny'] = None, ) -> Optional[AbstractSet[str]]: @@ -895,19 +915,13 @@ def _calculate_keys( keys = self.__dict__.keys() if include is not None: - if isinstance(include, Mapping): - keys &= include.keys() - else: - keys &= include + keys &= include.keys() if update: keys -= update.keys() if exclude: - if isinstance(exclude, Mapping): - keys -= {k for k, v in exclude.items() if v is ...} - else: - keys -= exclude + keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)} return keys diff --git a/pydantic/utils.py b/pydantic/utils.py index 8a8351c6e5..c2a6e514f0 100644 --- a/pydantic/utils.py +++ b/pydantic/utils.py @@ -21,7 +21,6 @@ Type, TypeVar, Union, - no_type_check, ) from .typing import GenericAlias, NoneType, display_as_type @@ -291,49 +290,6 @@ def unique_list(input_list: Union[List[T], Tuple[T, ...]]) -> List[T]: return result -def update_normalized_all( - item: Union['AbstractSetIntStr', 'MappingIntStrAny'], - all_items: Union['AbstractSetIntStr', 'MappingIntStrAny'], -) -> Union['AbstractSetIntStr', 'MappingIntStrAny']: - """ - Update item based on what all items contains. - - The update is done based on these cases: - - - if both arguments are dicts then each key-value pair existing in ``all_items`` is merged into ``item``, - while the rest of the key-value pairs are updated recursively with this function. - - if both arguments are sets then they are just merged. - - if ``item`` is a dictionary and ``all_items`` is a set then all values of it are added to ``item`` as - ``key: ...``. - - if ``item`` is set and ``all_items`` is a dictionary, then ``item`` is converted to a dictionary and then the - key-value pairs of ``all_items`` are merged in it. - - During recursive calls, there is a case where ``all_items`` can be an Ellipsis, in which case the ``item`` is - returned as is. - """ - if not item: - return all_items - if isinstance(item, dict) and isinstance(all_items, dict): - item = dict(item) - item.update({k: update_normalized_all(item[k], v) for k, v in all_items.items() if k in item}) - item.update({k: v for k, v in all_items.items() if k not in item}) - return item - if isinstance(item, set) and isinstance(all_items, set): - item = set(item) - item.update(all_items) - return item - if isinstance(item, dict) and isinstance(all_items, set): - item = dict(item) - item.update({k: ... for k in all_items if k not in item}) - return item - if isinstance(item, set) and isinstance(all_items, dict): - item = {k: ... for k in item} - item.update({k: v for k, v in all_items.items() if k not in item}) - return item - # Case when item or all_items is ... (in recursive calls). - return item - - class PyObjectStr(str): """ String class where repr doesn't include quotes. Useful with Representation when you want to return a string @@ -466,37 +422,21 @@ class ValueItems(Representation): __slots__ = ('_items', '_type') def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None: - if TYPE_CHECKING: - self._items: Union['AbstractSetIntStr', 'MappingIntStrAny'] - self._type: Type[Union[set, dict]] # type: ignore - - # For further type checks speed-up - if isinstance(items, Mapping): - self._type = dict - elif isinstance(items, AbstractSet): - self._type = set - else: - raise TypeError(f'Unexpected type of exclude value {items.__class__}') + items = self._coerce_items(items) if isinstance(value, (list, tuple)): items = self._normalize_indexes(items, len(value)) - self._items = items + self._items: 'MappingIntStrAny' = items - @no_type_check def is_excluded(self, item: Any) -> bool: """ - Check if item is fully excluded - (value considered excluded if self._type is set and item contained in self._items - or self._type is dict and self._items.get(item) is ... + Check if item is fully excluded. :param item: key or index of a value """ - if self._type is set: - return item in self._items - return self._items.get(item) is ... + return self.is_true(self._items.get(item)) - @no_type_check def is_included(self, item: Any) -> bool: """ Check if value is contained in self._items @@ -505,63 +445,112 @@ def is_included(self, item: Any) -> bool: """ return item in self._items - @no_type_check def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]: """ :param e: key or index of element on value :return: raw values for elemet if self._items is dict and contain needed element """ - if self._type is dict: - item = self._items.get(e) - return item if item is not ... else None - return None + item = self._items.get(e) + return item if not self.is_true(item) else None - @no_type_check - def _normalize_indexes( - self, items: Union['AbstractSetIntStr', 'MappingIntStrAny'], v_length: int - ) -> Union['AbstractSetIntStr', 'DictIntStrAny']: + def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny': """ :param items: dict or set of indexes which will be normalized :param v_length: length of sequence indexes of which will be - >>> self._normalize_indexes({0, -2, -1}, 4) - {0, 2, 3} - >>> self._normalize_indexes({'__all__'}, 4) - {0, 1, 2, 3} + >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4) + {0: True, 2: True, 3: True} + >>> self._normalize_indexes({'__all__': True}, 4) + {0: True, 1: True, 2: True, 3: True} """ - if any(not isinstance(i, int) and i != '__all__' for i in items): - raise TypeError( - 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' - 'expected integer keys or keyword "__all__"' - ) - if self._type is set: - if '__all__' in items: - if items != {'__all__'}: - raise ValueError('set with keyword "__all__" must not contain other elements') - return {i for i in range(v_length)} - return {v_length + i if i < 0 else i for i in items} - else: - all_items = items.get('__all__') - for i, v in items.items(): - if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or v is ...): - raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') - normalized_items = {v_length + i if i < 0 else i: v for i, v in items.items() if i != '__all__'} - if all_items: - default: Type[Union[Set[Any], Dict[Any, Any]]] - if isinstance(all_items, Mapping): - default = dict - elif isinstance(all_items, AbstractSet): - default = set - else: - for i in range(v_length): - normalized_items.setdefault(i, ...) - return normalized_items - for i in range(v_length): - normalized_item = normalized_items.setdefault(i, default()) - if normalized_item is not ...: - normalized_items[i] = update_normalized_all(normalized_item, all_items) + + normalized_items: 'DictIntStrAny' = {} + all_items = None + for i, v in items.items(): + if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)): + raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') + if i == '__all__': + all_items = self._coerce_value(v) + continue + if not isinstance(i, int): + raise TypeError( + 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' + 'expected integer keys or keyword "__all__"' + ) + normalized_i = v_length + i if i < 0 else i + normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i)) + + if not all_items: return normalized_items + if self.is_true(all_items): + for i in range(v_length): + normalized_items.setdefault(i, ...) + return normalized_items + for i in range(v_length): + normalized_item = normalized_items.setdefault(i, {}) + if not self.is_true(normalized_item): + normalized_items[i] = self.merge(all_items, normalized_item) + return normalized_items + + @classmethod + def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any: + """ + Merge a ``base`` item with an ``override`` item. + + Both ``base`` and ``override`` are converted to dictionaries if possible. + Sets are converted to dictionaries with the sets entries as keys and + Ellipsis as values. + + Each key-value pair existing in ``base`` is merged with ``override``, + while the rest of the key-value pairs are updated recursively with this function. + + Merging takes place based on the "union" of keys if ``intersect`` is + set to ``False`` (default) and on the intersection of keys if + ``intersect`` is set to ``True``. + """ + override = cls._coerce_value(override) + base = cls._coerce_value(base) + if override is None: + return base + if cls.is_true(base) or base is None: + return override + if cls.is_true(override): + return base if intersect else override + + # intersection or union of keys while preserving ordering: + if intersect: + merge_keys = [k for k in base if k in override] + [k for k in override if k in base] + else: + merge_keys = list(base) + [k for k in override if k not in base] + + merged: 'DictIntStrAny' = {} + for k in merge_keys: + merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect) + if merged_item is not None: + merged[k] = merged_item + + return merged + + @staticmethod + def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny': + if isinstance(items, Mapping): + pass + elif isinstance(items, AbstractSet): + items = dict.fromkeys(items, ...) + else: + raise TypeError(f'Unexpected type of exclude value {items.__class__}') + return items + + @classmethod + def _coerce_value(cls, value: Any) -> Any: + if value is None or cls.is_true(value): + return value + return cls._coerce_items(value) + + @staticmethod + def is_true(v: Any) -> bool: + return v is True or v is ... def __repr_args__(self) -> 'ReprArgs': return [(None, self._items)] diff --git a/tests/test_construction.py b/tests/test_construction.py index 8a8130c489..19e912b398 100644 --- a/tests/test_construction.py +++ b/tests/test_construction.py @@ -329,8 +329,8 @@ class Model(BaseModel): assert m.copy(exclude={'c'}).dict() == {'d': {'a': 'ax', 'b': 'bx'}} assert m.copy(exclude={'c'}, update={'c': 42}).dict() == {'c': 42, 'd': {'a': 'ax', 'b': 'bx'}} - assert m._calculate_keys(exclude={'x'}, include=None, exclude_unset=False) == {'c', 'd'} - assert m._calculate_keys(exclude={'x'}, include=None, exclude_unset=False, update={'c': 42}) == {'d'} + assert m._calculate_keys(exclude={'x': ...}, include=None, exclude_unset=False) == {'c', 'd'} + assert m._calculate_keys(exclude={'x': ...}, include=None, exclude_unset=False, update={'c': 42}) == {'d'} def test_shallow_copy_modify(): diff --git a/tests/test_main.py b/tests/test_main.py index 328e89a2b3..9e2ae43f17 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,10 +1,12 @@ import sys from collections import defaultdict +from copy import deepcopy from enum import Enum from typing import Any, Callable, ClassVar, DefaultDict, Dict, List, Mapping, Optional, Type, get_type_hints from uuid import UUID, uuid4 import pytest +from pytest import param from pydantic import ( BaseModel, @@ -1348,7 +1350,72 @@ class Bar(BaseModel): assert dict(m) == {'c': 3, 'd': Foo()} -def test_model_export_nested_list(): +@pytest.mark.parametrize( + 'exclude,expected,raises_match', + [ + param( + {'foos': {0: {'a'}, 1: {'a'}}}, + {'c': 3, 'foos': [{'b': 2}, {'b': 4}]}, + None, + id='excluding fields of indexed list items', + ), + param( + {'foos': {'a'}}, + TypeError, + 'expected integer keys', + id='should fail trying to exclude string keys on list field (1).', + ), + param( + {'foos': {0: ..., 'a': ...}}, + TypeError, + 'expected integer keys', + id='should fail trying to exclude string keys on list field (2).', + ), + param( + {'foos': {0: 1}}, + TypeError, + 'Unexpected type', + id='should fail using integer key to specify list item field name (1)', + ), + param( + {'foos': {'__all__': 1}}, + TypeError, + 'Unexpected type', + id='should fail using integer key to specify list item field name (2)', + ), + param( + {'foos': {'__all__': {'a'}}}, + {'c': 3, 'foos': [{'b': 2}, {'b': 4}]}, + None, + id='using "__all__" to exclude specific nested field', + ), + param( + {'foos': {0: {'b'}, '__all__': {'a'}}}, + {'c': 3, 'foos': [{}, {'b': 4}]}, + None, + id='using "__all__" to exclude specific nested field in combination with more specific exclude', + ), + param( + {'foos': {'__all__'}}, + {'c': 3, 'foos': []}, + None, + id='using "__all__" to exclude all list items', + ), + param( + {'foos': {1, '__all__'}}, + {'c': 3, 'foos': []}, + None, + id='using "__all__" and other items should get merged together, still excluding all list items', + ), + param( + {'foos': {1: {'a'}, -1: {'b'}}}, + {'c': 3, 'foos': [{'a': 1, 'b': 2}, {}]}, + None, + id='using negative and positive indexes, referencing the same items should merge excludes', + ), + ], +) +def test_model_export_nested_list(exclude, expected, raises_match): class Foo(BaseModel): a: int = 1 b: int = 2 @@ -1359,41 +1426,234 @@ class Bar(BaseModel): m = Bar(c=3, foos=[Foo(a=1, b=2), Foo(a=3, b=4)]) - assert m.dict(exclude={'foos': {0: {'a'}, 1: {'a'}}}) == {'c': 3, 'foos': [{'b': 2}, {'b': 4}]} + if isinstance(expected, type) and issubclass(expected, Exception): + with pytest.raises(expected, match=raises_match): + m.dict(exclude=exclude) + else: + original_exclude = deepcopy(exclude) + assert m.dict(exclude=exclude) == expected + assert exclude == original_exclude + + +@pytest.mark.parametrize( + 'excludes,expected', + [ + param( + {'bars': {0}}, + {'a': 1, 'bars': [{'y': 2}, {'w': -1, 'z': 3}]}, + id='excluding first item from list field using index', + ), + param({'bars': {'__all__'}}, {'a': 1, 'bars': []}, id='using "__all__" to exclude all list items'), + param( + {'bars': {'__all__': {'w'}}}, + {'a': 1, 'bars': [{'x': 1}, {'y': 2}, {'z': 3}]}, + id='exclude single dict key from all list items', + ), + ], +) +def test_model_export_dict_exclusion(excludes, expected): + class Foo(BaseModel): + a: int = 1 + bars: List[Dict[str, int]] - with pytest.raises(TypeError, match='expected integer keys'): - m.dict(exclude={'foos': {'a'}}) - with pytest.raises(TypeError, match='expected integer keys'): - m.dict(exclude={'foos': {0: ..., 'a': ...}}) - with pytest.raises(TypeError, match='Unexpected type'): - m.dict(exclude={'foos': {0: 1}}) - with pytest.raises(TypeError, match='Unexpected type'): - m.dict(exclude={'foos': {'__all__': 1}}) + m = Foo(a=1, bars=[{'w': 0, 'x': 1}, {'y': 2}, {'w': -1, 'z': 3}]) - assert m.dict(exclude={'foos': {0: {'b'}, '__all__': {'a'}}}) == {'c': 3, 'foos': [{}, {'b': 4}]} - assert m.dict(exclude={'foos': {'__all__': {'a'}}}) == {'c': 3, 'foos': [{'b': 2}, {'b': 4}]} - assert m.dict(exclude={'foos': {'__all__'}}) == {'c': 3, 'foos': []} + original_excludes = deepcopy(excludes) + assert m.dict(exclude=excludes) == expected + assert excludes == original_excludes - with pytest.raises(ValueError, match='set with keyword "__all__" must not contain other elements'): - m.dict(exclude={'foos': {1, '__all__'}}) +def test_model_exclude_config_field_merging(): + """Test merging field exclude values from config.""" -def test_model_export_dict_exclusion(): - class Foo(BaseModel): - a: int = 1 - bars: List[Dict[str, int]] + class Model(BaseModel): + b: int = Field(2, exclude=...) - m = Foo(a=1, bars=[{'w': 0, 'x': 1}, {'y': 2}, {'w': -1, 'z': 3}]) + class Config: + fields = { + 'b': {'exclude': ...}, + } + + assert Model.__fields__['b'].field_info.exclude is ... + + class Model(BaseModel): + b: int = Field(2, exclude={'a': {'test'}}) + + class Config: + fields = { + 'b': {'exclude': ...}, + } + + assert Model.__fields__['b'].field_info.exclude == {'a': {'test'}} + + class Model(BaseModel): + b: int = Field(2, exclude={'foo'}) + + class Config: + fields = { + 'b': {'exclude': {'bar'}}, + } + + assert Model.__fields__['b'].field_info.exclude == {'foo': ..., 'bar': ...} + + +@pytest.mark.parametrize( + 'kinds', + [ + {'sub_fields', 'model_fields', 'model_config', 'sub_config', 'combined_config'}, + {'sub_fields', 'model_fields', 'combined_config'}, + {'sub_fields', 'model_fields'}, + {'combined_config'}, + {'model_config', 'sub_config'}, + {'model_config', 'sub_fields'}, + {'model_fields', 'sub_config'}, + ], +) +@pytest.mark.parametrize( + 'exclude,expected', + [ + (None, {'a': 0, 'c': {'a': [3, 5], 'c': 'foobar'}, 'd': {'c': 'foobar'}}), + ({'c', 'd'}, {'a': 0}), + ({'a': ..., 'c': ..., 'd': {'a': ..., 'c': ...}}, {'d': {}}), + ], +) +def test_model_export_exclusion_with_fields_and_config(kinds, exclude, expected): + """Test that exporting models with fields using the export parameter works.""" + + class ChildConfig: + pass + + if 'sub_config' in kinds: + ChildConfig.fields = {'b': {'exclude': ...}, 'a': {'exclude': {1}}} + + class ParentConfig: + pass + + if 'combined_config' in kinds: + ParentConfig.fields = { + 'b': {'exclude': ...}, + 'c': {'exclude': {'b': ..., 'a': {1}}}, + 'd': {'exclude': {'a': ..., 'b': ...}}, + } + + elif 'model_config' in kinds: + ParentConfig.fields = {'b': {'exclude': ...}, 'd': {'exclude': {'a'}}} + + class Sub(BaseModel): + a: List[int] = Field([3, 4, 5], exclude={1} if 'sub_fields' in kinds else None) + b: int = Field(4, exclude=... if 'sub_fields' in kinds else None) + c: str = 'foobar' + + Config = ChildConfig + + class Model(BaseModel): + a: int = 0 + b: int = Field(2, exclude=... if 'model_fields' in kinds else None) + c: Sub = Sub() + d: Sub = Field(Sub(), exclude={'a'} if 'model_fields' in kinds else None) - excludes = {'bars': {0}} - assert m.dict(exclude=excludes) == {'a': 1, 'bars': [{'y': 2}, {'w': -1, 'z': 3}]} - assert excludes == {'bars': {0}} - excludes = {'bars': {'__all__'}} - assert m.dict(exclude=excludes) == {'a': 1, 'bars': []} - assert excludes == {'bars': {'__all__'}} - excludes = {'bars': {'__all__': {'w'}}} - assert m.dict(exclude=excludes) == {'a': 1, 'bars': [{'x': 1}, {'y': 2}, {'z': 3}]} - assert excludes == {'bars': {'__all__': {'w'}}} + Config = ParentConfig + + m = Model() + assert m.dict(exclude=exclude) == expected, 'Unexpected model export result' + + +def test_model_export_exclusion_inheritance(): + class Sub(BaseModel): + s1: str = 'v1' + s2: str = 'v2' + s3: str = 'v3' + s4: str = Field('v4', exclude=...) + + class Parent(BaseModel): + a: int + b: int = Field(..., exclude=...) + c: int + d: int + s: Sub = Sub() + + class Config: + fields = {'a': {'exclude': ...}, 's': {'exclude': {'s1'}}} + + class Child(Parent): + class Config: + fields = {'c': {'exclude': ...}, 's': {'exclude': {'s2'}}} + + actual = Child(a=0, b=1, c=2, d=3).dict() + expected = {'d': 3, 's': {'s3': 'v3'}} + assert actual == expected, 'Unexpected model export result' + + +def test_model_export_with_true_instead_of_ellipsis(): + class Sub(BaseModel): + s1: int = 1 + + class Model(BaseModel): + a: int = 2 + b: int = Field(3, exclude=True) + c: int = Field(4) + s: Sub = Sub() + + class Config: + fields = {'c': {'exclude': True}} + + m = Model() + assert m.dict(exclude={'s': True}) == {'a': 2} + + +def test_model_export_inclusion(): + class Sub(BaseModel): + s1: str = 'v1' + s2: str = 'v2' + s3: str = 'v3' + s4: str = 'v4' + + class Model(BaseModel): + a: Sub = Sub() + b: Sub = Field(Sub(), include={'s1'}) + c: Sub = Field(Sub(), include={'s1', 's2'}) + + class Config: + fields = {'a': {'include': {'s2', 's1', 's3'}}, 'b': {'include': {'s1', 's2', 's3', 's4'}}} + + Model.__fields__['a'].field_info.include == {'s1': ..., 's2': ..., 's3': ...} + Model.__fields__['b'].field_info.include == {'s1': ...} + Model.__fields__['c'].field_info.include == {'s1': ..., 's2': ...} + + actual = Model().dict(include={'a': {'s3', 's4'}, 'b': ..., 'c': ...}) + # s1 included via field, s2 via config and s3 via .dict call: + expected = {'a': {'s3': 'v3'}, 'b': {'s1': 'v1'}, 'c': {'s1': 'v1', 's2': 'v2'}} + + assert actual == expected, 'Unexpected model export result' + + +def test_model_export_inclusion_inheritance(): + class Sub(BaseModel): + s1: str = Field('v1', include=...) + s2: str = Field('v2', include=...) + s3: str = Field('v3', include=...) + s4: str = 'v4' + + class Parent(BaseModel): + a: int + b: int + c: int + s: Sub = Field(Sub(), include={'s1', 's2'}) # overrides includes set in Sub model + + class Config: + # b will be included since fields are set idependently + fields = {'b': {'include': ...}} + + class Child(Parent): + class Config: + # b is still included even if it doesn't occur here since fields + # are still considered separately. + # s however, is merged, resulting in only s1 being included. + fields = {'a': {'include': ...}, 's': {'include': {'s1'}}} + + actual = Child(a=0, b=1, c=2).dict() + expected = {'a': 0, 'b': 1, 's': {'s1': 'v1'}} + assert actual == expected, 'Unexpected model export result' def test_custom_init_subclass_params(): diff --git a/tests/test_utils.py b/tests/test_utils.py index d3c9922865..fc1a3814f7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -177,7 +177,7 @@ def test_value_items(): sub_v = included['a'] sub_vi = ValueItems(sub_v, vi.for_element('a')) - assert repr(sub_vi) == 'ValueItems({0, 2})' + assert repr(sub_vi) == 'ValueItems({0: Ellipsis, 2: Ellipsis})' assert sub_vi.is_excluded(2) assert [v_ for i, v_ in enumerate(sub_v) if not sub_vi.is_excluded(i)] == ['b'] @@ -186,6 +186,55 @@ def test_value_items(): assert [v_ for i, v_ in enumerate(sub_v) if sub_vi.is_included(i)] == ['a', 'c'] +@pytest.mark.parametrize( + 'base,override,intersect,expected', + [ + # Check in default (union) mode + (..., ..., False, ...), + (None, None, False, None), + ({}, {}, False, {}), + (..., None, False, ...), + (None, ..., False, ...), + (None, {}, False, {}), + ({}, None, False, {}), + (..., {}, False, {}), + ({}, ..., False, ...), + ({'a': None}, {'a': None}, False, {}), + ({'a'}, ..., False, ...), + ({'a'}, {}, False, {'a': ...}), + ({'a'}, {'b'}, False, {'a': ..., 'b': ...}), + ({'a': ...}, {'b': {'c'}}, False, {'a': ..., 'b': {'c': ...}}), + ({'a': ...}, {'a': {'c'}}, False, {'a': {'c': ...}}), + ({'a': {'c': ...}, 'b': {'d'}}, {'a': ...}, False, {'a': ..., 'b': {'d': ...}}), + # Check in intersection mode + (..., ..., True, ...), + (None, None, True, None), + ({}, {}, True, {}), + (..., None, True, ...), + (None, ..., True, ...), + (None, {}, True, {}), + ({}, None, True, {}), + (..., {}, True, {}), + ({}, ..., True, {}), + ({'a': None}, {'a': None}, True, {}), + ({'a'}, ..., True, {'a': ...}), + ({'a'}, {}, True, {}), + ({'a'}, {'b'}, True, {}), + ({'a': ...}, {'b': {'c'}}, True, {}), + ({'a': ...}, {'a': {'c'}}, True, {'a': {'c': ...}}), + ({'a': {'c': ...}, 'b': {'d'}}, {'a': ...}, True, {'a': {'c': ...}}), + # Check usage of `True` instead of `...` + (..., True, False, True), + (True, ..., False, ...), + (True, None, False, True), + ({'a': {'c': True}, 'b': {'d'}}, {'a': True}, False, {'a': True, 'b': {'d': ...}}), + ], +) +def test_value_items_merge(base, override, intersect, expected): + actual = ValueItems.merge(base, override, intersect=intersect) + assert actual == expected + + def test_value_items_error(): with pytest.raises(TypeError) as e: ValueItems(1, (1, 2, 3))