Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes to some coercions #208

Merged
merged 10 commits into from
Aug 3, 2022
4 changes: 4 additions & 0 deletions src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ pub enum ErrorKind {
error: String,
},
// ---------------------
// generic list-list errors
#[strum(message = "Error iterating over object")]
IterationError,
// ---------------------
// list errors
#[strum(message = "Input should be a valid list/array")]
ListType,
Expand Down
28 changes: 28 additions & 0 deletions src/input/_pyo3_dict.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// TODO: remove this file once a new pyo3 version is released
// with https://github.com/PyO3/pyo3/pull/2358

use pyo3::{ffi, pyobject_native_type_core, PyAny};

/// Represents a Python `dict_keys`.
#[cfg(not(PyPy))]
#[repr(transparent)]
pub struct PyDictKeys(PyAny);

#[cfg(not(PyPy))]
pyobject_native_type_core!(
PyDictKeys,
ffi::PyDictKeys_Type,
#checkfunction=ffi::PyDictKeys_Check
);

/// Represents a Python `dict_values`.
#[cfg(not(PyPy))]
#[repr(transparent)]
pub struct PyDictValues(PyAny);

#[cfg(not(PyPy))]
pyobject_native_type_core!(
PyDictValues,
ffi::PyDictValues_Type,
#checkfunction=ffi::PyDictValues_Check
);
117 changes: 107 additions & 10 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use std::str::from_utf8;
use pyo3::exceptions::PyAttributeError;
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyList, PyMapping, PySequence,
PySet, PyString, PyTime, PyTuple, PyType,
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping,
PySequence, PySet, PyString, PyTime, PyTuple, PyType,
};
use pyo3::{intern, AsPyPointer};

use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};

#[cfg(not(PyPy))]
use super::_pyo3_dict::{PyDictKeys, PyDictValues};
use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
Expand All @@ -22,6 +24,25 @@ use super::{
GenericMapping, Input, PyArgs,
};

#[cfg(not(PyPy))]
macro_rules! extract_gen_dict {
($type:ty, $obj:ident) => {{
let map_err = |_| ValError::new(ErrorKind::IterationError, $obj);
if let Ok(iterator) = $obj.cast_as::<PyIterator>() {
let vec = iterator.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
Some(<$type>::new($obj.py(), vec))
} else if let Ok(dict_keys) = $obj.cast_as::<PyDictKeys>() {
let vec = dict_keys.iter()?.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
Some(<$type>::new($obj.py(), vec))
} else if let Ok(dict_values) = $obj.cast_as::<PyDictValues>() {
let vec = dict_values.iter()?.collect::<PyResult<Vec<_>>>().map_err(map_err)?;
Some(<$type>::new($obj.py(), vec))
} else {
None
}
}};
}

impl<'a> Input<'a> for PyAny {
fn as_loc_item(&self) -> LocItem {
if let Ok(py_str) = self.cast_as::<PyString>() {
Expand Down Expand Up @@ -261,15 +282,30 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_list(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Some(list) = extract_gen_dict!(PyList, self) {
Ok(list.into())
} else {
Err(ValError::new(ErrorKind::ListType, self))
}
}

#[cfg(PyPy)]
fn lax_list(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyList::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::ListType, self))
}
Expand All @@ -283,15 +319,30 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_tuple(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
Ok(tuple.into())
} else {
Err(ValError::new(ErrorKind::TupleType, self))
}
}

#[cfg(PyPy)]
fn lax_tuple(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyTuple::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::TupleType, self))
}
Expand All @@ -305,6 +356,24 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_set(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
Ok(tuple.into())
} else {
Err(ValError::new(ErrorKind::SetType, self))
}
}

#[cfg(PyPy)]
fn lax_set(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
Expand All @@ -314,6 +383,11 @@ impl<'a> Input<'a> for PyAny {
Ok(tuple.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyTuple::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::SetType, self))
}
Expand All @@ -327,6 +401,24 @@ impl<'a> Input<'a> for PyAny {
}
}

#[cfg(not(PyPy))]
fn lax_frozenset(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Ok(set) = self.cast_as::<PySet>() {
Ok(set.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Some(tuple) = extract_gen_dict!(PyTuple, self) {
Ok(tuple.into())
} else {
Err(ValError::new(ErrorKind::FrozenSetType, self))
}
}

#[cfg(PyPy)]
fn lax_frozenset(&'a self) -> ValResult<GenericListLike<'a>> {
if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
Expand All @@ -336,6 +428,11 @@ impl<'a> Input<'a> for PyAny {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let vec = iterator
.collect::<PyResult<Vec<_>>>()
.map_err(|_| ValError::new(ErrorKind::IterationError, self))?;
Ok(PyTuple::new(self.py(), vec).into())
} else {
Err(ValError::new(ErrorKind::FrozenSetType, self))
}
Expand Down
2 changes: 2 additions & 0 deletions src/input/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use pyo3::prelude::*;

#[cfg(not(PyPy))]
mod _pyo3_dict;
mod datetime;
mod input_abstract;
mod input_json;
Expand Down
51 changes: 44 additions & 7 deletions tests/validators/test_frozenset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import platform
import re
from typing import Any, Dict

Expand Down Expand Up @@ -63,19 +64,35 @@ def test_frozenset_no_validators_both(py_and_json: PyAndJson, input_value, expec
@pytest.mark.parametrize(
'input_value,expected',
[
({1, 2, 3}, {1, 2, 3}),
({1, 2, 3}, frozenset({1, 2, 3})),
(frozenset(), frozenset()),
([1, 2, 3, 2, 3], {1, 2, 3}),
([1, 2, 3, 2, 3], frozenset({1, 2, 3})),
([], frozenset()),
((1, 2, 3, 2, 3), {1, 2, 3}),
((1, 2, 3, 2, 3), frozenset({1, 2, 3})),
((), frozenset()),
(frozenset([1, 2, 3, 2, 3]), {1, 2, 3}),
(frozenset([1, 2, 3, 2, 3]), frozenset({1, 2, 3})),
pytest.param(
{1: 10, 2: 20, '3': '30'}.keys(),
frozenset({1, 2, 3}),
marks=pytest.mark.skipif(
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
),
),
pytest.param(
{1: 10, 2: 20, '3': '30'}.values(),
frozenset({10, 20, 30}),
marks=pytest.mark.skipif(
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
),
),
({1: 10, 2: 20, '3': '30'}, Err('Input should be a valid frozenset [kind=frozen_set_type,')),
# https://github.com/samuelcolvin/pydantic-core/issues/211
({1: 10, 2: 20, '3': '30'}.items(), Err('Input should be a valid frozenset [kind=frozen_set_type,')),
((x for x in [1, 2, '3']), frozenset({1, 2, 3})),
({'abc'}, Err('0\n Input should be a valid integer')),
({1, 2, 'wrong'}, Err('Input should be a valid integer')),
({1: 2}, Err('1 validation error for frozenset[int]\n Input should be a valid frozenset')),
('abc', Err('Input should be a valid frozenset')),
# Technically correct, but does anyone actually need this? I think needs a new type in pyo3
pytest.param({1: 10, 2: 20, 3: 30}.keys(), {1, 2, 3}, marks=pytest.mark.xfail(raises=ValidationError)),
],
)
def test_frozenset_ints_python(input_value, expected):
Expand All @@ -89,7 +106,10 @@ def test_frozenset_ints_python(input_value, expected):
assert isinstance(output, frozenset)


@pytest.mark.parametrize('input_value,expected', [([1, 2.5, '3'], {1, 2.5, '3'}), ([(1, 2), (3, 4)], {(1, 2), (3, 4)})])
@pytest.mark.parametrize(
'input_value,expected',
[(frozenset([1, 2.5, '3']), {1, 2.5, '3'}), ([1, 2.5, '3'], {1, 2.5, '3'}), ([(1, 2), (3, 4)], {(1, 2), (3, 4)})],
)
def test_frozenset_no_validators_python(input_value, expected):
v = SchemaValidator({'type': 'frozenset'})
output = v.validate_python(input_value)
Expand Down Expand Up @@ -216,3 +236,20 @@ def test_repr():
'strict:true,item_validator:None,size_range:Some((Some(42),None)),name:"frozenset[any]"'
'}))'
)


def test_generator_error():
def gen(error: bool):
yield 1
yield 2
if error:
raise RuntimeError('error')
yield 3

v = SchemaValidator({'type': 'frozenset', 'items_schema': 'int'})
r = v.validate_python(gen(False))
assert r == {1, 2, 3}
assert isinstance(r, frozenset)

with pytest.raises(ValidationError, match=r'Error iterating over object \[kind=iteration_error,'):
v.validate_python(gen(True))