From b740c33682512184bd5bd2ef04cd4d9dee4772ef Mon Sep 17 00:00:00 2001 From: pigeon Date: Tue, 19 Apr 2022 15:09:54 +0200 Subject: [PATCH] Implement ToPyObject for [T; N] (#2313) --- CHANGELOG.md | 4 ++++ src/conversions/array.rs | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32856cdff39..ddfece05260 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Implement `ToPyObject` for `[T; N]`. [#2313](https://github.com/PyO3/pyo3/pull/2313) + ## [0.16.4] - 2022-04-14 ### Added diff --git a/src/conversions/array.rs b/src/conversions/array.rs index 1b630a8aa00..614b56c46a3 100644 --- a/src/conversions/array.rs +++ b/src/conversions/array.rs @@ -14,6 +14,15 @@ mod min_const_generics { } } + impl ToPyObject for [T; N] + where + T: ToPyObject, + { + fn to_object(&self, py: Python<'_>) -> PyObject { + self.as_ref().to_object(py) + } + } + impl<'a, T, const N: usize> FromPyObject<'a> for [T; N] where T: FromPyObject<'a>, @@ -154,6 +163,15 @@ mod array_impls { } } + impl ToPyObject for [T; $N] + where + T: ToPyObject, + { + fn to_object(&self, py: Python<'_>) -> PyObject { + self.as_ref().to_object(py) + } + } + impl<'a, T> FromPyObject<'a> for [T; $N] where T: Copy + Default + FromPyObject<'a>, @@ -200,7 +218,7 @@ fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr { #[cfg(test)] mod tests { - use crate::{PyResult, Python}; + use crate::{types::PyList, PyResult, Python}; #[test] fn test_extract_small_bytearray_to_array() { @@ -213,6 +231,19 @@ mod tests { assert!(&v == b"abc"); }); } + #[test] + fn test_topyobject_array_conversion() { + use crate::ToPyObject; + Python::with_gil(|py| { + let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0]; + let pyobject = array.to_object(py); + let pylist: &PyList = pyobject.extract(py).unwrap(); + assert_eq!(pylist[0].extract::().unwrap(), 0.0); + assert_eq!(pylist[1].extract::().unwrap(), -16.0); + assert_eq!(pylist[2].extract::().unwrap(), 16.0); + assert_eq!(pylist[3].extract::().unwrap(), 42.0); + }); + } #[test] fn test_extract_invalid_sequence_length() {