Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revalidate models #177

Merged
merged 2 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pydantic_core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Config(TypedDict, total=False):
typed_dict_populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
# used on typed-dicts and tagged union keys
from_attributes: bool
revalidate_models: bool
# fields related to string fields only
str_max_length: int
str_min_length: int
Expand Down
23 changes: 23 additions & 0 deletions src/build_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl<'py> SchemaDict<'py> for PyDict {
}
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn get_as_req<T>(&'py self, key: &str) -> PyResult<T>
where
T: FromPyObject<'py>,
Expand All @@ -40,6 +41,28 @@ impl<'py> SchemaDict<'py> for PyDict {
}
}

impl<'py> SchemaDict<'py> for Option<&PyDict> {
fn get_as<T>(&'py self, key: &str) -> PyResult<Option<T>>
where
T: FromPyObject<'py>,
{
match self {
Some(d) => d.get_as(key),
None => Ok(None),
}
}

fn get_as_req<T>(&'py self, key: &str) -> PyResult<T>
where
T: FromPyObject<'py>,
{
match self {
Some(d) => d.get_as_req(key),
None => py_error!(PyKeyError; "{}", key),
}
}
}

pub fn schema_or_config<'py, T>(
schema: &'py PyDict,
config: Option<&'py PyDict>,
Expand Down
7 changes: 6 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt;

use pyo3::prelude::*;
use pyo3::types::PyType;
use pyo3::types::{PyString, PyType};

use crate::errors::{InputValue, LocItem, ValResult};
use crate::input::datetime::EitherTime;
Expand Down Expand Up @@ -29,6 +29,11 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
Ok(false)
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn get_attr(&self, _name: &PyString) -> Option<&PyAny> {
None
}

fn is_instance(&self, _class: &PyType) -> PyResult<bool> {
Ok(false)
}
Expand Down
4 changes: 4 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ impl<'a> Input<'a> for PyAny {
Ok(self.get_type().eq(class)?)
}

fn get_attr(&self, name: &PyString) -> Option<&PyAny> {
self.getattr(name).ok()
}

fn is_instance(&self, class: &PyType) -> PyResult<bool> {
self.is_instance(class)
}
Expand Down
2 changes: 1 addition & 1 deletion src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl<'a> Extra<'a> {
#[enum_dispatch]
pub enum CombinedValidator {
// typed dict e.g. heterogeneous dicts or simply a model
Model(typed_dict::TypedDictValidator),
TypedDict(typed_dict::TypedDictValidator),
// unions
Union(union::UnionValidator),
TaggedUnion(union::TaggedUnionValidator),
Expand Down
33 changes: 24 additions & 9 deletions src/validators/model_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ use crate::errors::{ErrorKind, ValError, ValResult};
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;

use super::typed_dict::TypedDictValidator;
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

#[derive(Debug, Clone)]
pub struct ModelClassValidator {
strict: bool,
validator: Box<CombinedValidator>,
revalidate: bool,
validator: TypedDictValidator,
class: Py<PyType>,
name: String,
}
Expand All @@ -36,17 +38,23 @@ impl BuildValidator for ModelClassValidator {

let class: &PyType = schema.get_as_req("class_type")?;
let sub_schema: &PyAny = schema.get_as_req("schema")?;
let (validator, td_schema) = build_validator(sub_schema, config, build_context)?;
let (comb_validator, td_schema) = build_validator(sub_schema, config, build_context)?;

if !td_schema.get_as("return_fields_set")?.unwrap_or(false) {
return py_error!(r#"model-class inner schema must have "return_fields_set" set to True"#);
return py_error!("model-class inner schema must have 'return_fields_set' set to True");
}

let validator = match comb_validator {
CombinedValidator::TypedDict(tdv) => tdv,
_ => return py_error!("Wrong validator type, expected 'typed-dict' validator"),
};

Ok(Self {
// we don't use is_strict here since we don't want validation to be strict in this case if
// `config.strict` is set, only if this specific field is strict
strict: schema.get_as("strict")?.unwrap_or(false),
validator: Box::new(validator),
revalidate: config.get_as("revalidate_models")?.unwrap_or(false),
validator,
class: class.into(),
// Get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
// which is not what we want here
Expand All @@ -67,7 +75,15 @@ impl Validator for ModelClassValidator {
) -> ValResult<'data, PyObject> {
let class = self.class.as_ref(py);
if input.is_type(class)? {
Ok(input.to_object(py))
if self.revalidate {
let fields_set = input.get_attr(intern!(py, "__fields_set__"));
let output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
let (model_dict, validation_fields_set): (&PyAny, &PyAny) = output.extract(py)?;
let fields_set = fields_set.unwrap_or(validation_fields_set);
Ok(self.create_class(py, model_dict, fields_set)?)
} else {
Ok(input.to_object(py))
}
} else if extra.strict.unwrap_or(self.strict) {
Err(ValError::new(
ErrorKind::ModelClassType {
Expand All @@ -77,7 +93,8 @@ impl Validator for ModelClassValidator {
))
} else {
let output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
Ok(self.create_class(py, output)?)
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;
Ok(self.create_class(py, model_dict, fields_set)?)
}
}

Expand All @@ -87,9 +104,7 @@ impl Validator for ModelClassValidator {
}

impl ModelClassValidator {
fn create_class(&self, py: Python, output: PyObject) -> PyResult<PyObject> {
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;

fn create_class(&self, py: Python, model_dict: &PyAny, fields_set: &PyAny) -> PyResult<PyObject> {
// based on the following but with the second argument of new_func set to an empty tuple as required
// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/pyclass_init.rs#L35-L77
let args = PyTuple::empty(py);
Expand Down
87 changes: 86 additions & 1 deletion tests/validators/test_model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_not_return_fields_set():
class MyModel:
pass

with pytest.raises(SchemaError, match='model-class inner schema must have "return_fields_set" set to True'):
with pytest.raises(SchemaError, match="model-class inner schema must have 'return_fields_set' set to True"):
SchemaValidator(
{
'type': 'model-class',
Expand Down Expand Up @@ -212,6 +212,7 @@ def __init__(self):
},
}
)
assert re.search(r'revalidate: \w+', repr(v)).group(0) == 'revalidate: false'
m = MyModel()
m2 = v.validate_python(m)
assert isinstance(m, MyModel)
Expand Down Expand Up @@ -243,3 +244,87 @@ def test_internal_error():
)
with pytest.raises(AttributeError, match=re.escape("'int' object has no attribute '__dict__'")):
v.validate_python({'f': 123})


def test_revalidate():
class MyModel:
__slots__ = '__dict__', '__fields_set__'

def __init__(self, a, b, fields_set):
self.field_a = a
self.field_b = b
self.__fields_set__ = fields_set

v = SchemaValidator(
{
'type': 'model-class',
'class_type': MyModel,
'schema': {
'type': 'typed-dict',
'return_fields_set': True,
'from_attributes': True,
'fields': {'field_a': {'schema': {'type': 'str'}}, 'field_b': {'schema': {'type': 'int'}}},
},
'config': {'revalidate_models': True},
}
)
assert re.search(r'revalidate: \w+', repr(v)).group(0) == 'revalidate: true'

m = v.validate_python({'field_a': 'test', 'field_b': 12})
assert isinstance(m, MyModel)
assert m.__dict__ == {'field_a': 'test', 'field_b': 12}
assert m.__fields_set__ == {'field_a', 'field_b'}

m2 = MyModel('x', 42, {'field_a'})
m3 = v.validate_python(m2)
assert isinstance(m3, MyModel)
assert m3 is not m2
assert m3.__dict__ == {'field_a': 'x', 'field_b': 42}
assert m3.__fields_set__ == {'field_a'}

m4 = MyModel('x', 'not int', {'field_a'})
with pytest.raises(ValidationError) as exc_info:
v.validate_python(m4)
assert exc_info.value.errors() == [
{
'kind': 'int_parsing',
'loc': ['field_b'],
'message': 'Value must be a valid integer, unable to parse string as an integer',
'input_value': 'not int',
}
]


def test_revalidate_extra():
class MyModel:
__slots__ = '__dict__', '__fields_set__'

def __init__(self, **kwargs):
self.__dict__.update(kwargs)

v = SchemaValidator(
{
'type': 'model-class',
'class_type': MyModel,
'schema': {
'type': 'typed-dict',
'return_fields_set': True,
'from_attributes': True,
'extra_behavior': 'allow',
'fields': {'field_a': {'schema': {'type': 'str'}}, 'field_b': {'schema': {'type': 'int'}}},
},
'config': {'revalidate_models': True},
}
)

m = v.validate_python({'field_a': 'test', 'field_b': 12, 'more': (1, 2, 3)})
assert isinstance(m, MyModel)
assert m.__dict__ == {'field_a': 'test', 'field_b': 12, 'more': (1, 2, 3)}
assert m.__fields_set__ == {'field_a', 'field_b', 'more'}

m2 = MyModel(field_a='x', field_b=42, another=42.5)
m3 = v.validate_python(m2)
assert isinstance(m3, MyModel)
assert m3 is not m2
assert m3.__dict__ == {'field_a': 'x', 'field_b': 42, 'another': 42.5}
assert m3.__fields_set__ == {'field_a', 'field_b', 'another'}
2 changes: 1 addition & 1 deletion tests/validators/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_aliases_debug():
v = SchemaValidator(
{'type': 'typed-dict', 'fields': {'field_a': {'alias': [['foo', 'bar', 'bat'], ['foo', 3]], 'schema': 'int'}}}
)
assert repr(v).startswith('SchemaValidator(name="typed-dict", validator=Model(')
assert repr(v).startswith('SchemaValidator(name="typed-dict", validator=TypedDict(')
assert 'PathChoices(' in repr(v)


Expand Down