diff --git a/newsfragments/2675.fixed.md b/newsfragments/2675.fixed.md new file mode 100644 index 00000000000..133432bb9d6 --- /dev/null +++ b/newsfragments/2675.fixed.md @@ -0,0 +1 @@ +`impl FromPyObject for [T; N]` will accept anything passing `PySequence_Check`, e.g. NumPy arrays, in the same way that `impl FromPyObject for Vec` was modified after [#2477](https://github.com/PyO3/pyo3/pull/2477). diff --git a/pytests/src/sequence.rs b/pytests/src/sequence.rs index d6936f54a56..5916414ee8f 100644 --- a/pytests/src/sequence.rs +++ b/pytests/src/sequence.rs @@ -6,6 +6,11 @@ fn vec_to_vec_i32(vec: Vec) -> Vec { vec } +#[pyfunction] +fn array_to_array_i32(arr: [i32; 3]) -> [i32; 3] { + arr +} + #[pyfunction] fn vec_to_vec_pystring(vec: Vec<&PyString>) -> Vec<&PyString> { vec @@ -14,6 +19,7 @@ fn vec_to_vec_pystring(vec: Vec<&PyString>) -> Vec<&PyString> { #[pymodule] pub fn sequence(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(vec_to_vec_i32, m)?)?; + m.add_function(wrap_pyfunction!(array_to_array_i32, m)?)?; m.add_function(wrap_pyfunction!(vec_to_vec_pystring, m)?)?; Ok(()) } diff --git a/pytests/tests/test_sequence.py b/pytests/tests/test_sequence.py index a7a60e0ea7f..ca571e41293 100644 --- a/pytests/tests/test_sequence.py +++ b/pytests/tests/test_sequence.py @@ -29,3 +29,9 @@ def test_vec_from_array(): import numpy assert sequence.vec_to_vec_i32(numpy.array([1, 2, 3])) == [1, 2, 3] + + +def test_rust_array_from_array(): + import numpy + + assert sequence.array_to_array_i32(numpy.array([1, 2, 3])) == [1, 2, 3] diff --git a/src/conversions/array.rs b/src/conversions/array.rs index c6e9e70c15c..271b00d29d6 100644 --- a/src/conversions/array.rs +++ b/src/conversions/array.rs @@ -3,9 +3,11 @@ use crate::{exceptions, PyErr}; #[cfg(min_const_generics)] mod min_const_generics { use super::invalid_sequence_length; - use crate::conversion::IntoPyPointer; + use crate::conversion::{AsPyPointer, IntoPyPointer}; + use crate::types::PySequence; use crate::{ - ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject, + ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, PyTryFrom, + Python, ToPyObject, }; impl IntoPy for [T; N] @@ -61,8 +63,16 @@ mod min_const_generics { where T: FromPyObject<'s>, { - let seq = ::try_from(obj)?; - let seq_len = seq.len()? as usize; + // Types that pass `PySequence_Check` usually implement enough of the sequence protocol + // to support this function and if not, we will only fail extraction safely. + let seq = unsafe { + if ffi::PySequence_Check(obj.as_ptr()) != 0 { + ::try_from_unchecked(obj) + } else { + return Err(PyDowncastError::new(obj, "Sequence").into()); + } + }; + let seq_len = seq.len()?; if seq_len != N { return Err(invalid_sequence_length(N, seq_len)); } @@ -174,9 +184,11 @@ mod min_const_generics { #[cfg(not(min_const_generics))] mod array_impls { use super::invalid_sequence_length; - use crate::conversion::IntoPyPointer; + use crate::conversion::{AsPyPointer, IntoPyPointer}; + use crate::types::PySequence; use crate::{ - ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject, + ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, PyTryFrom, + Python, ToPyObject, }; use std::mem::{transmute_copy, ManuallyDrop}; @@ -274,8 +286,16 @@ mod array_impls { where T: FromPyObject<'s>, { - let seq = ::try_from(obj)?; - let seq_len = seq.len()? as usize; + // Types that pass `PySequence_Check` usually implement enough of the sequence protocol + // to support this function and if not, we will only fail extraction safely. + let seq = unsafe { + if ffi::PySequence_Check(obj.as_ptr()) != 0 { + ::try_from_unchecked(obj) + } else { + return Err(PyDowncastError::new(obj, "Sequence").into()); + } + }; + let seq_len = seq.len()?; if seq_len != slice.len() { return Err(invalid_sequence_length(slice.len(), seq_len)); } diff --git a/src/types/sequence.rs b/src/types/sequence.rs index af7cd0c5d19..422a2db2e10 100644 --- a/src/types/sequence.rs +++ b/src/types/sequence.rs @@ -309,7 +309,7 @@ where } }; - let mut v = Vec::with_capacity(seq.len().unwrap_or(0) as usize); + let mut v = Vec::with_capacity(seq.len().unwrap_or(0)); for item in seq.iter()? { v.push(item?.extract::()?); }