Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Config.smart_union option #2092

Merged
merged 14 commits into from Dec 7, 2021
1 change: 1 addition & 0 deletions changes/2092-PrettyWood.md
@@ -0,0 +1 @@
add `Config.smart_union` to prevent coercion in `Union` if possible. See [the doc](https://pydantic-docs.helpmanual.io/usage/model_config/#smart-union) for more information
19 changes: 19 additions & 0 deletions docs/examples/model_config_smart_union_off.py
@@ -0,0 +1,19 @@
from typing import Union

from pydantic import BaseModel


class Foo(BaseModel):
pass


class Bar(BaseModel):
pass


class Model(BaseModel):
x: Union[str, int]
y: Union[Foo, Bar]


print(Model(x=1, y=Bar()))
22 changes: 22 additions & 0 deletions docs/examples/model_config_smart_union_on.py
@@ -0,0 +1,22 @@
from typing import Union

from pydantic import BaseModel


class Foo(BaseModel):
pass


class Bar(BaseModel):
pass


class Model(BaseModel):
x: Union[str, int]
y: Union[Foo, Bar]

class Config:
smart_union = True


print(Model(x=1, y=Bar()))
14 changes: 14 additions & 0 deletions docs/examples/model_config_smart_union_on_edge_case.py
@@ -0,0 +1,14 @@
from typing import List, Union

from pydantic import BaseModel


class Model(BaseModel, smart_union=True):
x: Union[List[str], List[int]]


# Expected coercion
print(Model(x=[1, '2']))

# Unexpected coercion
print(Model(x=[1, 2]))
32 changes: 31 additions & 1 deletion docs/usage/model_config.md
Expand Up @@ -113,7 +113,10 @@ not be included in the model schemas. **Note**: this means that attributes on th
: whether to treat any underscore non-class var attrs as private, or leave them as is; See [Private model attributes](models.md#private-model-attributes)

**`copy_on_model_validation`**
: whether or not inherited models used as fields should be reconstructed (copied) on validation instead of being kept untouched (default: `True`)
: whether inherited models used as fields should be reconstructed (copied) on validation instead of being kept untouched (default: `True`)

**`smart_union`**
: whether _pydantic_ should try to check all types inside `Union` to prevent undesired coercion (see [the dedicated section](#smart-union)

## Change behaviour globally

Expand Down Expand Up @@ -164,3 +167,30 @@ For example:
{!.tmp_examples/model_config_alias_precedence.py!}
```
_(This script is complete, it should run "as is")_

## Smart Union

By default, as explained [here](types.md#unions), _pydantic_ tries to validate (and coerce if it can) in the order of the `Union`.
So sometimes you may have unexpected coerced data.

```py
{!.tmp_examples/model_config_smart_union_off.py!}
```
_(This script is complete, it should run "as is")_

To prevent this, you can enable `Config.smart_union`. _Pydantic_ will then check all allowed types before even trying to coerce.
Know that this is of course slower, especially if your `Union` is quite big.

```py
{!.tmp_examples/model_config_smart_union_on.py!}
```
_(This script is complete, it should run "as is")_

!!! warning
Note that this option **does not support compound types yet** (e.g. differentiate `List[int]` and `List[str]`).
This option will be improved further once a strict mode is added in _pydantic_ and will probably be the default behaviour in v2!

```py
{!.tmp_examples/model_config_smart_union_on_edge_case.py!}
```
_(This script is complete, it should run "as is")_
20 changes: 16 additions & 4 deletions docs/usage/types.md
Expand Up @@ -234,21 +234,33 @@ _(This script is complete, it should run "as is")_

The `Union` type allows a model attribute to accept different types, e.g.:

!!! warning
This script is complete, it should run "as is". However, it may not reflect the desired behavior; see below.
!!! info
You may get unexpected coercion with `Union`; see below.<br />
Know that you can also make the check slower but stricter by using [Smart Union](model_config.md#smart-union)

```py
{!.tmp_examples/types_union_incorrect.py!}
```
_(This script is complete, it should run "as is")_

However, as can be seen above, *pydantic* will attempt to 'match' any of the types defined under `Union` and will use
the first one that matches. In the above example the `id` of `user_03` was defined as a `uuid.UUID` class (which
is defined under the attribute's `Union` annotation) but as the `uuid.UUID` can be marshalled into an `int` it
chose to match against the `int` type and disregarded the other types.

!!! warning
`typing.Union` also ignores order when [defined](https://docs.python.org/3/library/typing.html#typing.Union),
so `Union[int, float] == Union[float, int]` which can lead to unexpected behaviour
when combined with matching based on the `Union` type order inside other type definitions, such as `List` and `Dict`
types (because python treats these definitions as singletons).
For example, `Dict[str, Union[int, float]] == Dict[str, Union[float, int]]` with the order based on the first time it was defined.
Please note that this can also be [affected by third party libraries](https://github.com/samuelcolvin/pydantic/issues/2835)
and their internal type definitions and the import orders.

As such, it is recommended that, when defining `Union` annotations, the most specific type is included first and
followed by less specific types. In the above example, the `UUID` class should precede the `int` and `str`
classes to preclude the unexpected representation as such:
followed by less specific types.

In the above example, the `UUID` class should precede the `int` and `str` classes to preclude the unexpected representation as such:

```py
{!.tmp_examples/types_union_correct.py!}
Expand Down
4 changes: 3 additions & 1 deletion pydantic/config.py
Expand Up @@ -63,8 +63,10 @@ class BaseConfig:
json_encoders: Dict[Type[Any], AnyCallable] = {}
underscore_attrs_are_private: bool = False

# Whether or not inherited models as fields should be reconstructed as base model
# whether inherited models as fields should be reconstructed as base model
copy_on_model_validation: bool = True
# whether `Union` should check all allowed types before even trying to coerce
smart_union: bool = False

@classmethod
def get_field_info(cls, name: str) -> Dict[str, Any]:
Expand Down
6 changes: 2 additions & 4 deletions pydantic/env_settings.py
Expand Up @@ -6,7 +6,7 @@
from .config import BaseConfig, Extra
from .fields import ModelField
from .main import BaseModel
from .typing import StrPath, display_as_type, get_origin, is_union_origin
from .typing import StrPath, display_as_type, get_origin, is_union
from .utils import deep_update, path_type, sequence_like

env_file_sentinel = str(object())
Expand Down Expand Up @@ -175,9 +175,7 @@ def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
except ValueError as e:
raise SettingsError(f'error parsing JSON for "{env_name}"') from e
elif (
is_union_origin(get_origin(field.type_))
and field.sub_fields
and any(f.is_complex() for f in field.sub_fields)
is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields)
):
try:
env_val = settings.__config__.json_loads(env_val)
Expand Down
33 changes: 31 additions & 2 deletions pydantic/fields.py
Expand Up @@ -42,7 +42,7 @@
is_new_type,
is_none_type,
is_typeddict,
is_union_origin,
is_union,
new_type_supertype,
)
from .utils import PyObjectStr, Representation, ValueItems, lenient_issubclass, sequence_like, smart_deepcopy
Expand Down Expand Up @@ -560,7 +560,7 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
return
if origin is Callable:
return
if is_union_origin(origin):
if is_union(origin):
types_ = []
for type_ in get_args(self.type_):
if type_ is NoneType:
Expand Down Expand Up @@ -935,6 +935,35 @@ def _validate_singleton(
) -> 'ValidateReturn':
if self.sub_fields:
errors = []

if self.model_config.smart_union and is_union(get_origin(self.type_)):
# 1st pass: check if the value is an exact instance of one of the Union types
# (e.g. to avoid coercing a bool into an int)
for field in self.sub_fields:
if v.__class__ is field.outer_type_:
return v, None

# 2nd pass: check if the value is an instance of any subclass of the Union types
for field in self.sub_fields:
# This whole logic will be improved later on to support more complex `isinstance` checks
# It will probably be done once a strict mode is added and be something like:
# ```
# value, error = field.validate(v, values, strict=True)
# if error is None:
# return value, None
# ```
try:
if isinstance(v, field.outer_type_):
return v, None
except TypeError:
# compound type
if isinstance(v, get_origin(field.outer_type_)):
value, error = field.validate(v, values, loc=loc, cls=cls)
if not error:
return value, None

# 1st pass by default or 3rd pass with `smart_union` enabled:
# check if the value can be coerced into one of the Union types
for field in self.sub_fields:
value, error = field.validate(v, values, loc=loc, cls=cls)
if error:
Expand Down
4 changes: 2 additions & 2 deletions pydantic/main.py
Expand Up @@ -39,7 +39,7 @@
get_origin,
is_classvar,
is_namedtuple,
is_union_origin,
is_union,
resolve_annotations,
update_model_forward_refs,
)
Expand Down Expand Up @@ -191,7 +191,7 @@ def is_untouched(v: Any) -> bool:
elif is_valid_field(ann_name):
validate_field_name(bases, ann_name)
value = namespace.get(ann_name, Undefined)
allowed_types = get_args(ann_type) if is_union_origin(get_origin(ann_type)) else (ann_type,)
allowed_types = get_args(ann_type) if is_union(get_origin(ann_type)) else (ann_type,)
if (
is_untouched(value)
and ann_type != PyObject
Expand Down
4 changes: 2 additions & 2 deletions pydantic/schema.py
Expand Up @@ -72,7 +72,7 @@
is_literal_type,
is_namedtuple,
is_none_type,
is_union_origin,
is_union,
)
from .utils import ROOT_KEY, get_model, lenient_issubclass, sequence_like

Expand Down Expand Up @@ -995,7 +995,7 @@ def go(type_: Any) -> Type[Any]:

if origin is Annotated:
return go(args[0])
if is_union_origin(origin):
if is_union(origin):
return Union[tuple(go(a) for a in args)] # type: ignore

if issubclass(origin, List) and (field_info.min_items is not None or field_info.max_items is not None):
Expand Down
8 changes: 4 additions & 4 deletions pydantic/typing.py
Expand Up @@ -193,7 +193,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:

if sys.version_info < (3, 10):

def is_union_origin(tp: Type[Any]) -> bool:
def is_union(tp: Type[Any]) -> bool:
return tp is Union

WithArgsTypes = (TypingGenericAlias,)
Expand All @@ -202,8 +202,8 @@ def is_union_origin(tp: Type[Any]) -> bool:
import types
import typing

def is_union_origin(origin: Type[Any]) -> bool:
return origin is Union or origin is types.UnionType # noqa: E721
def is_union(tp: Type[Any]) -> bool:
return tp is Union or tp is types.UnionType # noqa: E721

WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType)

Expand Down Expand Up @@ -269,7 +269,7 @@ def is_union_origin(origin: Type[Any]) -> bool:
'get_origin',
'typing_base',
'get_all_type_hints',
'is_union_origin',
'is_union',
'StrPath',
)

Expand Down