Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: validation on model attribute with nested Literal breaks #1364

Merged
merged 9 commits into from May 23, 2020
1 change: 1 addition & 0 deletions changes/1364-DBCerigo.md
@@ -0,0 +1 @@
Fix model validation to handle nested literals, e.g. `Literal['foo', Literal['bar']]`.
13 changes: 13 additions & 0 deletions pydantic/typing.py
Expand Up @@ -179,6 +179,19 @@ def literal_values(type_: AnyType) -> Tuple[Any, ...]:
return type_.__values__


def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
"""
This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
"""
if not is_literal_type(type_):
return (type_,)

values = literal_values(type_)
return tuple(x for value in values for x in all_literal_values(value))


test_type = NewType('test_type', str)


Expand Down
17 changes: 11 additions & 6 deletions pydantic/validators.py
@@ -1,5 +1,4 @@
import re
import sys
from collections import OrderedDict
from collections.abc import Hashable
from datetime import date, datetime, time, timedelta
Expand All @@ -26,7 +25,16 @@

from . import errors
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
from .typing import AnyCallable, AnyType, ForwardRef, display_as_type, get_class, is_callable_type, is_literal_type
from .typing import (
AnyCallable,
AnyType,
ForwardRef,
all_literal_values,
display_as_type,
get_class,
is_callable_type,
is_literal_type,
)
from .utils import almost_equal_floats, lenient_issubclass, sequence_like

if TYPE_CHECKING:
Expand Down Expand Up @@ -394,10 +402,7 @@ def callable_validator(v: Any) -> AnyCallable:


def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
if sys.version_info >= (3, 7):
permitted_choices = type_.__args__
else:
permitted_choices = type_.__values__
permitted_choices = all_literal_values(type_)
allowed_choices_set = set(permitted_choices)

def literal_validator(v: Any) -> Any:
Expand Down
15 changes: 14 additions & 1 deletion tests/test_utils.py
Expand Up @@ -11,7 +11,7 @@
from pydantic.color import Color
from pydantic.dataclasses import dataclass
from pydantic.fields import Undefined
from pydantic.typing import display_as_type, is_new_type, new_type_supertype
from pydantic.typing import Literal, all_literal_values, display_as_type, is_new_type, new_type_supertype
from pydantic.utils import (
ClassAttribute,
ValueItems,
Expand Down Expand Up @@ -320,3 +320,16 @@ class Foo:
f = Foo()
f.attr = 'not foo'
assert f.attr == 'not foo'


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_all_literal_values():
L1 = Literal['1']
assert all_literal_values(L1) == ('1',)

L2 = Literal['2']
L12 = Literal[L1, L2]
assert sorted(all_literal_values(L12)) == sorted(('1', '2'))

L312 = Literal['3', Literal[L1, L2]]
assert sorted(all_literal_values(L312)) == sorted(('1', '2', '3'))
42 changes: 42 additions & 0 deletions tests/test_validators.py
Expand Up @@ -6,6 +6,7 @@

from pydantic import BaseModel, ConfigError, Extra, ValidationError, errors, validator
from pydantic.class_validators import make_generic_validator, root_validator
from pydantic.typing import Literal


def test_simple():
Expand Down Expand Up @@ -1038,3 +1039,44 @@ def check_foo(cls, value):
m = Model(name='hello')
m.name = 'goodbye'
assert validator_calls == 2


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_literal_validator():
class Model(BaseModel):
a: Literal['foo']

Model(a='foo')

with pytest.raises(ValidationError) as exc_info:
Model(a='nope')
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': "unexpected value; permitted: 'foo'",
'type': 'value_error.const',
'ctx': {'given': 'nope', 'permitted': ('foo',)},
}
]


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_nested_literal_validator():
L1 = Literal['foo']
L2 = Literal['bar']

class Model(BaseModel):
a: Literal[L1, L2]

Model(a='foo')

with pytest.raises(ValidationError) as exc_info:
Model(a='nope')
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': "unexpected value; permitted: 'foo', 'bar'",
'type': 'value_error.const',
'ctx': {'given': 'nope', 'permitted': ('foo', 'bar')},
}
]