Skip to content

Commit

Permalink
feat: add Config.smart_union option (#2092)
Browse files Browse the repository at this point in the history
* feat: add `Config.smart_union` to prevent coercion in `Union` if possible

* docs: write some documentation

* Update docs/usage/model_config.md

Thanks @djpugh

Co-authored-by: David J Pugh <6003255+djpugh@users.noreply.github.com>

* improve doc

* support 3.10

* improve smart_union

* Update docs/usage/types.md

Co-authored-by: David J Pugh <6003255+djpugh@users.noreply.github.com>

* put new sentence inside warning block

* docs: reorder

* rename is_union_origin into is_union

* inverse and condition for perf

* fix doc

Co-authored-by: David J Pugh <6003255+djpugh@users.noreply.github.com>
  • Loading branch information
PrettyWood and djpugh committed Dec 7, 2021
1 parent eef4ac5 commit c38c463
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 20 deletions.
1 change: 1 addition & 0 deletions changes/2092-PrettyWood.md
@@ -0,0 +1 @@
add `Config.smart_union` to prevent coercion in `Union` if possible. See [the doc](https://pydantic-docs.helpmanual.io/usage/model_config/#smart-union) for more information
19 changes: 19 additions & 0 deletions docs/examples/model_config_smart_union_off.py
@@ -0,0 +1,19 @@
from typing import Union

from pydantic import BaseModel


class Foo(BaseModel):
pass


class Bar(BaseModel):
pass


class Model(BaseModel):
x: Union[str, int]
y: Union[Foo, Bar]


print(Model(x=1, y=Bar()))
22 changes: 22 additions & 0 deletions docs/examples/model_config_smart_union_on.py
@@ -0,0 +1,22 @@
from typing import Union

from pydantic import BaseModel


class Foo(BaseModel):
pass


class Bar(BaseModel):
pass


class Model(BaseModel):
x: Union[str, int]
y: Union[Foo, Bar]

class Config:
smart_union = True


print(Model(x=1, y=Bar()))
14 changes: 14 additions & 0 deletions docs/examples/model_config_smart_union_on_edge_case.py
@@ -0,0 +1,14 @@
from typing import List, Union

from pydantic import BaseModel


class Model(BaseModel, smart_union=True):
x: Union[List[str], List[int]]


# Expected coercion
print(Model(x=[1, '2']))

# Unexpected coercion
print(Model(x=[1, 2]))
32 changes: 31 additions & 1 deletion docs/usage/model_config.md
Expand Up @@ -113,7 +113,10 @@ not be included in the model schemas. **Note**: this means that attributes on th
: whether to treat any underscore non-class var attrs as private, or leave them as is; See [Private model attributes](models.md#private-model-attributes)

**`copy_on_model_validation`**
: whether or not inherited models used as fields should be reconstructed (copied) on validation instead of being kept untouched (default: `True`)
: whether inherited models used as fields should be reconstructed (copied) on validation instead of being kept untouched (default: `True`)

**`smart_union`**
: whether _pydantic_ should try to check all types inside `Union` to prevent undesired coercion (see [the dedicated section](#smart-union)

## Change behaviour globally

Expand Down Expand Up @@ -164,3 +167,30 @@ For example:
{!.tmp_examples/model_config_alias_precedence.py!}
```
_(This script is complete, it should run "as is")_

## Smart Union

By default, as explained [here](types.md#unions), _pydantic_ tries to validate (and coerce if it can) in the order of the `Union`.
So sometimes you may have unexpected coerced data.

```py
{!.tmp_examples/model_config_smart_union_off.py!}
```
_(This script is complete, it should run "as is")_

To prevent this, you can enable `Config.smart_union`. _Pydantic_ will then check all allowed types before even trying to coerce.
Know that this is of course slower, especially if your `Union` is quite big.

```py
{!.tmp_examples/model_config_smart_union_on.py!}
```
_(This script is complete, it should run "as is")_

!!! warning
Note that this option **does not support compound types yet** (e.g. differentiate `List[int]` and `List[str]`).
This option will be improved further once a strict mode is added in _pydantic_ and will probably be the default behaviour in v2!

```py
{!.tmp_examples/model_config_smart_union_on_edge_case.py!}
```
_(This script is complete, it should run "as is")_
20 changes: 16 additions & 4 deletions docs/usage/types.md
Expand Up @@ -234,21 +234,33 @@ _(This script is complete, it should run "as is")_

The `Union` type allows a model attribute to accept different types, e.g.:

!!! warning
This script is complete, it should run "as is". However, it may not reflect the desired behavior; see below.
!!! info
You may get unexpected coercion with `Union`; see below.<br />
Know that you can also make the check slower but stricter by using [Smart Union](model_config.md#smart-union)

```py
{!.tmp_examples/types_union_incorrect.py!}
```
_(This script is complete, it should run "as is")_

However, as can be seen above, *pydantic* will attempt to 'match' any of the types defined under `Union` and will use
the first one that matches. In the above example the `id` of `user_03` was defined as a `uuid.UUID` class (which
is defined under the attribute's `Union` annotation) but as the `uuid.UUID` can be marshalled into an `int` it
chose to match against the `int` type and disregarded the other types.

!!! warning
`typing.Union` also ignores order when [defined](https://docs.python.org/3/library/typing.html#typing.Union),
so `Union[int, float] == Union[float, int]` which can lead to unexpected behaviour
when combined with matching based on the `Union` type order inside other type definitions, such as `List` and `Dict`
types (because python treats these definitions as singletons).
For example, `Dict[str, Union[int, float]] == Dict[str, Union[float, int]]` with the order based on the first time it was defined.
Please note that this can also be [affected by third party libraries](https://github.com/samuelcolvin/pydantic/issues/2835)
and their internal type definitions and the import orders.

As such, it is recommended that, when defining `Union` annotations, the most specific type is included first and
followed by less specific types. In the above example, the `UUID` class should precede the `int` and `str`
classes to preclude the unexpected representation as such:
followed by less specific types.

In the above example, the `UUID` class should precede the `int` and `str` classes to preclude the unexpected representation as such:

```py
{!.tmp_examples/types_union_correct.py!}
Expand Down
4 changes: 3 additions & 1 deletion pydantic/config.py
Expand Up @@ -62,8 +62,10 @@ class BaseConfig:
json_encoders: Dict[Type[Any], AnyCallable] = {}
underscore_attrs_are_private: bool = False

# Whether or not inherited models as fields should be reconstructed as base model
# whether inherited models as fields should be reconstructed as base model
copy_on_model_validation: bool = True
# whether `Union` should check all allowed types before even trying to coerce
smart_union: bool = False

@classmethod
def get_field_info(cls, name: str) -> Dict[str, Any]:
Expand Down
6 changes: 2 additions & 4 deletions pydantic/env_settings.py
Expand Up @@ -6,7 +6,7 @@
from .config import BaseConfig, Extra
from .fields import ModelField
from .main import BaseModel
from .typing import StrPath, display_as_type, get_origin, is_union_origin
from .typing import StrPath, display_as_type, get_origin, is_union
from .utils import deep_update, path_type, sequence_like

env_file_sentinel = str(object())
Expand Down Expand Up @@ -175,9 +175,7 @@ def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
except ValueError as e:
raise SettingsError(f'error parsing JSON for "{env_name}"') from e
elif (
is_union_origin(get_origin(field.type_))
and field.sub_fields
and any(f.is_complex() for f in field.sub_fields)
is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields)
):
try:
env_val = settings.__config__.json_loads(env_val)
Expand Down
33 changes: 31 additions & 2 deletions pydantic/fields.py
Expand Up @@ -42,7 +42,7 @@
is_new_type,
is_none_type,
is_typeddict,
is_union_origin,
is_union,
new_type_supertype,
)
from .utils import PyObjectStr, Representation, ValueItems, lenient_issubclass, sequence_like, smart_deepcopy
Expand Down Expand Up @@ -560,7 +560,7 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
return
if origin is Callable:
return
if is_union_origin(origin):
if is_union(origin):
types_ = []
for type_ in get_args(self.type_):
if type_ is NoneType:
Expand Down Expand Up @@ -935,6 +935,35 @@ def _validate_singleton(
) -> 'ValidateReturn':
if self.sub_fields:
errors = []

if self.model_config.smart_union and is_union(get_origin(self.type_)):
# 1st pass: check if the value is an exact instance of one of the Union types
# (e.g. to avoid coercing a bool into an int)
for field in self.sub_fields:
if v.__class__ is field.outer_type_:
return v, None

# 2nd pass: check if the value is an instance of any subclass of the Union types
for field in self.sub_fields:
# This whole logic will be improved later on to support more complex `isinstance` checks
# It will probably be done once a strict mode is added and be something like:
# ```
# value, error = field.validate(v, values, strict=True)
# if error is None:
# return value, None
# ```
try:
if isinstance(v, field.outer_type_):
return v, None
except TypeError:
# compound type
if isinstance(v, get_origin(field.outer_type_)):
value, error = field.validate(v, values, loc=loc, cls=cls)
if not error:
return value, None

# 1st pass by default or 3rd pass with `smart_union` enabled:
# check if the value can be coerced into one of the Union types
for field in self.sub_fields:
value, error = field.validate(v, values, loc=loc, cls=cls)
if error:
Expand Down
4 changes: 2 additions & 2 deletions pydantic/main.py
Expand Up @@ -39,7 +39,7 @@
get_origin,
is_classvar,
is_namedtuple,
is_union_origin,
is_union,
resolve_annotations,
update_model_forward_refs,
)
Expand Down Expand Up @@ -191,7 +191,7 @@ def is_untouched(v: Any) -> bool:
elif is_valid_field(ann_name):
validate_field_name(bases, ann_name)
value = namespace.get(ann_name, Undefined)
allowed_types = get_args(ann_type) if is_union_origin(get_origin(ann_type)) else (ann_type,)
allowed_types = get_args(ann_type) if is_union(get_origin(ann_type)) else (ann_type,)
if (
is_untouched(value)
and ann_type != PyObject
Expand Down
4 changes: 2 additions & 2 deletions pydantic/schema.py
Expand Up @@ -72,7 +72,7 @@
is_literal_type,
is_namedtuple,
is_none_type,
is_union_origin,
is_union,
)
from .utils import ROOT_KEY, get_model, lenient_issubclass, sequence_like

Expand Down Expand Up @@ -994,7 +994,7 @@ def go(type_: Any) -> Type[Any]:

if origin is Annotated:
return go(args[0])
if is_union_origin(origin):
if is_union(origin):
return Union[tuple(go(a) for a in args)] # type: ignore

if issubclass(origin, List) and (field_info.min_items is not None or field_info.max_items is not None):
Expand Down
8 changes: 4 additions & 4 deletions pydantic/typing.py
Expand Up @@ -188,7 +188,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:

if sys.version_info < (3, 10):

def is_union_origin(tp: Type[Any]) -> bool:
def is_union(tp: Type[Any]) -> bool:
return tp is Union

WithArgsTypes = (TypingGenericAlias,)
Expand All @@ -197,8 +197,8 @@ def is_union_origin(tp: Type[Any]) -> bool:
import types
import typing

def is_union_origin(origin: Type[Any]) -> bool:
return origin is Union or origin is types.UnionType # noqa: E721
def is_union(tp: Type[Any]) -> bool:
return tp is Union or tp is types.UnionType # noqa: E721

WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType)

Expand Down Expand Up @@ -264,7 +264,7 @@ def is_union_origin(origin: Type[Any]) -> bool:
'get_origin',
'typing_base',
'get_all_type_hints',
'is_union_origin',
'is_union',
'StrPath',
)

Expand Down

0 comments on commit c38c463

Please sign in to comment.