Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate and parse nested models properly with default_factory #1712

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/run.py
Expand Up @@ -274,6 +274,6 @@ def diff():
else:
main()

if None in other_tests:
print('not all libraries could be imported!')
sys.exit(1)
# if None in other_tests:
# print('not all libraries could be imported!')
# sys.exit(1)
1 change: 1 addition & 0 deletions changes/1710-PrettyWood.md
@@ -0,0 +1 @@
fix validation and parsing of nested models with `default_factory`
26 changes: 16 additions & 10 deletions pydantic/fields.py
Expand Up @@ -341,8 +341,22 @@ def prepare(self) -> None:
e.g. calling it it multiple times may modify the field and configure it incorrectly.
"""

# To prevent side effects by calling the `default_factory` for nothing, we only call it
# when we want to validate the default value i.e. when `validate_all` is set to True.
self._set_default_and_type()
self._type_analysis()
if self.required is Undefined:
self.required = True
self.field_info.default = Required
if self.default is Undefined and self.default_factory is None:
self.default = None
self.populate_validators()

def _set_default_and_type(self) -> None:
"""
Set the default value, infer the type if needed and check if `None` value is valid.

Note: to prevent side effects by calling the `default_factory` for nothing, we only call it
when we want to validate the default value i.e. when `validate_all` is set to True.
"""
if self.default_factory is not None:
if self.type_ is None:
raise errors_.ConfigError(
Expand All @@ -368,14 +382,6 @@ def prepare(self) -> None:
if self.required is False and default_value is None:
self.allow_none = True

self._type_analysis()
if self.required is Undefined:
self.required = True
self.field_info.default = Required
if self.default is Undefined and self.default_factory is None:
self.default = None
self.populate_validators()

def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
# typing interface is horrible, we have to do some ugly checks
if lenient_issubclass(self.type_, JsonWrapper):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_edge_cases.py
Expand Up @@ -1714,3 +1714,19 @@ class Config:

m1 = MyModel()
assert m1.id == 2 # instead of 1


def test_default_factory_validator_child():
class Parent(BaseModel):
foo: List[str] = Field(default_factory=list)

@validator('foo', pre=True, each_item=True)
def mutate_foo(cls, v):
return f'{v}-1'

assert Parent(foo=['a', 'b']).foo == ['a-1', 'b-1']

class Child(Parent):
pass

assert Child(foo=['a', 'b']).foo == ['a-1', 'b-1']
30 changes: 30 additions & 0 deletions tests/test_main.py
Expand Up @@ -1174,6 +1174,36 @@ class MyModel(BaseModel):
assert m2.id == 2


def test_default_factory_validate_children():
class Child(BaseModel):
x: int

class Parent(BaseModel):
children: List[Child] = Field(default_factory=list)

Parent(children=[{'x': 1}, {'x': 2}])
with pytest.raises(ValidationError) as exc_info:
Parent(children=[{'x': 1}, {'y': 2}])

assert exc_info.value.errors() == [
{'loc': ('children', 1, 'x'), 'msg': 'field required', 'type': 'value_error.missing'},
]


def test_default_factory_parse():
class Inner(BaseModel):
val: int = Field(0)

class Outer(BaseModel):
inner_1: Inner = Field(default_factory=Inner)
inner_2: Inner = Field(Inner())

default = Outer().dict()
parsed = Outer.parse_obj(default)
assert parsed.dict() == {'inner_1': {'val': 0}, 'inner_2': {'val': 0}}
assert repr(parsed) == 'Outer(inner_1=Inner(val=0), inner_2=Inner(val=0))'


@pytest.mark.skipif(sys.version_info < (3, 7), reason='field constraints are set but not enforced with python 3.6')
def test_none_min_max_items():
# None default
Expand Down