diff --git a/src/conversions/array.rs b/src/conversions/array.rs index 3071d147cd3..3dbe855cef1 100644 --- a/src/conversions/array.rs +++ b/src/conversions/array.rs @@ -3,70 +3,6 @@ use crate::{ ToPyObject, }; -#[cfg(not(min_const_generics))] -macro_rules! array_impls { - ($($N:expr),+) => { - $( - impl IntoPy for [T; $N] - where - T: ToPyObject - { - fn into_py(self, py: Python) -> PyObject { - self.as_ref().to_object(py) - } - } - - impl<'a, T> FromPyObject<'a> for [T; $N] - where - T: Copy + Default + FromPyObject<'a>, - { - #[cfg(not(feature = "nightly"))] - fn extract(obj: &'a PyAny) -> PyResult { - let mut array = [T::default(); $N]; - _extract_sequence_into_slice(obj, &mut array)?; - Ok(array) - } - - #[cfg(feature = "nightly")] - default fn extract(obj: &'a PyAny) -> PyResult { - let mut array = [T::default(); $N]; - _extract_sequence_into_slice(obj, &mut array)?; - Ok(array) - } - } - - #[cfg(feature = "nightly")] - impl<'source, T> FromPyObject<'source> for [T; $N] - where - for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element, - { - fn extract(obj: &'source PyAny) -> PyResult { - let mut array = [T::default(); $N]; - // first try buffer protocol - if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 { - if let Ok(buf) = crate::buffer::PyBuffer::get(obj) { - if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() { - buf.release(obj.py()); - return Ok(array); - } - buf.release(obj.py()); - } - } - // fall back to sequence protocol - _extract_sequence_into_slice(obj, &mut array)?; - Ok(array) - } - } - )+ - } -} - -#[cfg(not(min_const_generics))] -array_impls!( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, - 26, 27, 28, 29, 30, 31, 32 -); - #[cfg(min_const_generics)] impl IntoPy for [T; N] where @@ -100,10 +36,10 @@ where { fn extract(obj: &'source PyAny) -> PyResult { use crate::{AsPyPointer, PyNativeType}; - let mut array = [T::default(); N]; // first try buffer protocol if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 { if let Ok(buf) = crate::buffer::PyBuffer::get(obj) { + let mut array = [T::default(); N]; if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() { buf.release(obj.py()); return Ok(array); @@ -111,9 +47,7 @@ where buf.release(obj.py()); } } - // fall back to sequence protocol - _extract_sequence_into_slice(obj, &mut array)?; - Ok(array) + create_array_from_obj(obj) } } @@ -123,12 +57,11 @@ where T: FromPyObject<'s>, { let seq = ::try_from(obj)?; - let expected_len = seq.len()? as usize; - array_try_from_fn(|idx| { - seq.get_item(idx as isize) - .map_err(|_| invalid_sequence_length(expected_len, idx + 1))? - .extract::() - }) + let seq_len = seq.len()? as usize; + if seq_len != N { + return Err(invalid_sequence_length(N, seq_len)); + } + array_try_from_fn(|idx| seq.get_item(idx as isize).and_then(PyAny::extract)) } // TODO use std::array::try_from_fn, if that stabilises: @@ -174,7 +107,72 @@ where } } -fn _extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()> +#[cfg(not(min_const_generics))] +macro_rules! array_impls { + ($($N:expr),+) => { + $( + impl IntoPy for [T; $N] + where + T: ToPyObject + { + fn into_py(self, py: Python) -> PyObject { + self.as_ref().to_object(py) + } + } + + impl<'a, T> FromPyObject<'a> for [T; $N] + where + T: Copy + Default + FromPyObject<'a>, + { + #[cfg(not(feature = "nightly"))] + fn extract(obj: &'a PyAny) -> PyResult { + let mut array = [T::default(); $N]; + extract_sequence_into_slice(obj, &mut array)?; + Ok(array) + } + + #[cfg(feature = "nightly")] + default fn extract(obj: &'a PyAny) -> PyResult { + let mut array = [T::default(); $N]; + extract_sequence_into_slice(obj, &mut array)?; + Ok(array) + } + } + + #[cfg(feature = "nightly")] + impl<'source, T> FromPyObject<'source> for [T; $N] + where + for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element, + { + fn extract(obj: &'source PyAny) -> PyResult { + let mut array = [T::default(); $N]; + // first try buffer protocol + if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 { + if let Ok(buf) = crate::buffer::PyBuffer::get(obj) { + if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() { + buf.release(obj.py()); + return Ok(array); + } + buf.release(obj.py()); + } + } + // fall back to sequence protocol + extract_sequence_into_slice(obj, &mut array)?; + Ok(array) + } + } + )+ + } +} + +#[cfg(not(min_const_generics))] +array_impls!( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32 +); + +#[cfg(not(min_const_generics))] +fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()> where T: FromPyObject<'s>, { @@ -189,7 +187,7 @@ where Ok(()) } -pub fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr { +fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr { exceptions::PyValueError::new_err(format!( "expected a sequence of length {} (got {})", expected, actual @@ -198,7 +196,7 @@ pub fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr { #[cfg(test)] mod test { - use crate::Python; + use crate::{Python, PyResult}; #[cfg(min_const_generics)] use std::{ panic, @@ -238,6 +236,17 @@ mod test { assert!(&v == b"abc"); } + #[test] + fn test_extract_invalid_sequence_length() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let v: PyResult<[u8; 3]> = py + .eval("bytearray(b'abcdefg')", None, None) + .unwrap() + .extract(); + assert_eq!(v.unwrap_err().to_string(), "ValueError: expected a sequence of length 3 (got 7)"); + } + #[cfg(min_const_generics)] #[test] fn test_extract_bytearray_to_array() { diff --git a/src/conversions/mod.rs b/src/conversions/mod.rs index 9be3ba9fb4a..60c57fc96a0 100644 --- a/src/conversions/mod.rs +++ b/src/conversions/mod.rs @@ -1,5 +1,4 @@ -//! This module contains conversions between non-String Rust object and their string representation -//! in Python +//! This module contains conversions between various Rust object and their representation in Python. mod array; mod osstr;