From 7b8918e5080add32ac0229d71589931fcca8d669 Mon Sep 17 00:00:00 2001 From: PrettyWood Date: Wed, 20 Jan 2021 13:37:41 +0100 Subject: [PATCH] feat: add smart Union --- .github/workflows/ci.yml | 2 +- pydantic/fields.py | 35 ++++++++++++++ pydantic/main.py | 1 + pydantic/version.py | 2 +- setup.py | 1 + tests/requirements-linting.txt | 3 +- tests/test_types.py | 88 ++++++++++++++++++++++++++++++++++ 7 files changed, 129 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5432bef3462..65545666ba3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -99,7 +99,7 @@ jobs: CONTEXT: linux-py${{ matrix.python-version }}-compiled-yes-deps-yes - name: uninstall deps - run: pip uninstall -y cython email-validator devtools python-dotenv + run: pip uninstall -y cython email-validator devtools python-dotenv typingx - name: test compiled without deps run: make test diff --git a/pydantic/fields.py b/pydantic/fields.py index 449dd55d4c2..6dcb95803cd 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -68,6 +68,8 @@ def __deepcopy__(self: T, _: Any) -> T: Undefined = UndefinedType() if TYPE_CHECKING: + from typingx.typing_compat import OneOrManyTypes + from .class_validators import ValidatorsList # noqa: F401 from .error_wrappers import ErrorList from .main import BaseConfig, BaseModel # noqa: F401 @@ -895,6 +897,39 @@ def _validate_singleton( ) -> 'ValidateReturn': if self.sub_fields: errors = [] + + if get_origin(self.type_) is Union and self.model_config.smart_union: + try: + from typingx import isinstancex + + except ImportError: + import warnings + + warnings.warn( + 'Smart Union will not be able to work with typing types. ' + 'You should install `typingx` for that.', + UserWarning, + ) + + def isinstancex(obj: Any, tp: 'OneOrManyTypes') -> bool: + try: + return isinstance(obj, tp) + except TypeError: + return False + + # 1st pass: check if the value is an exact instance of one of the Union types + # (e.g. to avoid coercing a bool into an int) + for field in self.sub_fields: + if v.__class__ is field.outer_type_: + return v, None + + # 2nd pass: check if the value is an instance of any subclass of the Union types + for field in self.sub_fields: + if isinstancex(v, field.outer_type_): + return v, None + + # 1st pass by default or 3rd pass with `smart_union` enabled: + # check if the value can be coerced into one of the Union types for field in self.sub_fields: value, error = field.validate(v, values, loc=loc, cls=cls) if error: diff --git a/pydantic/main.py b/pydantic/main.py index f6aca41048f..d1d175221da 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -136,6 +136,7 @@ class BaseConfig: json_dumps: Callable[..., str] = json.dumps json_encoders: Dict[Type[Any], AnyCallable] = {} underscore_attrs_are_private: bool = False + smart_union: bool = False # Whether or not inherited models as fields should be reconstructed as base model copy_on_model_validation: bool = True diff --git a/pydantic/version.py b/pydantic/version.py index 0614cbe3627..9cfad87748f 100644 --- a/pydantic/version.py +++ b/pydantic/version.py @@ -12,7 +12,7 @@ def version_info() -> str: from .main import compiled optional_deps = [] - for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'): + for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions', 'typingx'): try: import_module(p.replace('-', '_')) except ImportError: diff --git a/setup.py b/setup.py index 52baae27895..6a19bb76dc8 100644 --- a/setup.py +++ b/setup.py @@ -133,6 +133,7 @@ def extra(self): extras_require={ 'email': ['email-validator>=1.0.3'], 'dotenv': ['python-dotenv>=0.10.4'], + 'typingx': ['typingx>=0.5.3'], }, ext_modules=ext_modules, entry_points={'hypothesis': ['_ = pydantic._hypothesis_plugin']}, diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 57221b6a5b9..182e012bd6e 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -6,4 +6,5 @@ isort==5.8.0 mypy==0.812 pycodestyle==2.7.0 pyflakes==2.3.1 -twine==3.4.1 \ No newline at end of file +twine==3.4.1 +typingx==0.5.3 \ No newline at end of file diff --git a/tests/test_types.py b/tests/test_types.py index 4b6ef724094..e881de7b19c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -15,6 +15,7 @@ Iterable, Iterator, List, + Mapping, MutableSet, NewType, Optional, @@ -22,6 +23,7 @@ Sequence, Set, Tuple, + Union, ) from uuid import UUID @@ -77,6 +79,11 @@ except ImportError: email_validator = None +try: + import typingx +except ImportError: + typingx = None + class ConBytesModel(BaseModel): v: conbytes(max_length=10) = b'foobar' @@ -2774,3 +2781,84 @@ class Model(BaseModel): {'loc': ('my_none_dict', 'a'), 'msg': 'value is not None', 'type': 'type_error.not_none'}, {'loc': ('my_json_none',), 'msg': 'value is not None', 'type': 'type_error.not_none'}, ] + + +def test_default_union(): + class DefaultModel(BaseModel): + v: Union[int, bool, str] + + assert DefaultModel(v=True).json() == '{"v": 1}' + assert DefaultModel(v=1).json() == '{"v": 1}' + assert DefaultModel(v='1').json() == '{"v": 1}' + + # In 3.6, Union[int, bool, str] == Union[int, str] + allowed_json_types = ('integer', 'string') if sys.version_info[:2] == (3, 6) else ('integer', 'boolean', 'string') + + assert DefaultModel.schema() == { + 'title': 'DefaultModel', + 'type': 'object', + 'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in allowed_json_types]}}, + 'required': ['v'], + } + + +def test_smart_union(): + class SmartModel(BaseModel): + v: Union[int, bool, str] + + class Config: + smart_union = True + + if typingx is None: + with pytest.warns(UserWarning, match='Smart Union will not be able to work with typing types'): + assert SmartModel(v=1).json() == '{"v": 1}' + assert SmartModel(v=True).json() == '{"v": true}' + assert SmartModel(v='1').json() == '{"v": "1"}' + else: + assert SmartModel(v=1).json() == '{"v": 1}' + assert SmartModel(v=True).json() == '{"v": true}' + assert SmartModel(v='1').json() == '{"v": "1"}' + + # In 3.6, Union[int, bool, str] == Union[int, str] + allowed_json_types = ('integer', 'string') if sys.version_info[:2] == (3, 6) else ('integer', 'boolean', 'string') + + assert SmartModel.schema() == { + 'title': 'SmartModel', + 'type': 'object', + 'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in allowed_json_types]}}, + 'required': ['v'], + } + + +def test_default_union_complex(): + class DefaultModel(BaseModel): + values: Union[Dict[str, str], List[str]] + + assert DefaultModel(values={'L': '1'}).json() == '{"values": {"L": "1"}}' + assert DefaultModel(values=['L1']).json() == '{"values": {"L": "1"}}' # dict(['L1']) == {'L': '1'} + + +@pytest.mark.skipif(not typingx, reason='typingx is not installed') +def test_smart_union_complex(): + class DefaultModel(BaseModel): + values: Union[Dict[str, str], List[str]] + + class Config: + smart_union = True + + assert DefaultModel(values={'L': '1'}).json() == '{"values": {"L": "1"}}' + assert DefaultModel(values=['L1']).json() == '{"values": ["L1"]}' + assert DefaultModel(values=('L1',)).json() == '{"values": {"L": "1"}}' # still coerce as tuple is not a list + + +@pytest.mark.skipif(not typingx, reason='typingx is not installed') +def test_smart_union_complex_2(): + class DefaultModel(BaseModel): + values: Union[Mapping[str, str], Sequence[str]] + + class Config: + smart_union = True + + assert DefaultModel(values={'L': '1'}).json() == '{"values": {"L": "1"}}' + assert DefaultModel(values=['L1']).json() == '{"values": ["L1"]}' + assert DefaultModel(values=('L1',)).json() == '{"values": ["L1"]}'