Skip to content

Commit

Permalink
fix: support properly custom root type with from_orm() (#2237)
Browse files Browse the repository at this point in the history
* fix: support custom root type with `from_orm()`

* add other example

* chore: add change file

* refactor(main): use ROOT_KEY instead of __root__
  • Loading branch information
PrettyWood committed Feb 13, 2021
1 parent 7bef40b commit b076567
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
1 change: 1 addition & 0 deletions changes/2237-PrettyWood.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support custom root type (aka `__root__`) with `from_orm()`
8 changes: 4 additions & 4 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def prepare_config(config: Type[BaseConfig], cls_name: str) -> None:

def validate_custom_root_type(fields: Dict[str, ModelField]) -> None:
if len(fields) > 1:
raise ValueError('__root__ cannot be mixed with other fields')
raise ValueError(f'{ROOT_KEY} cannot be mixed with other fields')


# Annotated fields can have many types like `str`, `int`, `List[str]`, `Callable`...
Expand Down Expand Up @@ -590,7 +590,7 @@ def parse_file(
def from_orm(cls: Type['Model'], obj: Any) -> 'Model':
if not cls.__config__.orm_mode:
raise ConfigError('You must have the config attribute orm_mode=True to use from_orm')
obj = cls._decompose_class(obj)
obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
m = cls.__new__(cls)
values, fields_set, validation_error = validate_model(cls, obj)
if validation_error:
Expand Down Expand Up @@ -726,8 +726,8 @@ def _get_value(
exclude=exclude,
exclude_none=exclude_none,
)
if '__root__' in v_dict:
return v_dict['__root__']
if ROOT_KEY in v_dict:
return v_dict[ROOT_KEY]
return v_dict
else:
return v.copy(include=include, exclude=exclude)
Expand Down
43 changes: 42 additions & 1 deletion tests/test_orm_mode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, Dict, List

import pytest

Expand Down Expand Up @@ -47,6 +47,47 @@ def __getattr__(self, key):
assert repr(gd) == "GetterDict[TestCls]({'a': 1, 'c': 3, 'd': 4})"


def test_orm_mode_root():
class PokemonCls:
def __init__(self, *, en_name: str, jp_name: str):
self.en_name = en_name
self.jp_name = jp_name

class Pokemon(BaseModel):
en_name: str
jp_name: str

class Config:
orm_mode = True

class PokemonList(BaseModel):
__root__: List[Pokemon]

class Config:
orm_mode = True

pika = PokemonCls(en_name='Pikachu', jp_name='ピカチュウ')
bulbi = PokemonCls(en_name='Bulbasaur', jp_name='フシギダネ')

pokemons = PokemonList.from_orm([pika, bulbi])
assert pokemons.__root__ == [
Pokemon(en_name='Pikachu', jp_name='ピカチュウ'),
Pokemon(en_name='Bulbasaur', jp_name='フシギダネ'),
]

class PokemonDict(BaseModel):
__root__: Dict[str, Pokemon]

class Config:
orm_mode = True

pokemons = PokemonDict.from_orm({'pika': pika, 'bulbi': bulbi})
assert pokemons.__root__ == {
'pika': Pokemon(en_name='Pikachu', jp_name='ピカチュウ'),
'bulbi': Pokemon(en_name='Bulbasaur', jp_name='フシギダネ'),
}


def test_orm_mode():
class PetCls:
def __init__(self, *, name: str, species: str):
Expand Down

0 comments on commit b076567

Please sign in to comment.