Skip to content

Commit

Permalink
FIX: validation on model attribute with nested Literal breaks (#1364)
Browse files Browse the repository at this point in the history
* Add tests for nested literals validator

* Implement flatten literal in validator

* Add test for flatten literal

* Add changelog entry

* Add test skip markers if not Literal

* Refactor: use improved all literals implementation

From Github user PrettyWood, see PR #1364

* Add testing for typing module

Includes moving corresponding tests.

* Remove unnecessary type hint

* Move all literals test to test_utils
  • Loading branch information
DBCerigo committed May 23, 2020
1 parent 58b95b7 commit 913025a
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 7 deletions.
1 change: 1 addition & 0 deletions changes/1364-DBCerigo.md
@@ -0,0 +1 @@
Fix model validation to handle nested literals, e.g. `Literal['foo', Literal['bar']]`.
13 changes: 13 additions & 0 deletions pydantic/typing.py
Expand Up @@ -179,6 +179,19 @@ def literal_values(type_: AnyType) -> Tuple[Any, ...]:
return type_.__values__


def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
"""
This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
"""
if not is_literal_type(type_):
return (type_,)

values = literal_values(type_)
return tuple(x for value in values for x in all_literal_values(value))


test_type = NewType('test_type', str)


Expand Down
17 changes: 11 additions & 6 deletions pydantic/validators.py
@@ -1,5 +1,4 @@
import re
import sys
from collections import OrderedDict
from collections.abc import Hashable
from datetime import date, datetime, time, timedelta
Expand All @@ -26,7 +25,16 @@

from . import errors
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
from .typing import AnyCallable, AnyType, ForwardRef, display_as_type, get_class, is_callable_type, is_literal_type
from .typing import (
AnyCallable,
AnyType,
ForwardRef,
all_literal_values,
display_as_type,
get_class,
is_callable_type,
is_literal_type,
)
from .utils import almost_equal_floats, lenient_issubclass, sequence_like

if TYPE_CHECKING:
Expand Down Expand Up @@ -394,10 +402,7 @@ def callable_validator(v: Any) -> AnyCallable:


def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
if sys.version_info >= (3, 7):
permitted_choices = type_.__args__
else:
permitted_choices = type_.__values__
permitted_choices = all_literal_values(type_)
allowed_choices_set = set(permitted_choices)

def literal_validator(v: Any) -> Any:
Expand Down
15 changes: 14 additions & 1 deletion tests/test_utils.py
Expand Up @@ -11,7 +11,7 @@
from pydantic.color import Color
from pydantic.dataclasses import dataclass
from pydantic.fields import Undefined
from pydantic.typing import display_as_type, is_new_type, new_type_supertype
from pydantic.typing import Literal, all_literal_values, display_as_type, is_new_type, new_type_supertype
from pydantic.utils import (
ClassAttribute,
ValueItems,
Expand Down Expand Up @@ -320,3 +320,16 @@ class Foo:
f = Foo()
f.attr = 'not foo'
assert f.attr == 'not foo'


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_all_literal_values():
L1 = Literal['1']
assert all_literal_values(L1) == ('1',)

L2 = Literal['2']
L12 = Literal[L1, L2]
assert sorted(all_literal_values(L12)) == sorted(('1', '2'))

L312 = Literal['3', Literal[L1, L2]]
assert sorted(all_literal_values(L312)) == sorted(('1', '2', '3'))
42 changes: 42 additions & 0 deletions tests/test_validators.py
Expand Up @@ -6,6 +6,7 @@

from pydantic import BaseModel, ConfigError, Extra, ValidationError, errors, validator
from pydantic.class_validators import make_generic_validator, root_validator
from pydantic.typing import Literal


def test_simple():
Expand Down Expand Up @@ -1038,3 +1039,44 @@ def check_foo(cls, value):
m = Model(name='hello')
m.name = 'goodbye'
assert validator_calls == 2


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_literal_validator():
class Model(BaseModel):
a: Literal['foo']

Model(a='foo')

with pytest.raises(ValidationError) as exc_info:
Model(a='nope')
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': "unexpected value; permitted: 'foo'",
'type': 'value_error.const',
'ctx': {'given': 'nope', 'permitted': ('foo',)},
}
]


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_nested_literal_validator():
L1 = Literal['foo']
L2 = Literal['bar']

class Model(BaseModel):
a: Literal[L1, L2]

Model(a='foo')

with pytest.raises(ValidationError) as exc_info:
Model(a='nope')
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': "unexpected value; permitted: 'foo', 'bar'",
'type': 'value_error.const',
'ctx': {'given': 'nope', 'permitted': ('foo', 'bar')},
}
]

0 comments on commit 913025a

Please sign in to comment.