Skip to content

Commit

Permalink
Add exclude/include as field parameters
Browse files Browse the repository at this point in the history
- Add "exclude" / "include" as a field parameter so that it can be
  configured using model config (or fields) instead of purely at
  `.dict` / `.json` export time.
- Unify merging logic of advanced include/exclude fields
- Add tests for merging logic and field/config exclude/include params
- Closes pydantic#660
  • Loading branch information
daviskirk committed Mar 12, 2021
1 parent 619ff26 commit 7bfa756
Show file tree
Hide file tree
Showing 8 changed files with 411 additions and 131 deletions.
1 change: 1 addition & 0 deletions changes/660-daviskirk.md
@@ -0,0 +1 @@
Add "exclude" as a field parameter so that it can be configured using model config instead of purely at `.dict` / `.json` export time.
1 change: 1 addition & 0 deletions docs/usage/schema.md
Expand Up @@ -52,6 +52,7 @@ It has the following arguments:
* `title`: if omitted, `field_name.title()` is used
* `description`: if omitted and the annotation is a sub-model,
the docstring of the sub-model will be used
* `exclude`: exclude this field when dumping (`.dict` and `.json`) the instance. The exact syntax and configuration options are described in details in the [exporting models section](exporting_models.md#advanced-include-and-exclude).
* `const`: this argument *must* be the same as the field's default value if present.
* `gt`: for numeric values (``int``, `float`, `Decimal`), adds a validation of "greater than" and an annotation
of `exclusiveMinimum` to the JSON Schema
Expand Down
26 changes: 23 additions & 3 deletions pydantic/fields.py
Expand Up @@ -43,7 +43,7 @@
is_typeddict,
new_type_supertype,
)
from .utils import PyObjectStr, Representation, lenient_issubclass, sequence_like, smart_deepcopy
from .utils import PyObjectStr, Representation, ValueItems, lenient_issubclass, sequence_like, smart_deepcopy
from .validators import constant_validator, dict_validator, find_validators, validate_json

Required: Any = Ellipsis
Expand Down Expand Up @@ -72,7 +72,7 @@ def __deepcopy__(self: T, _: Any) -> T:
from .error_wrappers import ErrorList
from .main import BaseConfig, BaseModel # noqa: F401
from .types import ModelOrDc # noqa: F401
from .typing import ReprArgs # noqa: F401
from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs # noqa: F401

ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]]
LocStr = Union[Tuple[Union[int, str], ...], str]
Expand All @@ -91,6 +91,8 @@ class FieldInfo(Representation):
'alias_priority',
'title',
'description',
'exclude',
'include',
'const',
'gt',
'ge',
Expand Down Expand Up @@ -128,6 +130,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.alias_priority = kwargs.pop('alias_priority', 2 if self.alias else None)
self.title = kwargs.pop('title', None)
self.description = kwargs.pop('description', None)
self.exclude = kwargs.pop('exclude', None)
self.include = kwargs.pop('include', None)
self.const = kwargs.pop('const', None)
self.gt = kwargs.pop('gt', None)
self.ge = kwargs.pop('ge', None)
Expand Down Expand Up @@ -180,6 +184,8 @@ def Field(
alias: str = None,
title: str = None,
description: str = None,
exclude: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None,
include: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None,
const: bool = None,
gt: float = None,
ge: float = None,
Expand All @@ -205,6 +211,10 @@ def Field(
:param alias: the public name of the field
:param title: can be any string, used in the schema
:param description: can be any string, used in the schema
:param exclude: exclude this field while dumping.
Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method.
:param include: include this field while dumping.
Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method.
:param const: this field is required and *must* take it's default value
:param gt: only applies to numbers, requires the field to be "greater than". The schema
will have an ``exclusiveMinimum`` validation keyword
Expand Down Expand Up @@ -232,6 +242,8 @@ def Field(
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
Expand Down Expand Up @@ -382,7 +394,8 @@ def _get_field_info(
field_info.update_from_config(field_info_from_config)
elif field_info is None:
field_info = FieldInfo(value, **field_info_from_config)

field_info.exclude = ValueItems.merge(field_info_from_config.get('exclude'), field_info.exclude)
field_info.include = ValueItems.merge(field_info_from_config.get('include'), field_info.include, intersect=True)
value = None if field_info.default_factory is not None else field_info.default
field_info._validate()
return field_info, value
Expand All @@ -407,6 +420,7 @@ def infer(
elif value is not Undefined:
required = False
annotation = get_annotation_from_field_info(annotation, field_info, name, config.validate_assignment)

return cls(
name=name,
type_=annotation,
Expand All @@ -429,6 +443,12 @@ def set_config(self, config: Type['BaseConfig']) -> None:
self.field_info.alias = new_alias
self.field_info.alias_priority = new_alias_priority
self.alias = new_alias
new_exclude = info_from_config.get('exclude')
if new_exclude is not None:
self.field_info.exclude = ValueItems.merge(self.field_info.exclude, new_exclude)
new_include = info_from_config.get('include')
if new_include is not None:
self.field_info.include = ValueItems.merge(self.field_info.include, new_include, intersect=True)

@property
def alt_alias(self) -> bool:
Expand Down
10 changes: 10 additions & 0 deletions pydantic/main.py
Expand Up @@ -842,6 +842,16 @@ def _iter(
exclude_none: bool = False,
) -> 'TupleGenerator':

# merge field set excludes with explicit exclude parameter with explicit overriding field set options.
field_exclude = {
k: v.field_info.exclude for k, v in self.__fields__.items() if v.field_info.exclude is not None
} or None
exclude = ValueItems.merge(field_exclude, exclude)
field_include = {
k: v.field_info.include for k, v in self.__fields__.items() if v.field_info.include is not None
} or None
include = ValueItems.merge(field_include, include, intersect=True)

allowed_keys = self._calculate_keys(include=include, exclude=exclude, exclude_unset=exclude_unset)
if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none):
# huge boost for plain _iter()
Expand Down
186 changes: 89 additions & 97 deletions pydantic/utils.py
Expand Up @@ -21,7 +21,6 @@
Type,
TypeVar,
Union,
no_type_check,
)

from .typing import GenericAlias, NoneType, display_as_type
Expand Down Expand Up @@ -291,49 +290,6 @@ def unique_list(input_list: Union[List[T], Tuple[T, ...]]) -> List[T]:
return result


def update_normalized_all(
item: Union['AbstractSetIntStr', 'MappingIntStrAny'],
all_items: Union['AbstractSetIntStr', 'MappingIntStrAny'],
) -> Union['AbstractSetIntStr', 'MappingIntStrAny']:
"""
Update item based on what all items contains.
The update is done based on these cases:
- if both arguments are dicts then each key-value pair existing in ``all_items`` is merged into ``item``,
while the rest of the key-value pairs are updated recursively with this function.
- if both arguments are sets then they are just merged.
- if ``item`` is a dictionary and ``all_items`` is a set then all values of it are added to ``item`` as
``key: ...``.
- if ``item`` is set and ``all_items`` is a dictionary, then ``item`` is converted to a dictionary and then the
key-value pairs of ``all_items`` are merged in it.
During recursive calls, there is a case where ``all_items`` can be an Ellipsis, in which case the ``item`` is
returned as is.
"""
if not item:
return all_items
if isinstance(item, dict) and isinstance(all_items, dict):
item = dict(item)
item.update({k: update_normalized_all(item[k], v) for k, v in all_items.items() if k in item})
item.update({k: v for k, v in all_items.items() if k not in item})
return item
if isinstance(item, set) and isinstance(all_items, set):
item = set(item)
item.update(all_items)
return item
if isinstance(item, dict) and isinstance(all_items, set):
item = dict(item)
item.update({k: ... for k in all_items if k not in item})
return item
if isinstance(item, set) and isinstance(all_items, dict):
item = {k: ... for k in item}
item.update({k: v for k, v in all_items.items() if k not in item})
return item
# Case when item or all_items is ... (in recursive calls).
return item


class PyObjectStr(str):
"""
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
Expand Down Expand Up @@ -467,23 +423,31 @@ class ValueItems(Representation):

def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
if TYPE_CHECKING:
self._items: Union['AbstractSetIntStr', 'MappingIntStrAny']
self._type: Type[Union[set, dict]] # type: ignore
self._items: 'MappingIntStrAny'

# For further type checks speed-up
if isinstance(items, Mapping):
self._type = dict
elif isinstance(items, AbstractSet):
self._type = set
else:
raise TypeError(f'Unexpected type of exclude value {items.__class__}')
items = self._coerce_items(items)

if isinstance(value, (list, tuple)):
items = self._normalize_indexes(items, len(value))

self._items = items

@no_type_check
@classmethod
def _coerce_items(cls, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
if isinstance(items, Mapping):
pass
elif isinstance(items, AbstractSet):
items = dict.fromkeys(items, ...)
else:
raise TypeError(f'Unexpected type of exclude value {items.__class__}')
return items

@classmethod
def _coerce_value(cls, value: Any) -> Any:
if value is None or value is ...:
return value
return cls._coerce_items(value)

def is_excluded(self, item: Any) -> bool:
"""
Check if item is fully excluded
Expand All @@ -492,11 +456,8 @@ def is_excluded(self, item: Any) -> bool:
:param item: key or index of a value
"""
if self._type is set:
return item in self._items
return self._items.get(item) is ...

@no_type_check
def is_included(self, item: Any) -> bool:
"""
Check if value is contained in self._items
Expand All @@ -505,22 +466,16 @@ def is_included(self, item: Any) -> bool:
"""
return item in self._items

@no_type_check
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
"""
:param e: key or index of element on value
:return: raw values for elemet if self._items is dict and contain needed element
"""

if self._type is dict:
item = self._items.get(e)
return item if item is not ... else None
return None
item = self._items.get(e)
return item if item is not ... else None

@no_type_check
def _normalize_indexes(
self, items: Union['AbstractSetIntStr', 'MappingIntStrAny'], v_length: int
) -> Union['AbstractSetIntStr', 'DictIntStrAny']:
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
"""
:param items: dict or set of indexes which will be normalized
:param v_length: length of sequence indexes of which will be
Expand All @@ -530,38 +485,75 @@ def _normalize_indexes(
>>> self._normalize_indexes({'__all__'}, 4)
{0, 1, 2, 3}
"""
if any(not isinstance(i, int) and i != '__all__' for i in items):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
if self._type is set:
if '__all__' in items:
if items != {'__all__'}:
raise ValueError('set with keyword "__all__" must not contain other elements')
return {i for i in range(v_length)}
return {v_length + i if i < 0 else i for i in items}
else:
all_items = items.get('__all__')
for i, v in items.items():
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or v is ...):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
normalized_items = {v_length + i if i < 0 else i: v for i, v in items.items() if i != '__all__'}
if all_items:
default: Type[Union[Set[Any], Dict[Any, Any]]]
if isinstance(all_items, Mapping):
default = dict
elif isinstance(all_items, AbstractSet):
default = set
else:
for i in range(v_length):
normalized_items.setdefault(i, ...)
return normalized_items
for i in range(v_length):
normalized_item = normalized_items.setdefault(i, default())
if normalized_item is not ...:
normalized_items[i] = update_normalized_all(normalized_item, all_items)

if TYPE_CHECKING:
normalized_items: 'DictIntStrAny'
all_items = None
normalized_items = {}
for i, v in items.items():
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or v is ...):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
if i == '__all__':
all_items = self._coerce_value(v)
continue
if not isinstance(i, int):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
if v is not None:
normalized_i = v_length + i if i < 0 else i
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))

if not all_items:
return normalized_items
if all_items is ...:
for i in range(v_length):
normalized_items.setdefault(i, ...)
return normalized_items
for i in range(v_length):
normalized_item = normalized_items.setdefault(i, {})
if normalized_item is not ...:
normalized_items[i] = self.merge(all_items, normalized_item)
return normalized_items

@classmethod
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
"""
Merge a ``base`` item with an ``override`` item.
Both ``base`` and ``override`` are converted to dictionaries if possible.
Sets are converted to dictionaries with the sets entries as keys and
Ellipsis as values.
Each key-value pair existing in ``base`` is merged with ``override``,
while the rest of the key-value pairs are updated recursively with this function.
Merging takes place based on the "union" of keys if ``intersect`` is
set to ``False`` (default) and on the intersection of keys if
``intersect`` is set to ``True``.
"""
override = cls._coerce_value(override)
base = cls._coerce_value(base)
if override is None:
return base
if base is ... or base is None:
return override
elif override is ...:
return base if intersect else override

if intersect:
merge_keys = override.keys() & base.keys()
else:
merge_keys = override.keys() | base.keys()

merged = {}
for k in merge_keys:
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
if merged_item is not None:
merged[k] = merged_item

return merged

def __repr_args__(self) -> 'ReprArgs':
return [(None, self._items)]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Expand Up @@ -15,7 +15,7 @@
except ImportError:
pytest_plugins = []
else:
pytest_plugins = ['hypothesis.extra.pytestplugin']
pytest_plugins = []


def _extract_source_code_from_function(function):
Expand Down

0 comments on commit 7bfa756

Please sign in to comment.