Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support discriminated union #2336

Merged
merged 57 commits into from
Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
22d8458
feat: add discriminated union
PrettyWood Feb 9, 2021
6dd3b73
feat: add OpenAPI spec schema
PrettyWood Feb 9, 2021
315ad3a
test: add basic example for generated schema
PrettyWood Feb 20, 2021
50d6d02
test: add validation tests
PrettyWood Feb 9, 2021
4be10c7
docs: add basic documentation
PrettyWood Feb 11, 2021
9bad44d
fix: support ForwardRef
PrettyWood Feb 20, 2021
1d6e263
test: add ForwardRef case
PrettyWood Feb 20, 2021
682bec0
fix: false positive lint error
PrettyWood Feb 20, 2021
bbb1998
improve error
PrettyWood Feb 23, 2021
c77027f
add schema/schema_json utils
PrettyWood Feb 23, 2021
6a8363f
fix tests after merge
PrettyWood Mar 13, 2021
8121c9f
refactor: add `discriminator` attribute to `FieldInfo`
PrettyWood Mar 13, 2021
d98624e
refactor: @cybojenix remarks
PrettyWood Mar 13, 2021
1c2586b
fix schema with forward ref
PrettyWood Mar 13, 2021
d7408c8
start nested
PrettyWood Mar 14, 2021
0ddb984
feat: add allowed values in error message
PrettyWood Mar 14, 2021
eb5e517
fix wrong check
PrettyWood Mar 14, 2021
1170e15
test: add nested examples
PrettyWood Mar 14, 2021
0dfa998
remove uncovered code as we don't need it
PrettyWood Mar 14, 2021
e39bdc5
docs: add nested example
PrettyWood Mar 14, 2021
dba3129
fix: support properly Annotated Field syntax
PrettyWood Mar 14, 2021
c2ec4f9
support naked annotated
PrettyWood Mar 15, 2021
4fa7bf5
fix: handle TypeError
PrettyWood Apr 7, 2021
052bfa5
make error loc more explicit
PrettyWood Apr 7, 2021
80614b8
fix behaviour with basemodel instance as value
PrettyWood Apr 19, 2021
5bd6b1d
support schema for dataclasses
PrettyWood Apr 19, 2021
8a8588f
Merge branch 'master' into PrettyWood-f/discriminated-union
samuelcolvin May 1, 2021
51e945c
tweak examples
samuelcolvin May 1, 2021
0d90c40
refactor: context manager just around code that fails
PrettyWood May 1, 2021
8b38d16
refactor: add docstring + tweak on `get_sub_types`
PrettyWood May 1, 2021
57d13fd
refactor: move `get_discriminator_values` in `utils.py`
PrettyWood May 1, 2021
8de1cbd
refactor: create `MissingDiscriminator` and `InvalidDiscriminator`
PrettyWood May 1, 2021
4aa037a
refactor: move logic in `_validate_discriminated_union`
PrettyWood May 2, 2021
e0090a3
refactor: remove `DiscriminatedUnionConfig`
PrettyWood May 2, 2021
e323de9
docs: schema/schema_json
PrettyWood May 2, 2021
3be7c09
tests: add tests with other `Literal` types
PrettyWood May 3, 2021
aefcdf1
Merge branch 'master' into f/discriminated-union
PrettyWood May 12, 2021
417601a
Merge branch 'master' into f/discriminated-union
PrettyWood Sep 6, 2021
44c6222
update 3.10
PrettyWood Sep 6, 2021
a5fe444
add schema docstring
PrettyWood Sep 6, 2021
ea1f32a
weird bug on 3.8 with `Literal[None]`
PrettyWood Sep 5, 2021
259255a
bump to view docs & coverage
samuelcolvin Dec 5, 2021
6d98976
bump to prompt tests
samuelcolvin Dec 5, 2021
6f4d437
Merge branch 'master' into f/discriminated-union
PrettyWood Dec 8, 2021
d3d3e0a
Merge branch 'master' into f/discriminated-union
PrettyWood Dec 9, 2021
242717a
move tests in dedicated file
PrettyWood Dec 9, 2021
b8b0ba2
chore: rewording
PrettyWood Dec 9, 2021
901aa01
refactor: replace property by direct slot
PrettyWood Dec 9, 2021
a41b403
refactor: faster check
PrettyWood Dec 9, 2021
d52e777
refactor: missing discriminator
PrettyWood Dec 9, 2021
f20fc57
refactor: change error to ConfigError
PrettyWood Dec 9, 2021
a73225b
refactor: use display_as_type
PrettyWood Dec 9, 2021
4d0c134
fix: mypy
PrettyWood Dec 9, 2021
f8cc585
fix: duplicate
PrettyWood Dec 9, 2021
2b0c378
feat: handle alias
PrettyWood Dec 11, 2021
4f86219
feat: handle nested unions
PrettyWood Dec 11, 2021
b974edb
tweak first example
samuelcolvin Dec 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/619-PrettyWood.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a discriminated union. See [the doc](https://pydantic-docs.helpmanual.io/usage/types/#discriminated-unions) for more information.
30 changes: 30 additions & 0 deletions docs/examples/types_union_discriminated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Literal, Union

from pydantic import BaseModel, Field, ValidationError


class Cat(BaseModel):
pet_type: Literal['cat']
name: str


class Dog(BaseModel):
pet_type: Literal['dog']
name: str


class Lizard(BaseModel):
pet_type: Literal['reptile', 'lizard']
name: str


class Model(BaseModel):
pet: Union[Cat, Dog, Lizard] = Field(..., discriminator='pet_type')
n: int


print(Model.parse_obj({'pet': {'pet_type': 'dog', 'name': 'woof'}, 'n': '1'}))
try:
Model.parse_obj({'pet': {'pet_type': 'dog'}, 'n': '1'})
except ValidationError as e:
print(e)
56 changes: 56 additions & 0 deletions docs/examples/types_union_discriminated_nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, Field, ValidationError


class BlackCat(BaseModel):
pet_type: Literal['cat']
color: Literal['black']
black_name: str


class WhiteCat(BaseModel):
pet_type: Literal['cat']
color: Literal['white']
white_name: str


# Can also be written with a custom root type
#
# class Cat(BaseModel):
# __root__: Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')]

Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')]


class Dog(BaseModel):
pet_type: Literal['dog']
name: str


Pet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')]


class Model(BaseModel):
pet: Pet
n: int


print(
Model.parse_obj(
{
'pet': {'pet_type': 'cat', 'color': 'black', 'black_name': 'felix'},
'n': '1',
}
)
)
try:
Model.parse_obj({'pet': {'pet_type': 'cat', 'color': 'red'}, 'n': '1'})
except ValidationError as e:
print(e)
try:
Model.parse_obj({'pet': {'pet_type': 'cat', 'color': 'black'}, 'n': '1'})
except ValidationError as e:
print(e)
33 changes: 33 additions & 0 deletions docs/usage/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,39 @@ _(This script is complete, it should run "as is")_

See more details in [Required Fields](models.md#required-fields).

#### Discriminated Unions (a.k.a. Tagged Unions)

When `Union` is used with multiple submodels, you sometimes know exactly which submodel needs to
be checked and validated and want to enforce this.
To do that you can set the same field - let's call it `my_discriminator` - in each of the submodels
with a discriminated value, which is one (or many) `Literal` value(s).
For your `Union`, you can set the discriminator in its value: `Field(discriminator='my_discriminator')`.

Setting a discriminated union has many benefits:

- validation is faster since it is only attempted against one model
- only one explicit error is raised in case of failure
- the generated JSON schema implements the [associated OpenAPI specification](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#discriminatorObject)

```py
{!.tmp_examples/types_union_discriminated.py!}
```
_(This script is complete, it should run "as is")_

!!! note
Using the [Annotated Fields syntax](../schema/#typingannotated-fields) can be handy to regroup
the `Union` and `discriminator` information. See below for an example!

#### Nested Discriminated Unions

Only one discriminator can be set for a field but sometimes you want to combine multiple discriminators.
In this case you can always create "intermediate" models with `__root__` and add your discriminator.

```py
{!.tmp_examples/types_union_discriminated_nested.py!}
```
_(This script is complete, it should run "as is")_

### Enums and Choices

*pydantic* uses python's standard `enum` classes to define choices.
Expand Down
2 changes: 2 additions & 0 deletions pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
'parse_file_as',
'parse_obj_as',
'parse_raw_as',
'schema',
'schema_json',
# types
'NoneStr',
'NoneBytes',
Expand Down
115 changes: 114 additions & 1 deletion pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ForwardRef,
NoArgAnyCallable,
NoneType,
all_literal_values,
display_as_type,
get_args,
get_origin,
Expand All @@ -43,7 +44,7 @@
is_typeddict,
new_type_supertype,
)
from .utils import PyObjectStr, Representation, lenient_issubclass, sequence_like, smart_deepcopy
from .utils import ROOT_KEY, PyObjectStr, Representation, lenient_issubclass, sequence_like, smart_deepcopy
from .validators import constant_validator, dict_validator, find_validators, validate_json

Required: Any = Ellipsis
Expand Down Expand Up @@ -103,6 +104,7 @@ class FieldInfo(Representation):
'max_length',
'allow_mutation',
'regex',
'discriminator',
'extra',
)

Expand Down Expand Up @@ -140,6 +142,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.max_length = kwargs.pop('max_length', None)
self.allow_mutation = kwargs.pop('allow_mutation', True)
self.regex = kwargs.pop('regex', None)
self.discriminator = kwargs.pop('discriminator', None)
self.extra = kwargs

def __repr_args__(self) -> 'ReprArgs':
Expand Down Expand Up @@ -192,6 +195,7 @@ def Field(
max_length: int = None,
allow_mutation: bool = True,
regex: str = None,
discriminator: str = None,
**extra: Any,
) -> Any:
"""
Expand Down Expand Up @@ -224,6 +228,8 @@ def Field(
assigned on an instance. The BaseModel Config must set validate_assignment to True
:param regex: only applies to strings, requires the field match agains a regular expression
pattern string. The schema will have a ``pattern`` validation keyword
:param discriminator: only useful with a (discriminated a.k.a. tagged) `Union` of sub models with a common field.
The `discriminator` is the name of this common field to shorten validation and improve generated schema
:param **extra: any additional keyword arguments will be added as is to the schema
"""
field_info = FieldInfo(
Expand All @@ -244,6 +250,7 @@ def Field(
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
**extra,
)
field_info._validate()
Expand Down Expand Up @@ -279,11 +286,23 @@ def Field(
MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING}


class DiscriminatedUnionConfig(Representation):
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved
__slots__ = (
'discriminator_key',
'sub_fields_mapping',
)

def __init__(self, *, discriminator_key: str) -> None:
self.discriminator_key: str = discriminator_key
self.sub_fields_mapping: Dict[str, 'ModelField'] = {}


class ModelField(Representation):
__slots__ = (
'type_',
'outer_type_',
'sub_fields',
'discriminated_union_config',
'key_field',
'validators',
'pre_validators',
Expand Down Expand Up @@ -332,6 +351,7 @@ def __init__(
self.allow_none: bool = False
self.validate_always: bool = False
self.sub_fields: Optional[List[ModelField]] = None
self.discriminated_union_config: Optional[DiscriminatedUnionConfig] = None
self.key_field: Optional[ModelField] = None
self.validators: 'ValidatorsList' = []
self.pre_validators: Optional['ValidatorsList'] = None
Expand Down Expand Up @@ -545,6 +565,12 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
self._type_analysis()
else:
self.sub_fields = [self._create_sub_type(t, f'{self.name}_{display_as_type(t)}') for t in types_]

if self.field_info.discriminator:
self.discriminated_union_config = DiscriminatedUnionConfig(
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved
discriminator_key=self.field_info.discriminator,
)
self.prepare_discriminated_union_sub_fields()
return

if issubclass(origin, Tuple): # type: ignore
Expand Down Expand Up @@ -628,6 +654,28 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
# type_ has been refined eg. as the type of a List and sub_fields needs to be populated
self.sub_fields = [self._create_sub_type(self.type_, '_' + self.name)]

def prepare_discriminated_union_sub_fields(self) -> None:
"""
Prepare the mapping <discriminator key> -> <ModelField> and update `sub_fields`
Note that this process can be aborted if a `ForwardRef` is encountered
"""
assert self.discriminated_union_config is not None
assert self.sub_fields is not None

discriminator_key = self.discriminated_union_config.discriminator_key
sub_fields_mapping: Dict[str, 'ModelField'] = {}

for sub_field in self.sub_fields:
t = sub_field.type_
if t.__class__ is ForwardRef:
# Stopping everything...will need to call `update_forward_refs`
return

for discriminator_value in _get_discriminator_values(t, discriminator_key):
sub_fields_mapping[discriminator_value] = sub_field

self.discriminated_union_config.sub_fields_mapping = sub_fields_mapping
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved

def _create_sub_type(self, type_: Type[Any], name: str, *, for_keys: bool = False) -> 'ModelField':
if for_keys:
class_validators = None
Expand All @@ -645,11 +693,15 @@ def _create_sub_type(self, type_: Type[Any], name: str, *, for_keys: bool = Fals
for k, v in self.class_validators.items()
if v.each_item
}

field_info, _ = self._get_field_info(name, type_, None, self.model_config)

return self.__class__(
type_=type_,
name=name,
class_validators=class_validators,
model_config=self.model_config,
field_info=field_info,
)

def populate_validators(self) -> None:
Expand Down Expand Up @@ -895,6 +947,36 @@ def _validate_singleton(
) -> 'ValidateReturn':
if self.sub_fields:
errors = []

if get_origin(self.type_) is Union and self.discriminated_union_config is not None:
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved
if not self.discriminated_union_config.sub_fields_mapping:
assert cls is not None
raise ConfigError(
f'field "{self.name}" not yet prepared so type is still a ForwardRef, '
f'you might need to call {cls.__name__}.update_forward_refs().'
)

discriminator_key = self.discriminated_union_config.discriminator_key

try:
discriminator_value = v[discriminator_key]
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
except (KeyError, TypeError):
return v, ErrorWrapper(ValueError(f'Discriminator {discriminator_key!r} is missing in value'), loc)

try:
sub_field = self.discriminated_union_config.sub_fields_mapping[discriminator_value]
except KeyError:
allowed_values = ', '.join(map(repr, self.discriminated_union_config.sub_fields_mapping.keys()))
msg_err = (
f'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} '
f'(allowed values: {allowed_values})'
)
return v, ErrorWrapper(ValueError(msg_err), loc)
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved
else:
if not isinstance(loc, tuple):
loc = (loc,)
return sub_field.validate(v, values, loc=(*loc, display_as_type(sub_field.type_)), cls=cls)

for field in self.sub_fields:
value, error = field.validate(v, values, loc=loc, cls=cls)
if error:
Expand Down Expand Up @@ -1007,3 +1089,34 @@ class DeferredType:
"""
Used to postpone field preparation, while creating recursive generic models.
"""


def _get_discriminator_values(tp: Any, discriminator_key: str) -> Tuple[str, ...]:
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved
"""
Get all valid values in the `Literal` type of the discriminator field
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
"""
is_root_model = getattr(tp, '__custom_root_type__', False)

if get_origin(tp) is Annotated:
tp = get_args(tp)[0]

if is_root_model or get_origin(tp) is Union:
union_type = tp.__fields__[ROOT_KEY].type_ if is_root_model else tp

all_values = [_get_discriminator_values(t, discriminator_key) for t in get_args(union_type)]
if len(set(all_values)) > 1:
raise TypeError(f'Field {discriminator_key!r} is not the same for all submodels of {tp.__name__!r}')

return all_values[0]

else:
try:
t_discriminator_type = tp.__fields__[discriminator_key].type_
except KeyError:
raise KeyError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}')

if not is_literal_type(t_discriminator_type):
raise TypeError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')

return all_literal_values(t_discriminator_type)
26 changes: 26 additions & 0 deletions pydantic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
all_literal_values,
get_args,
get_origin,
get_sub_types,
is_callable_type,
is_literal_type,
is_namedtuple,
Expand Down Expand Up @@ -247,6 +248,31 @@ def field_schema(
ref_template=ref_template,
known_models=known_models or set(),
)

# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#discriminator-object
if field.discriminated_union_config is not None:
discriminator_models_refs: Dict[str, Union[str, Dict[str, Any]]] = {}

for discriminator_value, sub_field in field.discriminated_union_config.sub_fields_mapping.items():
# sub_field is either a `BaseModel` or directly an `Annotated` `Union` of many
if get_origin(sub_field.type_) is Union:
sub_models = get_sub_types(sub_field.type_)
discriminator_models_refs[discriminator_value] = {
model_name_map[sub_model]: get_schema_ref(
model_name_map[sub_model], ref_prefix, ref_template, False
)
for sub_model in sub_models
}
else:
discriminator_model_name = model_name_map[sub_field.type_]
discriminator_model_ref = get_schema_ref(discriminator_model_name, ref_prefix, ref_template, False)
discriminator_models_refs[discriminator_value] = discriminator_model_ref['$ref']

s['discriminator'] = {
'propertyName': field.discriminated_union_config.discriminator_key,
'mapping': discriminator_models_refs,
}

# $ref will only be returned when there are no schema_overrides
if '$ref' in f_schema:
return f_schema, f_definitions, f_nested_models
Expand Down