Skip to content

Commit

Permalink
Also relax the PySequence check when extracting fixed-sized arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Oct 11, 2022
1 parent c9b26f5 commit 7863dc6
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 9 deletions.
1 change: 1 addition & 0 deletions newsfragments/2675.fixed.md
@@ -0,0 +1 @@
`impl FromPyObject for [T; N]` will accept anything passing `PySequence_Check`, e.g. NumPy arrays, in the same way that `impl FromPyObject for Vec<T>` was modified after [#2477](https://github.com/PyO3/pyo3/pull/2477).
6 changes: 6 additions & 0 deletions pytests/src/sequence.rs
Expand Up @@ -6,6 +6,11 @@ fn vec_to_vec_i32(vec: Vec<i32>) -> Vec<i32> {
vec
}

#[pyfunction]
fn array_to_array_i32(arr: [i32; 3]) -> [i32; 3] {
arr
}

#[pyfunction]
fn vec_to_vec_pystring(vec: Vec<&PyString>) -> Vec<&PyString> {
vec
Expand All @@ -14,6 +19,7 @@ fn vec_to_vec_pystring(vec: Vec<&PyString>) -> Vec<&PyString> {
#[pymodule]
pub fn sequence(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(vec_to_vec_i32, m)?)?;
m.add_function(wrap_pyfunction!(array_to_array_i32, m)?)?;
m.add_function(wrap_pyfunction!(vec_to_vec_pystring, m)?)?;
Ok(())
}
6 changes: 6 additions & 0 deletions pytests/tests/test_sequence.py
Expand Up @@ -29,3 +29,9 @@ def test_vec_from_array():
import numpy

assert sequence.vec_to_vec_i32(numpy.array([1, 2, 3])) == [1, 2, 3]


def test_rust_array_from_array():
import numpy

assert sequence.array_to_array_i32(numpy.array([1, 2, 3])) == [1, 2, 3]
36 changes: 28 additions & 8 deletions src/conversions/array.rs
Expand Up @@ -3,9 +3,11 @@ use crate::{exceptions, PyErr};
#[cfg(min_const_generics)]
mod min_const_generics {
use super::invalid_sequence_length;
use crate::conversion::IntoPyPointer;
use crate::conversion::{AsPyPointer, IntoPyPointer};
use crate::types::PySequence;
use crate::{
ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject,
ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, PyTryFrom,
Python, ToPyObject,
};

impl<T, const N: usize> IntoPy<PyObject> for [T; N]
Expand Down Expand Up @@ -61,8 +63,16 @@ mod min_const_generics {
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let seq_len = seq.len()? as usize;
// Types that pass `PySequence_Check` usually implement enough of the sequence protocol
// to support this function and if not, we will only fail extraction safely.
let seq = unsafe {
if ffi::PySequence_Check(obj.as_ptr()) != 0 {
<PySequence as PyTryFrom>::try_from_unchecked(obj)
} else {
return Err(PyDowncastError::new(obj, "Sequence").into());
}
};
let seq_len = seq.len()?;
if seq_len != N {
return Err(invalid_sequence_length(N, seq_len));
}
Expand Down Expand Up @@ -174,9 +184,11 @@ mod min_const_generics {
#[cfg(not(min_const_generics))]
mod array_impls {
use super::invalid_sequence_length;
use crate::conversion::IntoPyPointer;
use crate::conversion::{AsPyPointer, IntoPyPointer};
use crate::types::PySequence;
use crate::{
ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject,
ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, PyTryFrom,
Python, ToPyObject,
};
use std::mem::{transmute_copy, ManuallyDrop};

Expand Down Expand Up @@ -274,8 +286,16 @@ mod array_impls {
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let seq_len = seq.len()? as usize;
// Types that pass `PySequence_Check` usually implement enough of the sequence protocol
// to support this function and if not, we will only fail extraction safely.
let seq = unsafe {
if ffi::PySequence_Check(obj.as_ptr()) != 0 {
<PySequence as PyTryFrom>::try_from_unchecked(obj)
} else {
return Err(PyDowncastError::new(obj, "Sequence").into());
}
};
let seq_len = seq.len()?;
if seq_len != slice.len() {
return Err(invalid_sequence_length(slice.len(), seq_len));
}
Expand Down
2 changes: 1 addition & 1 deletion src/types/sequence.rs
Expand Up @@ -309,7 +309,7 @@ where
}
};

let mut v = Vec::with_capacity(seq.len().unwrap_or(0) as usize);
let mut v = Vec::with_capacity(seq.len().unwrap_or(0));
for item in seq.iter()? {
v.push(item?.extract::<T>()?);
}
Expand Down

0 comments on commit 7863dc6

Please sign in to comment.