Skip to content

Commit

Permalink
feat: support | union operator properly
Browse files Browse the repository at this point in the history
`|` operator has origin `types.Union` (and not `typing.Union`)
  • Loading branch information
PrettyWood committed Jun 6, 2021
1 parent f45a02f commit c9e0604
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_literal_type,
is_new_type,
is_typeddict,
is_union,
new_type_supertype,
)
from .utils import PyObjectStr, Representation, ValueItems, lenient_issubclass, sequence_like, smart_deepcopy
Expand Down Expand Up @@ -556,7 +557,7 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
return
if origin is Callable:
return
if origin is Union:
if is_union(origin):
types_ = []
for type_ in get_args(self.type_):
if type_ is NoneType:
Expand Down
3 changes: 2 additions & 1 deletion pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_origin,
is_classvar,
is_namedtuple,
is_union,
resolve_annotations,
update_field_forward_refs,
)
Expand Down Expand Up @@ -175,7 +176,7 @@ def is_untouched(v: Any) -> bool:
elif is_valid_field(ann_name):
validate_field_name(bases, ann_name)
value = namespace.get(ann_name, Undefined)
allowed_types = get_args(ann_type) if get_origin(ann_type) is Union else (ann_type,)
allowed_types = get_args(ann_type) if is_union(get_origin(ann_type)) else (ann_type,)
if (
is_untouched(value)
and ann_type != PyObject
Expand Down
3 changes: 2 additions & 1 deletion pydantic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
is_callable_type,
is_literal_type,
is_namedtuple,
is_union,
)
from .utils import ROOT_KEY, get_model, lenient_issubclass, sequence_like

Expand Down Expand Up @@ -965,7 +966,7 @@ def go(type_: Any) -> Type[Any]:

if origin is Annotated:
return go(args[0])
if origin is Union:
if is_union(origin):
return Union[tuple(go(a) for a in args)] # type: ignore

if issubclass(origin, List) and (field_info.min_items is not None or field_info.max_items is not None):
Expand Down
14 changes: 14 additions & 0 deletions pydantic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp)


if sys.version_info < (3, 10):

def is_union(tp: Type[Any]) -> bool:
return tp is Union


else:
import types

def is_union(tp: Type[Any]) -> bool:
return tp is Union or tp is types.Union


if TYPE_CHECKING:
from .fields import ModelField

Expand Down Expand Up @@ -238,6 +251,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
'get_origin',
'typing_base',
'get_all_type_hints',
'is_union',
)


Expand Down
18 changes: 18 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,3 +2026,21 @@ class Model(Base, some_config='new_value'):
a: int

assert Model.__config__.some_config == 'new_value'


@pytest.mark.skipif(sys.version_info < (3, 10), reason='need 3.10 version')
def test_new_union_origin():
"""On 3.10+, origin of `int | str` is `types.Union`, not `typing.Union`"""

class Model(BaseModel):
x: int | str

assert Model(x=3).x == 3
assert Model(x='3').x == 3
assert Model(x='pika').x == 'pika'
assert Model.schema() == {
'title': 'Model',
'type': 'object',
'properties': {'x': {'title': 'X', 'anyOf': [{'type': 'integer'}, {'type': 'string'}]}},
'required': ['x'],
}

0 comments on commit c9e0604

Please sign in to comment.