diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 43b1780956..f0949e78bb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,6 +8,7 @@ on: - pydantic-v2-blog tags: - '**' + pull_request: {} jobs: lint: diff --git a/changes/4093-timkpaine.md b/changes/4093-timkpaine.md new file mode 100644 index 0000000000..8f22eb7119 --- /dev/null +++ b/changes/4093-timkpaine.md @@ -0,0 +1,3 @@ +Allow for shallow copies of attributes, adjusting the behavior of #3642 +`Config.copy_on_model_validation` is now a str enum of `["none", "deep", "shallow"]` corresponding to +not copying, deep copy, shallow copy, default `"shallow"`. diff --git a/docs/requirements.txt b/docs/requirements.txt index 348505f4d9..6df8b4ee69 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,12 @@ -ansi2html==1.6.0 -flake8==4.0.1 +ansi2html==1.8.0 +flake8==5.0.4 flake8-quotes==3.3.1 -hypothesis==6.46.3 -markdown-include==0.6.0 -mdx-truly-sane-lists==1.2 -mkdocs==1.3.0 +hypothesis==6.54.1 +markdown-include==0.7.0 +mdx-truly-sane-lists==1.3 +mkdocs==1.3.1 mkdocs-exclude==1.0.2 -mkdocs-material==8.2.14 +mkdocs-material==8.3.9 sqlalchemy orjson ujson diff --git a/pydantic/config.py b/pydantic/config.py index b37cd98ff1..b9d9bebccb 100644 --- a/pydantic/config.py +++ b/pydantic/config.py @@ -2,20 +2,20 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union +from typing_extensions import Literal, Protocol + from .typing import AnyCallable from .utils import GetterDict if TYPE_CHECKING: from typing import overload - import typing_extensions - from .fields import ModelField from .main import BaseModel ConfigType = Type['BaseConfig'] - class SchemaExtraCallable(typing_extensions.Protocol): + class SchemaExtraCallable(Protocol): @overload def __call__(self, schema: Dict[str, Any]) -> None: pass @@ -63,8 +63,10 @@ class BaseConfig: json_encoders: Dict[Union[Type[Any], str], AnyCallable] = {} underscore_attrs_are_private: bool = False - # whether inherited models as fields should be reconstructed as base model - copy_on_model_validation: bool = True + # whether inherited models as fields should be reconstructed as base model, + # and whether such a copy should be shallow or deep + copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow' + # whether `Union` should check all allowed types before even trying to coerce smart_union: bool = False diff --git a/pydantic/main.py b/pydantic/main.py index 0c20d9e69d..c57d48400c 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -675,10 +675,28 @@ def __get_validators__(cls) -> 'CallableGenerator': @classmethod def validate(cls: Type['Model'], value: Any) -> 'Model': if isinstance(value, cls): - if cls.__config__.copy_on_model_validation: - return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=True) - else: + copy_on_model_validation = cls.__config__.copy_on_model_validation + # whether to deep or shallow copy the model on validation, None means do not copy + deep_copy: Optional[bool] = None + if copy_on_model_validation not in {'deep', 'shallow', 'none'}: + # Warn about deprecated behavior + warnings.warn( + "`copy_on_model_validation` should be a string: 'deep', 'shallow' or 'none'", DeprecationWarning + ) + if copy_on_model_validation: + deep_copy = False + + if copy_on_model_validation == 'shallow': + # shallow copy + deep_copy = False + elif copy_on_model_validation == 'deep': + # deep copy + deep_copy = True + + if deep_copy is None: return value + else: + return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=deep_copy) value = cls._enforce_dict_if_root(value) diff --git a/tests/test_main.py b/tests/test_main.py index 8cf290cbeb..0b96f96a21 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1561,18 +1561,66 @@ class Config: assert t.user is not my_user assert t.user.hobbies == ['scuba diving'] - assert t.user.hobbies is not my_user.hobbies # `Config.copy_on_model_validation` does a deep copy + assert t.user.hobbies is my_user.hobbies # `Config.copy_on_model_validation` does a shallow copy assert t.user._priv == 13 assert t.user.password.get_secret_value() == 'hashedpassword' assert t.dict() == {'id': '1234567890', 'user': {'id': 42, 'hobbies': ['scuba diving']}} +def test_model_exclude_copy_on_model_validation_shallow(): + """When `Config.copy_on_model_validation` is set and `Config.copy_on_model_validation_shallow` is set, + do the same as the previous test but perform a shallow copy""" + + class User(BaseModel): + class Config: + copy_on_model_validation = 'shallow' + + hobbies: List[str] + + my_user = User(hobbies=['scuba diving']) + + class Transaction(BaseModel): + user: User = Field(...) + + t = Transaction(user=my_user) + + assert t.user is not my_user + assert t.user.hobbies is my_user.hobbies # unlike above, this should be a shallow copy + + +@pytest.mark.parametrize('comv_value', [True, False]) +def test_copy_on_model_validation_warning(comv_value): + class User(BaseModel): + class Config: + # True interpreted as 'shallow', False interpreted as 'none' + copy_on_model_validation = comv_value + + hobbies: List[str] + + my_user = User(hobbies=['scuba diving']) + + class Transaction(BaseModel): + user: User + + with pytest.warns(DeprecationWarning, match="`copy_on_model_validation` should be a string: 'deep', 'shallow' or"): + t = Transaction(user=my_user) + + if comv_value: + assert t.user is not my_user + else: + assert t.user is my_user + assert t.user.hobbies is my_user.hobbies + + def test_validation_deep_copy(): """By default, Config.copy_on_model_validation should do a deep copy""" class A(BaseModel): name: str + class Config: + copy_on_model_validation = 'deep' + class B(BaseModel): list_a: List[A] @@ -1987,7 +2035,7 @@ def __hash__(self): return id(self) class Config: - copy_on_model_validation = False + copy_on_model_validation = 'none' class Item(BaseModel): images: List[Image]