diff --git a/pydantic/typing.py b/pydantic/typing.py index 684c6fd02b3..31c8f74ea21 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -177,6 +177,19 @@ def literal_values(type_: AnyType) -> Tuple[Any, ...]: return type_.__values__ +def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]: + """ + This method is used to retrieve all Literal values as + Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586) + e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]` + """ + if not is_literal_type(type_): + return (type_,) + + values = literal_values(type_) + return tuple(x for value in values for x in all_literal_values(value)) + + test_type = NewType('test_type', str) diff --git a/pydantic/validators.py b/pydantic/validators.py index d36e16de7c4..0a45cbab4be 100644 --- a/pydantic/validators.py +++ b/pydantic/validators.py @@ -28,11 +28,11 @@ AnyCallable, AnyType, ForwardRef, + all_literal_values, display_as_type, get_class, is_callable_type, is_literal_type, - literal_values, ) from .utils import almost_equal_floats, lenient_issubclass, sequence_like @@ -393,18 +393,8 @@ def callable_validator(v: Any) -> AnyCallable: raise errors.CallableError(value=v) -def flatten_literal(s: Any) -> List[Any]: - if is_literal_type(s): - s = list(literal_values(s)) - if s == []: - return s - if is_literal_type(s[0]): - return flatten_literal(s[0]) + flatten_literal(s[1:]) - return s[:1] + flatten_literal(s[1:]) - - def make_literal_validator(type_: Any) -> Callable[[Any], Any]: - permitted_choices: Tuple[Any, ...] = tuple(flatten_literal(type_)) + permitted_choices: Tuple[Any, ...] = all_literal_values(type_) allowed_choices_set = set(permitted_choices) def literal_validator(v: Any) -> Any: diff --git a/tests/test_validators.py b/tests/test_validators.py index 470ed4a2bb9..4b2fe35b0a6 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigError, Extra, ValidationError, errors, validator from pydantic.class_validators import make_generic_validator, root_validator from pydantic.typing import Literal -from pydantic.validators import flatten_literal +from pydantic.validators import all_literal_values def test_simple(): @@ -1086,11 +1086,11 @@ class Model(BaseModel): @pytest.mark.skipif(not Literal, reason='typing_extensions not installed') def test_flatten_nested_literals(): L1 = Literal['1'] - assert flatten_literal(L1) == ['1'] + assert all_literal_values(L1) == ('1',) L2 = Literal['2'] L12 = Literal[L1, L2] - assert sorted(flatten_literal(L12)) == sorted(['1', '2']) + assert sorted(all_literal_values(L12)) == sorted(('1', '2')) L312 = Literal['3', Literal[L1, L2]] - assert sorted(flatten_literal(L312)) == sorted(['1', '2', '3']) + assert sorted(all_literal_values(L312)) == sorted(('1', '2', '3'))