Skip to content

Commit

Permalink
add iterator support
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Aug 1, 2022
1 parent d1f4216 commit acc1483
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::str::from_utf8;
use pyo3::exceptions::PyAttributeError;
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyList, PyMapping,
PySequence, PySet, PyString, PyTime, PyTuple, PyType,
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyType,
};
use pyo3::{intern, AsPyPointer};

Expand Down Expand Up @@ -279,6 +279,9 @@ impl<'a> Input<'a> for PyAny {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let tuple = PyTuple::new(self.py(), iterator.iter()?.flatten().collect::<Vec<_>>());
Ok(tuple.into())
} else if let Ok(dict_keys) = self.cast_as::<PyDictKeys>() {
Ok(dict_keys.as_sequence().tuple()?.into())
} else if let Ok(dict_values) = self.cast_as::<PyDictValues>() {
Expand Down Expand Up @@ -313,6 +316,9 @@ impl<'a> Input<'a> for PyAny {
Ok(tuple.into())
} else if let Ok(list) = self.cast_as::<PyList>() {
Ok(list.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let tuple = PyTuple::new(self.py(), iterator.iter()?.flatten().collect::<Vec<_>>());
Ok(tuple.into())
} else if let Ok(dict_keys) = self.cast_as::<PyDictKeys>() {
Ok(dict_keys.as_sequence().tuple()?.into())
} else if let Ok(dict_values) = self.cast_as::<PyDictValues>() {
Expand Down Expand Up @@ -351,6 +357,9 @@ impl<'a> Input<'a> for PyAny {
Ok(tuple.into())
} else if let Ok(frozen_set) = self.cast_as::<PyFrozenSet>() {
Ok(frozen_set.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let tuple = PyTuple::new(self.py(), iterator.iter()?.flatten().collect::<Vec<_>>());
Ok(tuple.into())
} else if let Ok(dict_keys) = self.cast_as::<PyDictKeys>() {
Ok(dict_keys.as_sequence().tuple()?.into())
} else if let Ok(dict_values) = self.cast_as::<PyDictValues>() {
Expand Down Expand Up @@ -393,6 +402,9 @@ impl<'a> Input<'a> for PyAny {
Ok(list.into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(tuple.into())
} else if let Ok(iterator) = self.cast_as::<PyIterator>() {
let tuple = PyTuple::new(self.py(), iterator.iter()?.flatten().collect::<Vec<_>>());
Ok(tuple.into())
} else if let Ok(dict_keys) = self.cast_as::<PyDictKeys>() {
Ok(dict_keys.as_sequence().tuple()?.into())
} else if let Ok(dict_values) = self.cast_as::<PyDictValues>() {
Expand Down
1 change: 1 addition & 0 deletions tests/validators/test_frozenset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_frozenset_no_validators_both(py_and_json: PyAndJson, input_value, expec
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
),
),
((x for x in [1, 2, '3']), frozenset({1, 2, 3})),
({'abc'}, Err('0\n Input should be a valid integer')),
({1, 2, 'wrong'}, Err('Input should be a valid integer')),
({1: 2}, Err('1 validation error for frozenset[int]\n Input should be a valid frozenset')),
Expand Down
1 change: 1 addition & 0 deletions tests/validators/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_list_strict():
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
),
),
((x for x in [1, 2, '3']), [1, 2, 3]),
],
)
def test_list_int(input_value, expected):
Expand Down
1 change: 1 addition & 0 deletions tests/validators/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_frozenset_no_validators_both(py_and_json: PyAndJson, input_value, expec
platform.python_implementation() == 'PyPy', reason='dict views not implemented in pyo3 for pypy'
),
),
((x for x in [1, 2, '3']), {1, 2, 3}),
({'abc'}, Err('0\n Input should be a valid integer')),
({1: 2}, Err('1 validation error for set[int]\n Input should be a valid set')),
('abc', Err('Input should be a valid set')),
Expand Down
11 changes: 11 additions & 0 deletions tests/validators/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ def test_tuple_validate(input_value, expected, mode, items):
assert v.validate_python(input_value) == expected


# Since `test_tuple_validate` is parametrized above, the generator is consumed
# on the first test run. This is a workaround to make sure the generator is
# always recreated.
@pytest.mark.parametrize(
'mode,items', [('variable', {'type': 'int'}), ('positional', [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}])]
)
def test_tuple_validate_iterator(mode, items):
v = SchemaValidator({'type': 'tuple', 'mode': mode, 'items_schema': items})
assert v.validate_python((x for x in [1, 2, '3'])) == (1, 2, 3)


@pytest.mark.parametrize(
'input_value,index',
[
Expand Down

0 comments on commit acc1483

Please sign in to comment.