Skip to content

Commit

Permalink
fix: support custom root type with nested models in parse_obj (#2238)
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Feb 13, 2021
1 parent b076567 commit b21da6f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
1 change: 1 addition & 0 deletions changes/2238-PrettyWood.md
@@ -0,0 +1 @@
Support custom root type (aka `__root__`) when using `parse_obj()` with nested models
15 changes: 10 additions & 5 deletions pydantic/main.py
Expand Up @@ -530,12 +530,18 @@ def json(
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)

@classmethod
def parse_obj(cls: Type['Model'], obj: Any) -> 'Model':
def _enforce_dict_if_root(cls, obj: Any) -> Any:
if cls.__custom_root_type__ and (
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) or cls.__fields__[ROOT_KEY].shape == SHAPE_MAPPING
):
obj = {ROOT_KEY: obj}
elif not isinstance(obj, dict):
return {ROOT_KEY: obj}
else:
return obj

@classmethod
def parse_obj(cls: Type['Model'], obj: Any) -> 'Model':
obj = cls._enforce_dict_if_root(obj)
if not isinstance(obj, dict):
try:
obj = dict(obj)
except (TypeError, ValueError) as e:
Expand Down Expand Up @@ -683,14 +689,13 @@ def __get_validators__(cls) -> 'CallableGenerator':

@classmethod
def validate(cls: Type['Model'], value: Any) -> 'Model':
value = cls._enforce_dict_if_root(value)
if isinstance(value, dict):
return cls(**value)
elif isinstance(value, cls):
return value.copy() if cls.__config__.copy_on_model_validation else value
elif cls.__config__.orm_mode:
return cls.from_orm(value)
elif cls.__custom_root_type__:
return cls.parse_obj(value)
else:
try:
value_as_dict = dict(value)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_main.py
Expand Up @@ -1120,6 +1120,60 @@ class MyModel(BaseModel):
]


def test_parse_obj_nested_root():
class Pokemon(BaseModel):
name: str
level: int

class Pokemons(BaseModel):
__root__: List[Pokemon]

class Player(BaseModel):
rank: int
pokemons: Pokemons

class Players(BaseModel):
__root__: Dict[str, Player]

class Tournament(BaseModel):
players: Players
city: str

payload = {
'players': {
'Jane': {
'rank': 1,
'pokemons': [
{
'name': 'Pikachu',
'level': 100,
},
{
'name': 'Bulbasaur',
'level': 13,
},
],
},
'Tarzan': {
'rank': 2,
'pokemons': [
{
'name': 'Jigglypuff',
'level': 7,
},
],
},
},
'city': 'Qwerty',
}

tournament = Tournament.parse_obj(payload)
assert tournament.city == 'Qwerty'
assert len(tournament.players.__root__) == 2
assert len(tournament.players.__root__['Jane'].pokemons.__root__) == 2
assert tournament.players.__root__['Jane'].pokemons.__root__[0].name == 'Pikachu'


def test_untouched_types():
from pydantic import BaseModel

Expand Down

0 comments on commit b21da6f

Please sign in to comment.