diff --git a/changes/3375-dosisod.md b/changes/3375-dosisod.md new file mode 100644 index 0000000000..a7eca2040e --- /dev/null +++ b/changes/3375-dosisod.md @@ -0,0 +1 @@ +Allow for passing keyword arguments to `from_orm` diff --git a/docs/examples/models_orm_mode_kwargs.py b/docs/examples/models_orm_mode_kwargs.py new file mode 100644 index 0000000000..2610c08f09 --- /dev/null +++ b/docs/examples/models_orm_mode_kwargs.py @@ -0,0 +1,30 @@ +from pydantic import BaseModel +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + + +class MyModel(BaseModel): + foo: str + bar: int + spam: bytes + + class Config: + orm_mode = True + + +BaseModel = declarative_base() + + +class SQLModel(BaseModel): + __tablename__ = 'my_table' + id = sa.Column('id', sa.Integer, primary_key=True) + foo = sa.Column('foo', sa.String(32)) + bar = sa.Column('bar', sa.Integer) + + +sql_model = SQLModel(id=1, foo='hello world', bar=123) + +pydantic_model = MyModel.from_orm(sql_model, bar=456, spam=b'placeholder') + +print(pydantic_model.dict()) +print(pydantic_model.dict(by_alias=True)) diff --git a/docs/usage/models.md b/docs/usage/models.md index bddb2ec19d..c321538fbc 100644 --- a/docs/usage/models.md +++ b/docs/usage/models.md @@ -152,6 +152,14 @@ _(This script is complete, it should run "as is")_ The example above works because aliases have priority over field names for field population. Accessing `SQLModel`'s `metadata` attribute would lead to a `ValidationError`. +You can also achieve the same thing by passing keyword arguments to `from_orm`, instead of +using Field aliases: + +```py +{!.tmp_examples/models_orm_mode_kwargs.py!} +``` +_(This script is complete, it should run "as is")_ + ### Recursive ORM models ORM instances will be parsed with `from_orm` recursively as well as at the top level. diff --git a/pydantic/main.py b/pydantic/main.py index 4b8daec309..48c22dd421 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -568,10 +568,10 @@ def parse_file( return cls.parse_obj(obj) @classmethod - def from_orm(cls: Type['Model'], obj: Any) -> 'Model': + def from_orm(cls: Type['Model'], obj: Any, **kwargs: Any) -> 'Model': if not cls.__config__.orm_mode: raise ConfigError('You must have the config attribute orm_mode=True to use from_orm') - obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) + obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj, kwargs) m = cls.__new__(cls) values, fields_set, validation_error = validate_model(cls, obj) if validation_error: @@ -698,10 +698,10 @@ def validate(cls: Type['Model'], value: Any) -> 'Model': return cls(**value_as_dict) @classmethod - def _decompose_class(cls: Type['Model'], obj: Any) -> GetterDict: + def _decompose_class(cls: Type['Model'], obj: Any, kwargs: Dict[str, Any]) -> GetterDict: if isinstance(obj, GetterDict): return obj - return cls.__config__.getter_dict(obj) + return cls.__config__.getter_dict(obj, kwargs) @classmethod @no_type_check diff --git a/pydantic/utils.py b/pydantic/utils.py index 972f2e20ca..d407fd7798 100644 --- a/pydantic/utils.py +++ b/pydantic/utils.py @@ -413,18 +413,26 @@ class GetterDict(Representation): We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves. """ - __slots__ = ('_obj',) + __slots__ = ('_obj', '_kwargs') - def __init__(self, obj: Any): + def __init__(self, obj: Any, kwargs: Optional[Dict[str, Any]] = None): self._obj = obj + self._kwargs = kwargs or None def __getitem__(self, key: str) -> Any: try: + if self._kwargs and key in self._kwargs: + return self._kwargs[key] return getattr(self._obj, key) except AttributeError as e: raise KeyError(key) from e def get(self, key: Any, default: Any = None) -> Any: + if self._kwargs: + try: + return self._kwargs[key] + except KeyError: + pass return getattr(self._obj, key, default) def extra_keys(self) -> Set[Any]: diff --git a/tests/test_orm_mode.py b/tests/test_orm_mode.py index af31fd2509..ae837118c1 100644 --- a/tests/test_orm_mode.py +++ b/tests/test_orm_mode.py @@ -26,7 +26,7 @@ def __getattr__(self, key): raise AttributeError() t = TestCls() - gd = GetterDict(t) + gd = GetterDict(t, {'extra': 42}) assert gd.keys() == ['a', 'c', 'd'] assert gd.get('a') == 1 assert gd['a'] == 1 @@ -46,6 +46,7 @@ def __getattr__(self, key): assert len(gd) == 3 assert str(gd) == "{'a': 1, 'c': 3, 'd': 4}" assert repr(gd) == "GetterDict[TestCls]({'a': 1, 'c': 3, 'd': 4})" + assert gd['extra'] == 42 def test_orm_mode_root(): @@ -253,7 +254,7 @@ class TestCls: x = 1 y = 2 - def custom_getter_dict(obj): + def custom_getter_dict(obj, _): assert isinstance(obj, TestCls) return {'x': 42, 'y': 24} @@ -357,3 +358,48 @@ class Config: # Pass dictionary data directly State(**{'user': {'first_name': 'John', 'last_name': 'Appleseed'}}) + + +def test_orm_mode_with_kwargs(): + class MyModel(BaseModel): + foo: str + bar: int + spam: bytes + + class Config: + orm_mode = True + + class SQLModel(BaseModel): + id: int + foo: str + bar: int + + sql_model = SQLModel(id=1, foo='hello world', bar=123) + pydantic_model = MyModel.from_orm(sql_model, bar=456, spam=b'placeholder') + + assert pydantic_model.foo == 'hello world' + assert pydantic_model.bar == 456 + assert pydantic_model.spam == b'placeholder' + + +def test_orm_mode_with_kwargs_property_wont_be_acessed(): + class Model(BaseModel): + a: int + + class Config: + orm_mode = True + + property_was_accessed = False + + class Test: + a = 1 + + @property + def test(self): + nonlocal property_was_accessed + property_was_accessed = True + + test = Test() + Model.from_orm(test, extra=42) + + assert property_was_accessed is False