Skip to content

Commit

Permalink
revalidate models (#177)
Browse files Browse the repository at this point in the history
* revaliate models

* fields_set and tests
  • Loading branch information
samuelcolvin committed Jul 18, 2022
1 parent cf04609 commit f29eec1
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 13 deletions.
1 change: 1 addition & 0 deletions pydantic_core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Config(TypedDict, total=False):
typed_dict_populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
# used on typed-dicts and tagged union keys
from_attributes: bool
revalidate_models: bool
# fields related to string fields only
str_max_length: int
str_min_length: int
Expand Down
23 changes: 23 additions & 0 deletions src/build_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl<'py> SchemaDict<'py> for PyDict {
}
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn get_as_req<T>(&'py self, key: &str) -> PyResult<T>
where
T: FromPyObject<'py>,
Expand All @@ -40,6 +41,28 @@ impl<'py> SchemaDict<'py> for PyDict {
}
}

impl<'py> SchemaDict<'py> for Option<&PyDict> {
fn get_as<T>(&'py self, key: &str) -> PyResult<Option<T>>
where
T: FromPyObject<'py>,
{
match self {
Some(d) => d.get_as(key),
None => Ok(None),
}
}

fn get_as_req<T>(&'py self, key: &str) -> PyResult<T>
where
T: FromPyObject<'py>,
{
match self {
Some(d) => d.get_as_req(key),
None => py_error!(PyKeyError; "{}", key),
}
}
}

pub fn schema_or_config<'py, T>(
schema: &'py PyDict,
config: Option<&'py PyDict>,
Expand Down
7 changes: 6 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt;

use pyo3::prelude::*;
use pyo3::types::PyType;
use pyo3::types::{PyString, PyType};

use crate::errors::{InputValue, LocItem, ValResult};
use crate::input::datetime::EitherTime;
Expand Down Expand Up @@ -29,6 +29,11 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
Ok(false)
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn get_attr(&self, _name: &PyString) -> Option<&PyAny> {
None
}

fn is_instance(&self, _class: &PyType) -> PyResult<bool> {
Ok(false)
}
Expand Down
4 changes: 4 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ impl<'a> Input<'a> for PyAny {
Ok(self.get_type().eq(class)?)
}

fn get_attr(&self, name: &PyString) -> Option<&PyAny> {
self.getattr(name).ok()
}

fn is_instance(&self, class: &PyType) -> PyResult<bool> {
self.is_instance(class)
}
Expand Down
2 changes: 1 addition & 1 deletion src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl<'a> Extra<'a> {
#[enum_dispatch]
pub enum CombinedValidator {
// typed dict e.g. heterogeneous dicts or simply a model
Model(typed_dict::TypedDictValidator),
TypedDict(typed_dict::TypedDictValidator),
// unions
Union(union::UnionValidator),
TaggedUnion(union::TaggedUnionValidator),
Expand Down
33 changes: 24 additions & 9 deletions src/validators/model_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ use crate::errors::{ErrorKind, ValError, ValResult};
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;

use super::typed_dict::TypedDictValidator;
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

#[derive(Debug, Clone)]
pub struct ModelClassValidator {
strict: bool,
validator: Box<CombinedValidator>,
revalidate: bool,
validator: TypedDictValidator,
class: Py<PyType>,
name: String,
}
Expand All @@ -36,17 +38,23 @@ impl BuildValidator for ModelClassValidator {

let class: &PyType = schema.get_as_req("class_type")?;
let sub_schema: &PyAny = schema.get_as_req("schema")?;
let (validator, td_schema) = build_validator(sub_schema, config, build_context)?;
let (comb_validator, td_schema) = build_validator(sub_schema, config, build_context)?;

if !td_schema.get_as("return_fields_set")?.unwrap_or(false) {
return py_error!(r#"model-class inner schema must have "return_fields_set" set to True"#);
return py_error!("model-class inner schema must have 'return_fields_set' set to True");
}

let validator = match comb_validator {
CombinedValidator::TypedDict(tdv) => tdv,
_ => return py_error!("Wrong validator type, expected 'typed-dict' validator"),
};

Ok(Self {
// we don't use is_strict here since we don't want validation to be strict in this case if
// `config.strict` is set, only if this specific field is strict
strict: schema.get_as("strict")?.unwrap_or(false),
validator: Box::new(validator),
revalidate: config.get_as("revalidate_models")?.unwrap_or(false),
validator,
class: class.into(),
// Get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
// which is not what we want here
Expand All @@ -67,7 +75,15 @@ impl Validator for ModelClassValidator {
) -> ValResult<'data, PyObject> {
let class = self.class.as_ref(py);
if input.is_type(class)? {
Ok(input.to_object(py))
if self.revalidate {
let fields_set = input.get_attr(intern!(py, "__fields_set__"));
let output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
let (model_dict, validation_fields_set): (&PyAny, &PyAny) = output.extract(py)?;
let fields_set = fields_set.unwrap_or(validation_fields_set);
Ok(self.create_class(py, model_dict, fields_set)?)
} else {
Ok(input.to_object(py))
}
} else if extra.strict.unwrap_or(self.strict) {
Err(ValError::new(
ErrorKind::ModelClassType {
Expand All @@ -77,7 +93,8 @@ impl Validator for ModelClassValidator {
))
} else {
let output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
Ok(self.create_class(py, output)?)
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;
Ok(self.create_class(py, model_dict, fields_set)?)
}
}

Expand All @@ -87,9 +104,7 @@ impl Validator for ModelClassValidator {
}

impl ModelClassValidator {
fn create_class(&self, py: Python, output: PyObject) -> PyResult<PyObject> {
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;

fn create_class(&self, py: Python, model_dict: &PyAny, fields_set: &PyAny) -> PyResult<PyObject> {
// based on the following but with the second argument of new_func set to an empty tuple as required
// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/pyclass_init.rs#L35-L77
let args = PyTuple::empty(py);
Expand Down
87 changes: 86 additions & 1 deletion tests/validators/test_model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_not_return_fields_set():
class MyModel:
pass

with pytest.raises(SchemaError, match='model-class inner schema must have "return_fields_set" set to True'):
with pytest.raises(SchemaError, match="model-class inner schema must have 'return_fields_set' set to True"):
SchemaValidator(
{
'type': 'model-class',
Expand Down Expand Up @@ -212,6 +212,7 @@ def __init__(self):
},
}
)
assert re.search(r'revalidate: \w+', repr(v)).group(0) == 'revalidate: false'
m = MyModel()
m2 = v.validate_python(m)
assert isinstance(m, MyModel)
Expand Down Expand Up @@ -243,3 +244,87 @@ def test_internal_error():
)
with pytest.raises(AttributeError, match=re.escape("'int' object has no attribute '__dict__'")):
v.validate_python({'f': 123})


def test_revalidate():
class MyModel:
__slots__ = '__dict__', '__fields_set__'

def __init__(self, a, b, fields_set):
self.field_a = a
self.field_b = b
self.__fields_set__ = fields_set

v = SchemaValidator(
{
'type': 'model-class',
'class_type': MyModel,
'schema': {
'type': 'typed-dict',
'return_fields_set': True,
'from_attributes': True,
'fields': {'field_a': {'schema': {'type': 'str'}}, 'field_b': {'schema': {'type': 'int'}}},
},
'config': {'revalidate_models': True},
}
)
assert re.search(r'revalidate: \w+', repr(v)).group(0) == 'revalidate: true'

m = v.validate_python({'field_a': 'test', 'field_b': 12})
assert isinstance(m, MyModel)
assert m.__dict__ == {'field_a': 'test', 'field_b': 12}
assert m.__fields_set__ == {'field_a', 'field_b'}

m2 = MyModel('x', 42, {'field_a'})
m3 = v.validate_python(m2)
assert isinstance(m3, MyModel)
assert m3 is not m2
assert m3.__dict__ == {'field_a': 'x', 'field_b': 42}
assert m3.__fields_set__ == {'field_a'}

m4 = MyModel('x', 'not int', {'field_a'})
with pytest.raises(ValidationError) as exc_info:
v.validate_python(m4)
assert exc_info.value.errors() == [
{
'kind': 'int_parsing',
'loc': ['field_b'],
'message': 'Value must be a valid integer, unable to parse string as an integer',
'input_value': 'not int',
}
]


def test_revalidate_extra():
class MyModel:
__slots__ = '__dict__', '__fields_set__'

def __init__(self, **kwargs):
self.__dict__.update(kwargs)

v = SchemaValidator(
{
'type': 'model-class',
'class_type': MyModel,
'schema': {
'type': 'typed-dict',
'return_fields_set': True,
'from_attributes': True,
'extra_behavior': 'allow',
'fields': {'field_a': {'schema': {'type': 'str'}}, 'field_b': {'schema': {'type': 'int'}}},
},
'config': {'revalidate_models': True},
}
)

m = v.validate_python({'field_a': 'test', 'field_b': 12, 'more': (1, 2, 3)})
assert isinstance(m, MyModel)
assert m.__dict__ == {'field_a': 'test', 'field_b': 12, 'more': (1, 2, 3)}
assert m.__fields_set__ == {'field_a', 'field_b', 'more'}

m2 = MyModel(field_a='x', field_b=42, another=42.5)
m3 = v.validate_python(m2)
assert isinstance(m3, MyModel)
assert m3 is not m2
assert m3.__dict__ == {'field_a': 'x', 'field_b': 42, 'another': 42.5}
assert m3.__fields_set__ == {'field_a', 'field_b', 'another'}
2 changes: 1 addition & 1 deletion tests/validators/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_aliases_debug():
v = SchemaValidator(
{'type': 'typed-dict', 'fields': {'field_a': {'alias': [['foo', 'bar', 'bat'], ['foo', 3]], 'schema': 'int'}}}
)
assert repr(v).startswith('SchemaValidator(name="typed-dict", validator=Model(')
assert repr(v).startswith('SchemaValidator(name="typed-dict", validator=TypedDict(')
assert 'PathChoices(' in repr(v)


Expand Down

0 comments on commit f29eec1

Please sign in to comment.