Skip to content

Commit

Permalink
Refactor: use improved all literals implementation
Browse files Browse the repository at this point in the history
From Github user PrettyWood, see PR pydantic#1364
  • Loading branch information
DBCerigo committed Apr 26, 2020
1 parent 47edf04 commit 291cc3b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
13 changes: 13 additions & 0 deletions pydantic/typing.py
Expand Up @@ -177,6 +177,19 @@ def literal_values(type_: AnyType) -> Tuple[Any, ...]:
return type_.__values__


def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
"""
This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
"""
if not is_literal_type(type_):
return (type_,)

values = literal_values(type_)
return tuple(x for value in values for x in all_literal_values(value))


test_type = NewType('test_type', str)


Expand Down
14 changes: 2 additions & 12 deletions pydantic/validators.py
Expand Up @@ -28,11 +28,11 @@
AnyCallable,
AnyType,
ForwardRef,
all_literal_values,
display_as_type,
get_class,
is_callable_type,
is_literal_type,
literal_values,
)
from .utils import almost_equal_floats, lenient_issubclass, sequence_like

Expand Down Expand Up @@ -393,18 +393,8 @@ def callable_validator(v: Any) -> AnyCallable:
raise errors.CallableError(value=v)


def flatten_literal(s: Any) -> List[Any]:
if is_literal_type(s):
s = list(literal_values(s))
if s == []:
return s
if is_literal_type(s[0]):
return flatten_literal(s[0]) + flatten_literal(s[1:])
return s[:1] + flatten_literal(s[1:])


def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
permitted_choices: Tuple[Any, ...] = tuple(flatten_literal(type_))
permitted_choices: Tuple[Any, ...] = all_literal_values(type_)
allowed_choices_set = set(permitted_choices)

def literal_validator(v: Any) -> Any:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_validators.py
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, ConfigError, Extra, ValidationError, errors, validator
from pydantic.class_validators import make_generic_validator, root_validator
from pydantic.typing import Literal
from pydantic.validators import flatten_literal
from pydantic.validators import all_literal_values


def test_simple():
Expand Down Expand Up @@ -1086,11 +1086,11 @@ class Model(BaseModel):
@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_flatten_nested_literals():
L1 = Literal['1']
assert flatten_literal(L1) == ['1']
assert all_literal_values(L1) == ('1',)

L2 = Literal['2']
L12 = Literal[L1, L2]
assert sorted(flatten_literal(L12)) == sorted(['1', '2'])
assert sorted(all_literal_values(L12)) == sorted(('1', '2'))

L312 = Literal['3', Literal[L1, L2]]
assert sorted(flatten_literal(L312)) == sorted(['1', '2', '3'])
assert sorted(all_literal_values(L312)) == sorted(('1', '2', '3'))

0 comments on commit 291cc3b

Please sign in to comment.