Skip to content

Commit

Permalink
fix: models copied via Config.copy_on_model_validation always have …
Browse files Browse the repository at this point in the history
…all fields (#3201)

Small regression in #2231.
The shallow copy done with `Config.copy_on_model_validation = True` (default behaviour)
was using excluded / included fields when it should just copy everything

closes #3195
  • Loading branch information
PrettyWood committed Dec 19, 2021
1 parent da916f3 commit 5ad73d0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
41 changes: 24 additions & 17 deletions pydantic/main.py
Expand Up @@ -586,6 +586,24 @@ def construct(cls: Type['Model'], _fields_set: Optional['SetStr'] = None, **valu
m._init_private_attributes()
return m

def _copy_and_set_values(self: 'Model', values: 'DictStrAny', fields_set: 'SetStr', *, deep: bool) -> 'Model':
if deep:
# chances of having empty dict here are quite low for using smart_deepcopy
values = deepcopy(values)

cls = self.__class__
m = cls.__new__(cls)
object_setattr(m, '__dict__', values)
object_setattr(m, '__fields_set__', fields_set)
for name in self.__private_attributes__:
value = getattr(self, name, Undefined)
if value is not Undefined:
if deep:
value = deepcopy(value)
object_setattr(m, name, value)

return m

def copy(
self: 'Model',
*,
Expand All @@ -605,32 +623,18 @@ def copy(
:return: new model instance
"""

v = dict(
values = dict(
self._iter(to_dict=False, by_alias=False, include=include, exclude=exclude, exclude_unset=False),
**(update or {}),
)

if deep:
# chances of having empty dict here are quite low for using smart_deepcopy
v = deepcopy(v)

cls = self.__class__
m = cls.__new__(cls)
object_setattr(m, '__dict__', v)
# new `__fields_set__` can have unset optional fields with a set value in `update` kwarg
if update:
fields_set = self.__fields_set__ | update.keys()
else:
fields_set = set(self.__fields_set__)
object_setattr(m, '__fields_set__', fields_set)
for name in self.__private_attributes__:
value = getattr(self, name, Undefined)
if value is not Undefined:
if deep:
value = deepcopy(value)
object_setattr(m, name, value)

return m
return self._copy_and_set_values(values, fields_set, deep=deep)

@classmethod
def schema(cls, by_alias: bool = True, ref_template: str = default_ref_template) -> 'DictStrAny':
Expand Down Expand Up @@ -658,7 +662,10 @@ def __get_validators__(cls) -> 'CallableGenerator':
@classmethod
def validate(cls: Type['Model'], value: Any) -> 'Model':
if isinstance(value, cls):
return value.copy() if cls.__config__.copy_on_model_validation else value
if cls.__config__.copy_on_model_validation:
return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=False)
else:
return value

value = cls._enforce_dict_if_root(value)

Expand Down
41 changes: 41 additions & 0 deletions tests/test_main.py
Expand Up @@ -28,7 +28,9 @@
Field,
NoneBytes,
NoneStr,
PrivateAttr,
Required,
SecretStr,
ValidationError,
constr,
root_validator,
Expand Down Expand Up @@ -1516,6 +1518,45 @@ class Config:
assert Model.__fields__['b'].field_info.exclude == {'foo': ..., 'bar': ...}


def test_model_exclude_copy_on_model_validation():
"""When `Config.copy_on_model_validation` is set, it should keep private attributes and excluded fields"""

class User(BaseModel):
_priv: int = PrivateAttr()
id: int
username: str
password: SecretStr = Field(exclude=True)
hobbies: List[str]

my_user = User(id=42, username='JohnDoe', password='hashedpassword', hobbies=['scuba diving'])

my_user._priv = 13
assert my_user.id == 42
assert my_user.password.get_secret_value() == 'hashedpassword'
assert my_user.dict() == {'id': 42, 'username': 'JohnDoe', 'hobbies': ['scuba diving']}

class Transaction(BaseModel):
id: str
user: User = Field(..., exclude={'username'})
value: int

class Config:
fields = {'value': {'exclude': True}}

t = Transaction(
id='1234567890',
user=my_user,
value=9876543210,
)

assert t.user is not my_user
assert t.user.hobbies == ['scuba diving']
assert t.user.hobbies is my_user.hobbies # `Config.copy_on_model_validation` only does a shallow copy
assert t.user._priv == 13
assert t.user.password.get_secret_value() == 'hashedpassword'
assert t.dict() == {'id': '1234567890', 'user': {'id': 42, 'hobbies': ['scuba diving']}}


@pytest.mark.parametrize(
'kinds',
[
Expand Down

0 comments on commit 5ad73d0

Please sign in to comment.