diff --git a/changes/4093-timkpaine.md b/changes/4093-timkpaine.md new file mode 100644 index 0000000000..8f22eb7119 --- /dev/null +++ b/changes/4093-timkpaine.md @@ -0,0 +1,3 @@ +Allow for shallow copies of attributes, adjusting the behavior of #3642 +`Config.copy_on_model_validation` is now a str enum of `["none", "deep", "shallow"]` corresponding to +not copying, deep copy, shallow copy, default `"shallow"`. diff --git a/pydantic/config.py b/pydantic/config.py index adb9fd4b21..6ceb28c01d 100644 --- a/pydantic/config.py +++ b/pydantic/config.py @@ -2,6 +2,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union +from typing_extensions import Literal, Protocol + from .typing import AnyCallable from .utils import GetterDict from .version import compiled @@ -9,14 +11,12 @@ if TYPE_CHECKING: from typing import overload - import typing_extensions - from .fields import ModelField from .main import BaseModel ConfigType = Type['BaseConfig'] - class SchemaExtraCallable(typing_extensions.Protocol): + class SchemaExtraCallable(Protocol): @overload def __call__(self, schema: Dict[str, Any]) -> None: pass @@ -103,8 +103,10 @@ class BaseConfig: json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {} underscore_attrs_are_private: bool = False - # whether inherited models as fields should be reconstructed as base model - copy_on_model_validation: bool = True + # whether inherited models as fields should be reconstructed as base model, + # and whether such a copy should be shallow or deep + copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow' + # whether `Union` should check all allowed types before even trying to coerce smart_union: bool = False # whether dataclass `__post_init__` should be run before or after validation diff --git a/pydantic/main.py b/pydantic/main.py index 4b8daec309..f8b48355e8 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -679,10 +679,28 @@ def __get_validators__(cls) -> 'CallableGenerator': @classmethod def validate(cls: Type['Model'], value: Any) -> 'Model': if isinstance(value, cls): - if cls.__config__.copy_on_model_validation: - return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=True) - else: + copy_on_model_validation = cls.__config__.copy_on_model_validation + # whether to deep or shallow copy the model on validation, None means do not copy + deep_copy: Optional[bool] = None + if copy_on_model_validation not in {'deep', 'shallow', 'none'}: + # Warn about deprecated behavior + warnings.warn( + "`copy_on_model_validation` should be a string: 'deep', 'shallow' or 'none'", DeprecationWarning + ) + if copy_on_model_validation: + deep_copy = False + + if copy_on_model_validation == 'shallow': + # shallow copy + deep_copy = False + elif copy_on_model_validation == 'deep': + # deep copy + deep_copy = True + + if deep_copy is None: return value + else: + return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=deep_copy) value = cls._enforce_dict_if_root(value) diff --git a/tests/test_main.py b/tests/test_main.py index ceac222264..41e46952dc 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1561,18 +1561,66 @@ class Config: assert t.user is not my_user assert t.user.hobbies == ['scuba diving'] - assert t.user.hobbies is not my_user.hobbies # `Config.copy_on_model_validation` does a deep copy + assert t.user.hobbies is my_user.hobbies # `Config.copy_on_model_validation` does a shallow copy assert t.user._priv == 13 assert t.user.password.get_secret_value() == 'hashedpassword' assert t.dict() == {'id': '1234567890', 'user': {'id': 42, 'hobbies': ['scuba diving']}} +def test_model_exclude_copy_on_model_validation_shallow(): + """When `Config.copy_on_model_validation` is set and `Config.copy_on_model_validation_shallow` is set, + do the same as the previous test but perform a shallow copy""" + + class User(BaseModel): + class Config: + copy_on_model_validation = 'shallow' + + hobbies: List[str] + + my_user = User(hobbies=['scuba diving']) + + class Transaction(BaseModel): + user: User = Field(...) + + t = Transaction(user=my_user) + + assert t.user is not my_user + assert t.user.hobbies is my_user.hobbies # unlike above, this should be a shallow copy + + +@pytest.mark.parametrize('comv_value', [True, False]) +def test_copy_on_model_validation_warning(comv_value): + class User(BaseModel): + class Config: + # True interpreted as 'shallow', False interpreted as 'none' + copy_on_model_validation = comv_value + + hobbies: List[str] + + my_user = User(hobbies=['scuba diving']) + + class Transaction(BaseModel): + user: User + + with pytest.warns(DeprecationWarning, match="`copy_on_model_validation` should be a string: 'deep', 'shallow' or"): + t = Transaction(user=my_user) + + if comv_value: + assert t.user is not my_user + else: + assert t.user is my_user + assert t.user.hobbies is my_user.hobbies + + def test_validation_deep_copy(): """By default, Config.copy_on_model_validation should do a deep copy""" class A(BaseModel): name: str + class Config: + copy_on_model_validation = 'deep' + class B(BaseModel): list_a: List[A] @@ -1986,7 +2034,7 @@ def __hash__(self): return id(self) class Config: - copy_on_model_validation = False + copy_on_model_validation = 'none' class Item(BaseModel): images: List[Image]