diff --git a/changes/1364-DBCerigo.md b/changes/1364-DBCerigo.md new file mode 100644 index 0000000000..ed31901e41 --- /dev/null +++ b/changes/1364-DBCerigo.md @@ -0,0 +1 @@ +Fix model validation to handle nested literals, e.g. `Literal['foo', Literal['bar']]`. diff --git a/pydantic/typing.py b/pydantic/typing.py index 163185f0c3..ed8c23c71f 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -179,6 +179,19 @@ def literal_values(type_: AnyType) -> Tuple[Any, ...]: return type_.__values__ +def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]: + """ + This method is used to retrieve all Literal values as + Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586) + e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]` + """ + if not is_literal_type(type_): + return (type_,) + + values = literal_values(type_) + return tuple(x for value in values for x in all_literal_values(value)) + + test_type = NewType('test_type', str) diff --git a/pydantic/validators.py b/pydantic/validators.py index 7249849dc1..7149f235aa 100644 --- a/pydantic/validators.py +++ b/pydantic/validators.py @@ -1,5 +1,4 @@ import re -import sys from collections import OrderedDict from collections.abc import Hashable from datetime import date, datetime, time, timedelta @@ -26,7 +25,16 @@ from . import errors from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time -from .typing import AnyCallable, AnyType, ForwardRef, display_as_type, get_class, is_callable_type, is_literal_type +from .typing import ( + AnyCallable, + AnyType, + ForwardRef, + all_literal_values, + display_as_type, + get_class, + is_callable_type, + is_literal_type, +) from .utils import almost_equal_floats, lenient_issubclass, sequence_like if TYPE_CHECKING: @@ -394,10 +402,7 @@ def callable_validator(v: Any) -> AnyCallable: def make_literal_validator(type_: Any) -> Callable[[Any], Any]: - if sys.version_info >= (3, 7): - permitted_choices = type_.__args__ - else: - permitted_choices = type_.__values__ + permitted_choices = all_literal_values(type_) allowed_choices_set = set(permitted_choices) def literal_validator(v: Any) -> Any: diff --git a/tests/test_utils.py b/tests/test_utils.py index feb373ce60..5dba3ebc48 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,7 +11,7 @@ from pydantic.color import Color from pydantic.dataclasses import dataclass from pydantic.fields import Undefined -from pydantic.typing import display_as_type, is_new_type, new_type_supertype +from pydantic.typing import Literal, all_literal_values, display_as_type, is_new_type, new_type_supertype from pydantic.utils import ( ClassAttribute, ValueItems, @@ -320,3 +320,16 @@ class Foo: f = Foo() f.attr = 'not foo' assert f.attr == 'not foo' + + +@pytest.mark.skipif(not Literal, reason='typing_extensions not installed') +def test_all_literal_values(): + L1 = Literal['1'] + assert all_literal_values(L1) == ('1',) + + L2 = Literal['2'] + L12 = Literal[L1, L2] + assert sorted(all_literal_values(L12)) == sorted(('1', '2')) + + L312 = Literal['3', Literal[L1, L2]] + assert sorted(all_literal_values(L312)) == sorted(('1', '2', '3')) diff --git a/tests/test_validators.py b/tests/test_validators.py index 386fbbeb77..2ff9359765 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigError, Extra, ValidationError, errors, validator from pydantic.class_validators import make_generic_validator, root_validator +from pydantic.typing import Literal def test_simple(): @@ -1038,3 +1039,44 @@ def check_foo(cls, value): m = Model(name='hello') m.name = 'goodbye' assert validator_calls == 2 + + +@pytest.mark.skipif(not Literal, reason='typing_extensions not installed') +def test_literal_validator(): + class Model(BaseModel): + a: Literal['foo'] + + Model(a='foo') + + with pytest.raises(ValidationError) as exc_info: + Model(a='nope') + assert exc_info.value.errors() == [ + { + 'loc': ('a',), + 'msg': "unexpected value; permitted: 'foo'", + 'type': 'value_error.const', + 'ctx': {'given': 'nope', 'permitted': ('foo',)}, + } + ] + + +@pytest.mark.skipif(not Literal, reason='typing_extensions not installed') +def test_nested_literal_validator(): + L1 = Literal['foo'] + L2 = Literal['bar'] + + class Model(BaseModel): + a: Literal[L1, L2] + + Model(a='foo') + + with pytest.raises(ValidationError) as exc_info: + Model(a='nope') + assert exc_info.value.errors() == [ + { + 'loc': ('a',), + 'msg': "unexpected value; permitted: 'foo', 'bar'", + 'type': 'value_error.const', + 'ctx': {'given': 'nope', 'permitted': ('foo', 'bar')}, + } + ]