Skip to content

Commit

Permalink
changes to some coercions (#208)
Browse files Browse the repository at this point in the history
* stop coercing `set / frozenset` to `list / tuple`

* add `dict_key` and `dict_value`

* add iterator support

* more tests

* use PyList directly for lax_list

* catch errors in generator evaluation

* fix mypy and more tests

* remove unused dict_items code

* make format

Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
PrettyWood and samuelcolvin committed Aug 3, 2022
1 parent 42a465a commit cfd5da7
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 43 deletions.
4 changes: 4 additions & 0 deletions src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ pub enum ErrorKind {
error: String,
},
// ---------------------
// generic list-list errors
#[strum(message = "Error iterating over object")]
IterationError,
// ---------------------
// list errors
#[strum(message = "Input should be a valid list/array")]
ListType,
Expand Down
28 changes: 28 additions & 0 deletions src/input/_pyo3_dict.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// TODO: remove this file once a new pyo3 version is released
// with https://github.com/PyO3/pyo3/pull/2358

use pyo3::{ffi, pyobject_native_type_core, PyAny};

/// Represents a Python `dict_keys`.
#[cfg(not(PyPy))]
#[repr(transparent)]
pub struct PyDictKeys(PyAny);

#[cfg(not(PyPy))]
pyobject_native_type_core!(
PyDictKeys,
ffi::PyDictKeys_Type,
#checkfunction=ffi::PyDictKeys_Check
);

/// Represents a Python `dict_values`.
#[cfg(not(PyPy))]
#[repr(transparent)]
pub struct PyDictValues(PyAny);

#[cfg(not(PyPy))]
pyobject_native_type_core!(
PyDictValues,
ffi::PyDictValues_Type,
#checkfunction=ffi::PyDictValues_Check
);
117 changes: 107 additions & 10 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use std::str::from_utf8;
use pyo3::exceptions::PyAttributeError;
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyList, PyMapping, PySequence,
PySet, PyString, PyTime, PyTuple, PyType,
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping,
PySequence, PySet, PyString, PyTime, PyTuple, PyType,
};
use pyo3::{intern, AsPyPointer};

use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};

#[cfg(not(PyPy))]
use super::_pyo3_dict::{PyDictKeys, PyDictValues};
use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
Expand All @@ -22,6 +24,25 @@ use super::{
GenericMapping, Input, PyArgs,
};

#[cfg(not(PyPy))]
macro_rules! extract_gen_dict {
($type:ty, $obj:ident) => {{
let map_err = |_| ValError::new(ErrorKind::IterationError, $obj);
if let Ok(iterator) = $obj.cast_as::<PyIterator>() {
let vec = iterator.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
Some(<$type>::new($obj.py(), vec))
} else if let Ok(dict_keys) = $obj.cast_as::<PyDictKeys>() {
let vec = dict_keys.iter()?.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
Some(<$type>::new($obj.py(), vec))
} else if let Ok(dict_values) = $obj.cast_as::<PyDictValues>() {
let vec = dict_values.iter()?.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
Some(<$type>::new($obj.py(), vec))
} else {
None
}
}};
}

impl<'a> Input<'a> for PyAny {
fn as_loc_item(&self) -> LocItem {
if let Ok(py_str) = self.cast_as::<PyString>() {
Expand Down Expand Up @@ -261,15 +282,30 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_list(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Some(list) = extract_gen_dict!(PyList, self) {
Ok(list.into())
} else {
Err(ValError::new(ErrorKind::ListType, self))
}
}

#[cfg(PyPy)]
fn lax_list(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyList::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::ListType, self))
}
Expand All @@ -283,15 +319,30 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_tuple(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
Ok(tuple.into())
} else {
Err(ValError::new(ErrorKind::TupleType, self))
}
}

#[cfg(PyPy)]
fn lax_tuple(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyTuple::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::TupleType, self))
}
Expand All @@ -305,6 +356,24 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_set(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
Ok(tuple.into())
} else {
Err(ValError::new(ErrorKind::SetType, self))
}
}

#[cfg(PyPy)]
fn lax_set(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
Expand All @@ -314,6 +383,11 @@ impl<'a> Input<'a> for PyAny {
Ok(tuple.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyTuple::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::SetType, self))
}
Expand All @@ -327,6 +401,24 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_frozenset(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
Ok(tuple.into())
} else {
Err(ValError::new(ErrorKind::FrozenSetType, self))
}
}

#[cfg(PyPy)]
fn lax_frozenset(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
Expand All @@ -336,6 +428,11 @@ impl<'a> Input<'a> for PyAny {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyTuple::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::FrozenSetType, self))
}
Expand Down
2 changes: 2 additions & 0 deletions src/input/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use pyo3::prelude::*;

#[cfg(not(PyPy))]
mod _pyo3_dict;
mod datetime;
mod input_abstract;
mod input_json;
Expand Down
51 changes: 44 additions & 7 deletions tests/validators/test_frozenset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import platform
import re
from typing import Any, Dict

Expand Down Expand Up @@ -63,19 +64,35 @@ def test_frozenset_no_validators_both(py_and_json: PyAndJson, input_value, expec
@pytest.mark.parametrize(
'input_value,expected',
[
({1, 2, 3}, {1, 2, 3}),
({1, 2, 3}, frozenset({1, 2, 3})),
(frozenset(), frozenset()),
([1, 2, 3, 2, 3], {1, 2, 3}),
([1, 2, 3, 2, 3], frozenset({1, 2, 3})),
([], frozenset()),
((1, 2, 3, 2, 3), {1, 2, 3}),
((1, 2, 3, 2, 3), frozenset({1, 2, 3})),
((), frozenset()),
(frozenset([1, 2, 3, 2, 3]), {1, 2, 3}),
(frozenset([1, 2, 3, 2, 3]), frozenset({1, 2, 3})),
pytest.param(
{1: 10, 2: 20, '3': '30'}.keys(),
frozenset({1, 2, 3}),
marks=pytest.mark.skipif(
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
),
),
pytest.param(
{1: 10, 2: 20, '3': '30'}.values(),
frozenset({10, 20, 30}),
marks=pytest.mark.skipif(
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
),
),
({1: 10, 2: 20, '3': '30'}, Err('Input should be a valid frozenset [kind=frozen_set_type,')),
# https://github.com/samuelcolvin/pydantic-core/issues/211
({1: 10, 2: 20, '3': '30'}.items(), Err('Input should be a valid frozenset [kind=frozen_set_type,')),
((x for x in [1, 2, '3']), frozenset({1, 2, 3})),
({'abc'}, Err('0\n Input should be a valid integer')),
({1, 2, 'wrong'}, Err('Input should be a valid integer')),
({1: 2}, Err('1 validation error for frozenset[int]\n Input should be a valid frozenset')),
('abc', Err('Input should be a valid frozenset')),
# Technically correct, but does anyone actually need this? I think needs a new type in pyo3
pytest.param({1: 10, 2: 20, 3: 30}.keys(), {1, 2, 3}, marks=pytest.mark.xfail(raises=ValidationError)),
],
)
def test_frozenset_ints_python(input_value, expected):
Expand All @@ -89,7 +106,10 @@ def test_frozenset_ints_python(input_value, expected):
assert isinstance(output, frozenset)


@pytest.mark.parametrize('input_value,expected', [([1, 2.5, '3'], {1, 2.5, '3'}), ([(1, 2), (3, 4)], {(1, 2), (3, 4)})])
@pytest.mark.parametrize(
'input_value,expected',
[(frozenset([1, 2.5, '3']), {1, 2.5, '3'}), ([1, 2.5, '3'], {1, 2.5, '3'}), ([(1, 2), (3, 4)], {(1, 2), (3, 4)})],
)
def test_frozenset_no_validators_python(input_value, expected):
v = SchemaValidator({'type': 'frozenset'})
output = v.validate_python(input_value)
Expand Down Expand Up @@ -216,3 +236,20 @@ def test_repr():
'strict:true,item_validator:None,size_range:Some((Some(42),None)),name:"frozenset[any]"'
'}))'
)


def test_generator_error():
def gen(error: bool):
yield 1
yield 2
if error:
raise RuntimeError('error')
yield 3

v = SchemaValidator({'type': 'frozenset', 'items_schema': 'int'})
r = v.validate_python(gen(False))
assert r == {1, 2, 3}
assert isinstance(r, frozenset)

with pytest.raises(ValidationError, match=r'Error iterating over object \[kind=iteration_error,'):
v.validate_python(gen(True))

0 comments on commit cfd5da7

Please sign in to comment.