Skip to content

Commit

Permalink
Avoid side effect on default factory by calling it only once
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed May 11, 2020
1 parent 5067508 commit 1d5e1dd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
1 change: 1 addition & 0 deletions changes/1491-PrettyWood.md
@@ -0,0 +1 @@
Avoid side effects with `default_factory` by not calling it multiple times
8 changes: 4 additions & 4 deletions pydantic/fields.py
Expand Up @@ -271,9 +271,9 @@ def __init__(
self.model_config.prepare_field(self)
self.prepare()

def get_default(self) -> Any:
def get_default(self, *, copy_factory: bool = False) -> Any:
if self.default_factory is not None:
value = self.default_factory()
value = self.default_factory() if not copy_factory else deepcopy(self.default_factory)()
elif self.default is None:
# deepcopy is quite slow on None
value = None
Expand All @@ -296,7 +296,7 @@ def infer(

if isinstance(value, FieldInfo):
field_info = value
value = field_info.default_factory() if field_info.default_factory is not None else field_info.default
value = None if field_info.default_factory is not None else field_info.default
else:
field_info = FieldInfo(value, **field_info_from_config)
required: 'BoolUndefined' = Undefined
Expand Down Expand Up @@ -341,7 +341,7 @@ def prepare(self) -> None:
Note: this method is **not** idempotent (because _type_analysis is not idempotent),
e.g. calling it it multiple times may modify the field and configure it incorrectly.
"""
default_value = self.get_default()
default_value = self.get_default(copy_factory=True)
if default_value is not None and self.type_ is None:
self.type_ = default_value.__class__
self.outer_type_ = self.type_
Expand Down
21 changes: 21 additions & 0 deletions tests/test_main.py
Expand Up @@ -1088,6 +1088,27 @@ class FunctionModel(BaseModel):
assert m.uid is uuid4


def test_default_factory_side_effect():
"""It should call only once the given factory"""

class Seq:
def __init__(self):
self.v = 0

def __call__(self):
self.v += 1
return self.v

class MyModel(BaseModel):
id: int = Field(default_factory=Seq())

m1 = MyModel()
assert m1.id == 1
m2 = MyModel()
assert m2.id == 2
assert m1.id == 1


@pytest.mark.skipif(sys.version_info < (3, 7), reason='field constraints are set but not enforced with python 3.6')
def test_none_min_max_items():
# None default
Expand Down

0 comments on commit 1d5e1dd

Please sign in to comment.