From 33b056b5bef2fe7dab5f5019ad7e50e78d168aca Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Thu, 22 Jul 2021 20:22:38 +0300 Subject: [PATCH] - `PyList`, `PyTuple` and `PySequence`'s `get_item` now accepts only `usize` indices instead of `isize`. - `PyList` and `PyTuple`'s `get_item` now always uses the safe API. See `get_item_unchecked` for retrieving index without checks. --- CHANGELOG.md | 3 + benches/bench_list.rs | 18 ++- benches/bench_tuple.rs | 18 ++- pyo3-macros-backend/src/from_pyobject.rs | 2 +- src/conversions/array.rs | 2 +- src/types/dict.rs | 4 +- src/types/list.rs | 162 ++++++++++++++--------- src/types/sequence.rs | 9 +- src/types/tuple.rs | 82 ++++++++++-- 9 files changed, 214 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 047412978da..1edcdb946a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added +- Add `PyList::get_item_unchecked()` and `PyTuple::get_item_unchecked()` to get items without bounds checks. [#1733](https://github.com/PyO3/pyo3/pull/1733) - Add `PyAny::py()` as a convenience for `PyNativeType::py()`. [#1751](https://github.com/PyO3/pyo3/pull/1751) ### Changed - Change `PyErr::fetch()` to return `Option`. [#1717](https://github.com/PyO3/pyo3/pull/1717) +- `PyList`, `PyTuple` and `PySequence`'s `get_item` now accepts only `usize` indices instead of `isize`. [#1733](https://github.com/PyO3/pyo3/pull/1733) +- `PyList` and `PyTuple`'s `get_item` now return `PyResult<&PyAny> instead of panicking. [#1733](https://github.com/PyO3/pyo3/pull/1733) ### Fixed diff --git a/benches/bench_list.rs b/benches/bench_list.rs index a778a0de262..6bc1911a967 100644 --- a/benches/bench_list.rs +++ b/benches/bench_list.rs @@ -32,7 +32,22 @@ fn list_get_item(b: &mut Bencher) { let mut sum = 0; b.iter(|| { for i in 0..LEN { - sum += list.get_item(i as isize).extract::().unwrap(); + sum += list.get_item(i).unwrap().extract::().unwrap(); + } + }); +} + +fn list_get_item_unchecked(b: &mut Bencher) { + let gil = Python::acquire_gil(); + let py = gil.python(); + const LEN: usize = 50_000; + let list = PyList::new(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + unsafe { + sum += list.get_item_unchecked(i).extract::().unwrap(); + } } }); } @@ -41,6 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("iter_list", iter_list); c.bench_function("list_new", list_new); c.bench_function("list_get_item", list_get_item); + c.bench_function("list_get_item_unchecked", list_get_item_unchecked); } criterion_group!(benches, criterion_benchmark); diff --git a/benches/bench_tuple.rs b/benches/bench_tuple.rs index ee5caf3260d..e6359e8eb58 100644 --- a/benches/bench_tuple.rs +++ b/benches/bench_tuple.rs @@ -32,7 +32,22 @@ fn tuple_get_item(b: &mut Bencher) { let mut sum = 0; b.iter(|| { for i in 0..LEN { - sum += tuple.get_item(i).extract::().unwrap(); + sum += tuple.get_item(i).unwrap().extract::().unwrap(); + } + }); +} + +fn tuple_get_item_unchecked(b: &mut Bencher) { + let gil = Python::acquire_gil(); + let py = gil.python(); + const LEN: usize = 50_000; + let tuple = PyTuple::new(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + unsafe { + sum += tuple.get_item_unchecked(i).extract::().unwrap(); + } } }); } @@ -41,6 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("iter_tuple", iter_tuple); c.bench_function("tuple_new", tuple_new); c.bench_function("tuple_get_item", tuple_get_item); + c.bench_function("tuple_get_item_unchecked", tuple_get_item_unchecked); } criterion_group!(benches, criterion_benchmark); diff --git a/pyo3-macros-backend/src/from_pyobject.rs b/pyo3-macros-backend/src/from_pyobject.rs index da1144253c1..460349dfc9f 100644 --- a/pyo3-macros-backend/src/from_pyobject.rs +++ b/pyo3-macros-backend/src/from_pyobject.rs @@ -243,7 +243,7 @@ impl<'a> Container<'a> { for i in 0..len { let error_msg = format!("failed to extract field {}.{}", quote!(#self_ty), i); fields.push(quote!( - s.get_item(#i).extract().map_err(|inner| { + s.get_item(#i).and_then(PyAny::extract).map_err(|inner| { let py = pyo3::PyNativeType::py(obj); let new_err = pyo3::exceptions::PyTypeError::new_err(#error_msg); new_err.set_cause(py, Some(inner)); diff --git a/src/conversions/array.rs b/src/conversions/array.rs index c586a5d725a..92fded84f14 100644 --- a/src/conversions/array.rs +++ b/src/conversions/array.rs @@ -60,7 +60,7 @@ mod min_const_generics { if seq_len != N { return Err(invalid_sequence_length(N, seq_len)); } - array_try_from_fn(|idx| seq.get_item(idx as isize).and_then(PyAny::extract)) + array_try_from_fn(|idx| seq.get_item(idx).and_then(PyAny::extract)) } // TODO use std::array::try_from_fn, if that stabilises: diff --git a/src/types/dict.rs b/src/types/dict.rs index 13a0b34cefb..d4297584be5 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -639,8 +639,8 @@ mod tests { let mut value_sum = 0; for el in dict.items().iter() { let tuple = el.cast_as::().unwrap(); - key_sum += tuple.get_item(0).extract::().unwrap(); - value_sum += tuple.get_item(1).extract::().unwrap(); + key_sum += tuple.get_item(0).unwrap().extract::().unwrap(); + value_sum += tuple.get_item(1).unwrap().extract::().unwrap(); } assert_eq!(7 + 8 + 9, key_sum); assert_eq!(32 + 42 + 123, value_sum); diff --git a/src/types/list.rs b/src/types/list.rs index abdd89dd9f3..6bc724121eb 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -66,23 +66,39 @@ impl PyList { self.len() == 0 } - /// Gets the item at the specified index. - /// - /// Panics if the index is out of range. - pub fn get_item(&self, index: isize) -> &PyAny { - assert!(index >= 0 && index < self.len() as isize); + /// Gets the list item at the specified index. + /// # Example + /// ``` + /// use pyo3::{prelude::*, types::PyList}; + /// Python::with_gil(|py| { + /// let list = PyList::new(py, &[2, 3, 5, 7]); + /// let obj = list.get_item(0); + /// assert_eq!(obj.unwrap().extract::().unwrap(), 2); + /// }); + /// ``` + pub fn get_item(&self, index: usize) -> PyResult<&PyAny> { unsafe { - #[cfg(not(Py_LIMITED_API))] - let ptr = ffi::PyList_GET_ITEM(self.as_ptr(), index as Py_ssize_t); - #[cfg(Py_LIMITED_API)] - let ptr = ffi::PyList_GetItem(self.as_ptr(), index as Py_ssize_t); - + let item = ffi::PyList_GetItem(self.as_ptr(), index as Py_ssize_t); // PyList_GetItem return borrowed ptr; must make owned for safety (see #890). - ffi::Py_INCREF(ptr); - self.py().from_owned_ptr(ptr) + ffi::Py_XINCREF(item); + self.py().from_owned_ptr_or_err(item) } } + /// Gets the list item at the specified index. Undefined behavior on bad index. Use with caution. + /// + /// # Safety + /// + /// Caller must verify that the index is within the bounds of the list. + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + #[cfg_attr(docsrs, doc(cfg(not(any(Py_LIMITED_API, PyPy)))))] + pub unsafe fn get_item_unchecked(&self, index: usize) -> &PyAny { + let item = ffi::PyList_GET_ITEM(self.as_ptr(), index as Py_ssize_t); + // PyList_GET_ITEM return borrowed ptr; must make owned for safety (see #890). + ffi::Py_XINCREF(item); + self.py().from_owned_ptr(item) + } + /// Takes the slice `self[low:high]` and returns it as a new list. /// /// Indices must be nonnegative, and out-of-range indices are clipped to @@ -163,7 +179,7 @@ impl PyList { /// Used by `PyList::iter()`. pub struct PyListIterator<'a> { list: &'a PyList, - index: isize, + index: usize, } impl<'a> Iterator for PyListIterator<'a> { @@ -171,8 +187,11 @@ impl<'a> Iterator for PyListIterator<'a> { #[inline] fn next(&mut self) -> Option<&'a PyAny> { - if self.index < self.list.len() as isize { - let item = self.list.get_item(self.index); + if self.index < self.list.len() { + #[cfg(any(Py_LIMITED_API, PyPy))] + let item = self.list.get_item(self.index).expect("tuple.get failed"); + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + let item = unsafe { self.list.get_item_unchecked(self.index) }; self.index += 1; Some(item) } else { @@ -185,8 +204,8 @@ impl<'a> Iterator for PyListIterator<'a> { let len = self.list.len(); ( - len.saturating_sub(self.index as usize), - Some(len.saturating_sub(self.index as usize)), + len.saturating_sub(self.index), + Some(len.saturating_sub(self.index)), ) } } @@ -237,10 +256,10 @@ mod tests { fn test_new() { Python::with_gil(|py| { let list = PyList::new(py, &[2, 3, 5, 7]); - assert_eq!(2, list.get_item(0).extract::().unwrap()); - assert_eq!(3, list.get_item(1).extract::().unwrap()); - assert_eq!(5, list.get_item(2).extract::().unwrap()); - assert_eq!(7, list.get_item(3).extract::().unwrap()); + assert_eq!(2, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(3, list.get_item(1).unwrap().extract::().unwrap()); + assert_eq!(5, list.get_item(2).unwrap().extract::().unwrap()); + assert_eq!(7, list.get_item(3).unwrap().extract::().unwrap()); }); } @@ -256,19 +275,10 @@ mod tests { fn test_get_item() { Python::with_gil(|py| { let list = PyList::new(py, &[2, 3, 5, 7]); - assert_eq!(2, list.get_item(0).extract::().unwrap()); - assert_eq!(3, list.get_item(1).extract::().unwrap()); - assert_eq!(5, list.get_item(2).extract::().unwrap()); - assert_eq!(7, list.get_item(3).extract::().unwrap()); - }); - } - - #[test] - #[should_panic] - fn test_get_item_invalid() { - Python::with_gil(|py| { - let list = PyList::new(py, &[2, 3, 5, 7]); - list.get_item(-1); + assert_eq!(2, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(3, list.get_item(1).unwrap().extract::().unwrap()); + assert_eq!(5, list.get_item(2).unwrap().extract::().unwrap()); + assert_eq!(7, list.get_item(3).unwrap().extract::().unwrap()); }); } @@ -289,9 +299,9 @@ mod tests { let list = PyList::new(py, &[2, 3, 5, 7]); let val = 42i32.to_object(py); let val2 = 42i32.to_object(py); - assert_eq!(2, list.get_item(0).extract::().unwrap()); + assert_eq!(2, list.get_item(0).unwrap().extract::().unwrap()); list.set_item(0, val).unwrap(); - assert_eq!(42, list.get_item(0).extract::().unwrap()); + assert_eq!(42, list.get_item(0).unwrap().extract::().unwrap()); assert!(list.set_item(10, val2).is_err()); }); } @@ -321,13 +331,13 @@ mod tests { let val = 42i32.to_object(py); let val2 = 43i32.to_object(py); assert_eq!(4, list.len()); - assert_eq!(2, list.get_item(0).extract::().unwrap()); + assert_eq!(2, list.get_item(0).unwrap().extract::().unwrap()); list.insert(0, val).unwrap(); list.insert(1000, val2).unwrap(); assert_eq!(6, list.len()); - assert_eq!(42, list.get_item(0).extract::().unwrap()); - assert_eq!(2, list.get_item(1).extract::().unwrap()); - assert_eq!(43, list.get_item(5).extract::().unwrap()); + assert_eq!(42, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(2, list.get_item(1).unwrap().extract::().unwrap()); + assert_eq!(43, list.get_item(5).unwrap().extract::().unwrap()); }); } @@ -352,8 +362,8 @@ mod tests { Python::with_gil(|py| { let list = PyList::new(py, &[2]); list.append(3).unwrap(); - assert_eq!(2, list.get_item(0).extract::().unwrap()); - assert_eq!(3, list.get_item(1).extract::().unwrap()); + assert_eq!(2, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(3, list.get_item(1).unwrap().extract::().unwrap()); }); } @@ -430,15 +440,15 @@ mod tests { Python::with_gil(|py| { let v = vec![7, 3, 2, 5]; let list = PyList::new(py, &v); - assert_eq!(7, list.get_item(0).extract::().unwrap()); - assert_eq!(3, list.get_item(1).extract::().unwrap()); - assert_eq!(2, list.get_item(2).extract::().unwrap()); - assert_eq!(5, list.get_item(3).extract::().unwrap()); + assert_eq!(7, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(3, list.get_item(1).unwrap().extract::().unwrap()); + assert_eq!(2, list.get_item(2).unwrap().extract::().unwrap()); + assert_eq!(5, list.get_item(3).unwrap().extract::().unwrap()); list.sort().unwrap(); - assert_eq!(2, list.get_item(0).extract::().unwrap()); - assert_eq!(3, list.get_item(1).extract::().unwrap()); - assert_eq!(5, list.get_item(2).extract::().unwrap()); - assert_eq!(7, list.get_item(3).extract::().unwrap()); + assert_eq!(2, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(3, list.get_item(1).unwrap().extract::().unwrap()); + assert_eq!(5, list.get_item(2).unwrap().extract::().unwrap()); + assert_eq!(7, list.get_item(3).unwrap().extract::().unwrap()); }); } @@ -447,15 +457,15 @@ mod tests { Python::with_gil(|py| { let v = vec![2, 3, 5, 7]; let list = PyList::new(py, &v); - assert_eq!(2, list.get_item(0).extract::().unwrap()); - assert_eq!(3, list.get_item(1).extract::().unwrap()); - assert_eq!(5, list.get_item(2).extract::().unwrap()); - assert_eq!(7, list.get_item(3).extract::().unwrap()); + assert_eq!(2, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(3, list.get_item(1).unwrap().extract::().unwrap()); + assert_eq!(5, list.get_item(2).unwrap().extract::().unwrap()); + assert_eq!(7, list.get_item(3).unwrap().extract::().unwrap()); list.reverse().unwrap(); - assert_eq!(7, list.get_item(0).extract::().unwrap()); - assert_eq!(5, list.get_item(1).extract::().unwrap()); - assert_eq!(3, list.get_item(2).extract::().unwrap()); - assert_eq!(2, list.get_item(3).extract::().unwrap()); + assert_eq!(7, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(5, list.get_item(1).unwrap().extract::().unwrap()); + assert_eq!(3, list.get_item(2).unwrap().extract::().unwrap()); + assert_eq!(2, list.get_item(3).unwrap().extract::().unwrap()); }); } @@ -464,8 +474,40 @@ mod tests { Python::with_gil(|py| { let array: PyObject = [1, 2].into_py(py); let list = ::try_from(array.as_ref(py)).unwrap(); - assert_eq!(1, list.get_item(0).extract::().unwrap()); - assert_eq!(2, list.get_item(1).extract::().unwrap()); + assert_eq!(1, list.get_item(0).unwrap().extract::().unwrap()); + assert_eq!(2, list.get_item(1).unwrap().extract::().unwrap()); + }); + } + + #[test] + fn test_list_get_item_invalid_index() { + Python::with_gil(|py| { + let list = PyList::new(py, &[2, 3, 5, 7]); + let obj = list.get_item(5); + assert!(obj.is_err()); + assert_eq!( + obj.unwrap_err().to_string(), + "IndexError: list index out of range" + ); + }); + } + + #[test] + fn test_list_get_item_sanity() { + Python::with_gil(|py| { + let list = PyList::new(py, &[2, 3, 5, 7]); + let obj = list.get_item(0); + assert_eq!(obj.unwrap().extract::().unwrap(), 2); + }); + } + + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + #[test] + fn test_list_get_item_unchecked_sanity() { + Python::with_gil(|py| { + let list = PyList::new(py, &[2, 3, 5, 7]); + let obj = unsafe { list.get_item_unchecked(0) }; + assert_eq!(obj.extract::().unwrap(), 2); }); } } diff --git a/src/types/sequence.rs b/src/types/sequence.rs index fc95d32b6aa..0b70da0ee40 100644 --- a/src/types/sequence.rs +++ b/src/types/sequence.rs @@ -99,9 +99,9 @@ impl PySequence { /// Returns the `index`th element of the Sequence. /// - /// This is equivalent to the Python expression `self[index]`. + /// This is equivalent to the Python expression `self[index]` without support of negative indices. #[inline] - pub fn get_item(&self, index: isize) -> PyResult<&PyAny> { + pub fn get_item(&self, index: usize) -> PyResult<&PyAny> { unsafe { self.py() .from_owned_ptr_or_err(ffi::PySequence_GetItem(self.as_ptr(), index as Py_ssize_t)) @@ -403,11 +403,6 @@ mod tests { assert_eq!(3, seq.get_item(3).unwrap().extract::().unwrap()); assert_eq!(5, seq.get_item(4).unwrap().extract::().unwrap()); assert_eq!(8, seq.get_item(5).unwrap().extract::().unwrap()); - assert_eq!(8, seq.get_item(-1).unwrap().extract::().unwrap()); - assert_eq!(5, seq.get_item(-2).unwrap().extract::().unwrap()); - assert_eq!(3, seq.get_item(-3).unwrap().extract::().unwrap()); - assert_eq!(2, seq.get_item(-4).unwrap().extract::().unwrap()); - assert_eq!(1, seq.get_item(-5).unwrap().extract::().unwrap()); assert!(seq.get_item(10).is_err()); }); } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 29bdc49229e..3264aedf262 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -85,20 +85,36 @@ impl PyTuple { } /// Gets the tuple item at the specified index. - /// - /// Panics if the index is out of range. - pub fn get_item(&self, index: usize) -> &PyAny { - assert!(index < self.len()); + /// # Example + /// ``` + /// use pyo3::{prelude::*, types::PyTuple}; + /// Python::with_gil(|py| -> PyResult<()> { + /// let ob = (1, 2, 3).to_object(py); + /// let tuple = ::try_from(ob.as_ref(py)).unwrap(); + /// let obj = tuple.get_item(0); + /// assert_eq!(obj.unwrap().extract::().unwrap(), 1); + /// Ok(()) + /// }); + /// ``` + pub fn get_item(&self, index: usize) -> PyResult<&PyAny> { unsafe { - #[cfg(not(any(Py_LIMITED_API, PyPy)))] - let item = ffi::PyTuple_GET_ITEM(self.as_ptr(), index as Py_ssize_t); - #[cfg(any(Py_LIMITED_API, PyPy))] let item = ffi::PyTuple_GetItem(self.as_ptr(), index as Py_ssize_t); - - self.py().from_borrowed_ptr(item) + self.py().from_borrowed_ptr_or_err(item) } } + /// Gets the tuple item at the specified index. Undefined behavior on bad index. Use with caution. + /// + /// # Safety + /// + /// Caller must verify that the index is within the bounds of the tuple. + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + #[cfg_attr(docsrs, doc(cfg(not(any(Py_LIMITED_API, PyPy)))))] + pub unsafe fn get_item_unchecked(&self, index: usize) -> &PyAny { + let item = ffi::PyTuple_GET_ITEM(self.as_ptr(), index as Py_ssize_t); + self.py().from_borrowed_ptr(item) + } + /// Returns `self` as a slice of objects. #[cfg(not(Py_LIMITED_API))] #[cfg_attr(docsrs, doc(cfg(not(Py_LIMITED_API))))] @@ -135,7 +151,10 @@ impl<'a> Iterator for PyTupleIterator<'a> { #[inline] fn next(&mut self) -> Option<&'a PyAny> { if self.index < self.length { - let item = self.tuple.get_item(self.index); + #[cfg(any(Py_LIMITED_API, PyPy))] + let item = self.tuple.get_item(self.index).expect("tuple.get failed"); + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + let item = unsafe { self.tuple.get_item_unchecked(self.index) }; self.index += 1; Some(item) } else { @@ -211,9 +230,11 @@ macro_rules! tuple_conversion ({$length:expr,$(($refN:ident, $n:tt, $T:ident)),+ { let t = ::try_from(obj)?; if t.len() == $length { - Ok(( - $(t.get_item($n).extract::<$T>()?,)+ - )) + #[cfg(any(Py_LIMITED_API, PyPy))] + return Ok(($(t.get_item($n)?.extract::<$T>()?,)+)); + + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + unsafe {return Ok(($(t.get_item_unchecked($n).extract::<$T>()?,)+));} } else { Err(wrong_tuple_length(t, $length)) } @@ -480,4 +501,39 @@ mod tests { ); }) } + + #[test] + fn test_tuple_get_item_invalid_index() { + Python::with_gil(|py| { + let ob = (1, 2, 3).to_object(py); + let tuple = ::try_from(ob.as_ref(py)).unwrap(); + let obj = tuple.get_item(5); + assert!(obj.is_err()); + assert_eq!( + obj.unwrap_err().to_string(), + "IndexError: tuple index out of range" + ); + }); + } + + #[test] + fn test_tuple_get_item_sanity() { + Python::with_gil(|py| { + let ob = (1, 2, 3).to_object(py); + let tuple = ::try_from(ob.as_ref(py)).unwrap(); + let obj = tuple.get_item(0); + assert_eq!(obj.unwrap().extract::().unwrap(), 1); + }); + } + + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + #[test] + fn test_tuple_get_item_unchecked_sanity() { + Python::with_gil(|py| { + let ob = (1, 2, 3).to_object(py); + let tuple = ::try_from(ob.as_ref(py)).unwrap(); + let obj = unsafe { tuple.get_item_unchecked(0) }; + assert_eq!(obj.extract::().unwrap(), 1); + }); + } }