Skip to content

Commit

Permalink
feat: add smart Union
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Apr 7, 2021
1 parent 14f055e commit 7b8918e
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Expand Up @@ -99,7 +99,7 @@ jobs:
CONTEXT: linux-py${{ matrix.python-version }}-compiled-yes-deps-yes

- name: uninstall deps
run: pip uninstall -y cython email-validator devtools python-dotenv
run: pip uninstall -y cython email-validator devtools python-dotenv typingx

- name: test compiled without deps
run: make test
Expand Down
35 changes: 35 additions & 0 deletions pydantic/fields.py
Expand Up @@ -68,6 +68,8 @@ def __deepcopy__(self: T, _: Any) -> T:
Undefined = UndefinedType()

if TYPE_CHECKING:
from typingx.typing_compat import OneOrManyTypes

from .class_validators import ValidatorsList # noqa: F401
from .error_wrappers import ErrorList
from .main import BaseConfig, BaseModel # noqa: F401
Expand Down Expand Up @@ -895,6 +897,39 @@ def _validate_singleton(
) -> 'ValidateReturn':
if self.sub_fields:
errors = []

if get_origin(self.type_) is Union and self.model_config.smart_union:
try:
from typingx import isinstancex

except ImportError:
import warnings

warnings.warn(
'Smart Union will not be able to work with typing types. '
'You should install `typingx` for that.',
UserWarning,
)

def isinstancex(obj: Any, tp: 'OneOrManyTypes') -> bool:
try:
return isinstance(obj, tp)
except TypeError:
return False

# 1st pass: check if the value is an exact instance of one of the Union types
# (e.g. to avoid coercing a bool into an int)
for field in self.sub_fields:
if v.__class__ is field.outer_type_:
return v, None

# 2nd pass: check if the value is an instance of any subclass of the Union types
for field in self.sub_fields:
if isinstancex(v, field.outer_type_):
return v, None

# 1st pass by default or 3rd pass with `smart_union` enabled:
# check if the value can be coerced into one of the Union types
for field in self.sub_fields:
value, error = field.validate(v, values, loc=loc, cls=cls)
if error:
Expand Down
1 change: 1 addition & 0 deletions pydantic/main.py
Expand Up @@ -136,6 +136,7 @@ class BaseConfig:
json_dumps: Callable[..., str] = json.dumps
json_encoders: Dict[Type[Any], AnyCallable] = {}
underscore_attrs_are_private: bool = False
smart_union: bool = False

# Whether or not inherited models as fields should be reconstructed as base model
copy_on_model_validation: bool = True
Expand Down
2 changes: 1 addition & 1 deletion pydantic/version.py
Expand Up @@ -12,7 +12,7 @@ def version_info() -> str:
from .main import compiled

optional_deps = []
for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'):
for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions', 'typingx'):
try:
import_module(p.replace('-', '_'))
except ImportError:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -133,6 +133,7 @@ def extra(self):
extras_require={
'email': ['email-validator>=1.0.3'],
'dotenv': ['python-dotenv>=0.10.4'],
'typingx': ['typingx>=0.5.3'],
},
ext_modules=ext_modules,
entry_points={'hypothesis': ['_ = pydantic._hypothesis_plugin']},
Expand Down
3 changes: 2 additions & 1 deletion tests/requirements-linting.txt
Expand Up @@ -6,4 +6,5 @@ isort==5.8.0
mypy==0.812
pycodestyle==2.7.0
pyflakes==2.3.1
twine==3.4.1
twine==3.4.1
typingx==0.5.3
88 changes: 88 additions & 0 deletions tests/test_types.py
Expand Up @@ -15,13 +15,15 @@
Iterable,
Iterator,
List,
Mapping,
MutableSet,
NewType,
Optional,
Pattern,
Sequence,
Set,
Tuple,
Union,
)
from uuid import UUID

Expand Down Expand Up @@ -77,6 +79,11 @@
except ImportError:
email_validator = None

try:
import typingx
except ImportError:
typingx = None


class ConBytesModel(BaseModel):
v: conbytes(max_length=10) = b'foobar'
Expand Down Expand Up @@ -2774,3 +2781,84 @@ class Model(BaseModel):
{'loc': ('my_none_dict', 'a'), 'msg': 'value is not None', 'type': 'type_error.not_none'},
{'loc': ('my_json_none',), 'msg': 'value is not None', 'type': 'type_error.not_none'},
]


def test_default_union():
class DefaultModel(BaseModel):
v: Union[int, bool, str]

assert DefaultModel(v=True).json() == '{"v": 1}'
assert DefaultModel(v=1).json() == '{"v": 1}'
assert DefaultModel(v='1').json() == '{"v": 1}'

# In 3.6, Union[int, bool, str] == Union[int, str]
allowed_json_types = ('integer', 'string') if sys.version_info[:2] == (3, 6) else ('integer', 'boolean', 'string')

assert DefaultModel.schema() == {
'title': 'DefaultModel',
'type': 'object',
'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in allowed_json_types]}},
'required': ['v'],
}


def test_smart_union():
class SmartModel(BaseModel):
v: Union[int, bool, str]

class Config:
smart_union = True

if typingx is None:
with pytest.warns(UserWarning, match='Smart Union will not be able to work with typing types'):
assert SmartModel(v=1).json() == '{"v": 1}'
assert SmartModel(v=True).json() == '{"v": true}'
assert SmartModel(v='1').json() == '{"v": "1"}'
else:
assert SmartModel(v=1).json() == '{"v": 1}'
assert SmartModel(v=True).json() == '{"v": true}'
assert SmartModel(v='1').json() == '{"v": "1"}'

# In 3.6, Union[int, bool, str] == Union[int, str]
allowed_json_types = ('integer', 'string') if sys.version_info[:2] == (3, 6) else ('integer', 'boolean', 'string')

assert SmartModel.schema() == {
'title': 'SmartModel',
'type': 'object',
'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in allowed_json_types]}},
'required': ['v'],
}


def test_default_union_complex():
class DefaultModel(BaseModel):
values: Union[Dict[str, str], List[str]]

assert DefaultModel(values={'L': '1'}).json() == '{"values": {"L": "1"}}'
assert DefaultModel(values=['L1']).json() == '{"values": {"L": "1"}}' # dict(['L1']) == {'L': '1'}


@pytest.mark.skipif(not typingx, reason='typingx is not installed')
def test_smart_union_complex():
class DefaultModel(BaseModel):
values: Union[Dict[str, str], List[str]]

class Config:
smart_union = True

assert DefaultModel(values={'L': '1'}).json() == '{"values": {"L": "1"}}'
assert DefaultModel(values=['L1']).json() == '{"values": ["L1"]}'
assert DefaultModel(values=('L1',)).json() == '{"values": {"L": "1"}}' # still coerce as tuple is not a list


@pytest.mark.skipif(not typingx, reason='typingx is not installed')
def test_smart_union_complex_2():
class DefaultModel(BaseModel):
values: Union[Mapping[str, str], Sequence[str]]

class Config:
smart_union = True

assert DefaultModel(values={'L': '1'}).json() == '{"values": {"L": "1"}}'
assert DefaultModel(values=['L1']).json() == '{"values": ["L1"]}'
assert DefaultModel(values=('L1',)).json() == '{"values": ["L1"]}'

0 comments on commit 7b8918e

Please sign in to comment.