Skip to content

Commit

Permalink
fix: ensure to always return one of the values in Literal field type (
Browse files Browse the repository at this point in the history
#2181)

* fix: ensure to always return one of the values in `Literal` field type

closes #2166

* perf: improve `literal_validator` speed

Thanks to @yobiscus

* fix: when more options in Literal

switch from `set` to `dict` to still have a O(1) complexity
Thanks @layday :)
  • Loading branch information
PrettyWood committed Jan 1, 2021
1 parent 191647c commit 80175f3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
1 change: 1 addition & 0 deletions changes/2166-PrettyWood.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fix: ensure to always return one of the values in `Literal` field type
11 changes: 8 additions & 3 deletions pydantic/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,17 @@ def int_enum_validator(v: Any) -> IntEnum:

def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
permitted_choices = all_literal_values(type_)
allowed_choices_set = set(permitted_choices)

# To have a O(1) complexity and still return one of the values set inside the `Literal`,
# we create a dict with the set values (a set causes some problems with the way intersection works).
# In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`)
allowed_choices = {v: v for v in permitted_choices}

def literal_validator(v: Any) -> Any:
if v not in allowed_choices_set:
try:
return allowed_choices[v]
except KeyError:
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
return v

return literal_validator

Expand Down
23 changes: 23 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import deque
from datetime import datetime
from enum import Enum
from itertools import product
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -1135,6 +1136,28 @@ class Model(BaseModel):
]


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_literal_validator_str_enum():
class Bar(str, Enum):
FIZ = 'fiz'
FUZ = 'fuz'

class Foo(BaseModel):
bar: Bar
barfiz: Literal[Bar.FIZ]
fizfuz: Literal[Bar.FIZ, Bar.FUZ]

my_foo = Foo.parse_obj({'bar': 'fiz', 'barfiz': 'fiz', 'fizfuz': 'fiz'})
assert my_foo.bar is Bar.FIZ
assert my_foo.barfiz is Bar.FIZ
assert my_foo.fizfuz is Bar.FIZ

my_foo = Foo.parse_obj({'bar': 'fiz', 'barfiz': 'fiz', 'fizfuz': 'fuz'})
assert my_foo.bar is Bar.FIZ
assert my_foo.barfiz is Bar.FIZ
assert my_foo.fizfuz is Bar.FUZ


@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_nested_literal_validator():
L1 = Literal['foo']
Expand Down

0 comments on commit 80175f3

Please sign in to comment.