From c67e94dd855b3e3043fc0d23b27816a740da06a1 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 23 Jun 2020 11:35:02 +0100 Subject: [PATCH] Add ability to return from `__next__` / `__anext__` --- CHANGELOG.md | 1 + examples/rustapi_module/setup.py | 1 + examples/rustapi_module/src/lib.rs | 1 + examples/rustapi_module/src/pyclass_iter.rs | 34 ++++++++ .../rustapi_module/tests/test_pyclass_iter.py | 15 ++++ src/class/iter.rs | 77 ++++++++++++++++--- src/class/pyasync.rs | 39 +++++++--- 7 files changed, 150 insertions(+), 18 deletions(-) create mode 100644 examples/rustapi_module/src/pyclass_iter.rs create mode 100644 examples/rustapi_module/tests/test_pyclass_iter.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c4cc464629..d7c15ab3c83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add `PyByteArray::data`, `PyByteArray::as_bytes`, and `PyByteArray::as_bytes_mut`. [#967](https://github.com/PyO3/pyo3/pull/967) - Add `GILOnceCell` to use in situations where `lazy_static` or `once_cell` can deadlock. [#975](https://github.com/PyO3/pyo3/pull/975) - Add `Py::borrow`, `Py::borrow_mut`, `Py::try_borrow`, and `Py::try_borrow_mut` for accessing `#[pyclass]` values. [#976](https://github.com/PyO3/pyo3/pull/976) +- Add `IterNextOutput` and `IterANextOutput` for returning from `__next__` / `__anext__`. [#997](https://github.com/PyO3/pyo3/pull/997) ### Changed - Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934) diff --git a/examples/rustapi_module/setup.py b/examples/rustapi_module/setup.py index c90755f5c71..f1fe9002891 100644 --- a/examples/rustapi_module/setup.py +++ b/examples/rustapi_module/setup.py @@ -99,6 +99,7 @@ def make_rust_extension(module_name): make_rust_extension("rustapi_module.othermod"), make_rust_extension("rustapi_module.subclassing"), make_rust_extension("rustapi_module.test_dict"), + make_rust_extension("rustapi_module.pyclass_iter"), ], install_requires=install_requires, tests_require=tests_require, diff --git a/examples/rustapi_module/src/lib.rs b/examples/rustapi_module/src/lib.rs index 588ffa7239c..ce63565a494 100644 --- a/examples/rustapi_module/src/lib.rs +++ b/examples/rustapi_module/src/lib.rs @@ -3,4 +3,5 @@ pub mod datetime; pub mod dict_iter; pub mod objstore; pub mod othermod; +pub mod pyclass_iter; pub mod subclassing; diff --git a/examples/rustapi_module/src/pyclass_iter.rs b/examples/rustapi_module/src/pyclass_iter.rs new file mode 100644 index 00000000000..bb09e260699 --- /dev/null +++ b/examples/rustapi_module/src/pyclass_iter.rs @@ -0,0 +1,34 @@ +use pyo3::class::iter::{IterNextOutput, PyIterProtocol}; +use pyo3::prelude::*; + +/// This is for demonstrating how to return a value from __next__ +#[pyclass] +struct PyClassIter { + count: usize, +} + +#[pymethods] +impl PyClassIter { + #[new] + pub fn new() -> Self { + PyClassIter { count: 0 } + } +} + +#[pyproto] +impl PyIterProtocol for PyClassIter { + fn __next__(mut slf: PyRefMut) -> IterNextOutput { + if slf.count < 5 { + slf.count += 1; + IterNextOutput::Yield(slf.count) + } else { + IterNextOutput::Return("Ended") + } + } +} + +#[pymodule] +pub fn pyclass_iter(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/examples/rustapi_module/tests/test_pyclass_iter.py b/examples/rustapi_module/tests/test_pyclass_iter.py new file mode 100644 index 00000000000..f69eab361fb --- /dev/null +++ b/examples/rustapi_module/tests/test_pyclass_iter.py @@ -0,0 +1,15 @@ +import pytest +from rustapi_module import pyclass_iter + + +def test_iter(): + i = pyclass_iter.PyClassIter() + assert next(i) == 1 + assert next(i) == 2 + assert next(i) == 3 + assert next(i) == 4 + assert next(i) == 5 + + with pytest.raises(StopIteration) as excinfo: + next(i) + assert excinfo.value.value == "Ended" diff --git a/src/class/iter.rs b/src/class/iter.rs index bb3e3a9d10d..4062845abec 100644 --- a/src/class/iter.rs +++ b/src/class/iter.rs @@ -9,6 +9,40 @@ use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, Python}; /// Python Iterator Interface. /// +/// # Example +/// +/// The following example shows how to implement a simple Python iterator in Rust which yields +/// the integers 1 to 5, before raising `StopIteration('Ended')`. +/// +/// ```rust +/// use pyo3::prelude::*; +/// use pyo3::PyIterProtocol; +/// use pyo3::class::iter::IterNextOutput; +/// +/// #[pyclass] +/// struct Iter { +/// count: usize +/// } +/// +/// #[pyproto] +/// impl PyIterProtocol for Iter { +/// fn __next__(mut slf: PyRefMut) -> IterNextOutput { +/// if slf.count < 5 { +/// slf.count += 1; +/// IterNextOutput::Yield(slf.count) +/// } else { +/// IterNextOutput::Return("Ended") +/// } +/// } +/// } +/// +/// # let gil = Python::acquire_gil(); +/// # let py = gil.python(); +/// # let inst = Py::new(py, Iter { count: 0 }).unwrap(); +/// # pyo3::py_run!(py, inst, "assert next(inst) == 1"); +/// # // test of StopIteration is done in examples/rustapi_module/pyclass_iter.rs +/// ``` +/// /// Check [CPython doc](https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_iter) /// for more. #[allow(unused_variables)] @@ -35,7 +69,7 @@ pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> { pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> { type Receiver: TryFromPyCell<'p, Self>; - type Result: IntoPyCallbackOutput; + type Result: IntoPyCallbackOutput; } #[derive(Default)] @@ -64,22 +98,47 @@ impl PyIterMethods { } } -pub struct IterNextOutput(Option); +/// Output of `__next__` which can either `yield` the next value in the iteration, or +/// `return` a value to raise `StopIteration` in Python. +/// +/// See [`PyIterProtocol`] for an example. +pub enum IterNextOutput { + Yield(T), + Return(U), +} + +pub type PyIterNextOutput = IterNextOutput; -impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput { +impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterNextOutput { fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> { - match self.0 { - Some(o) => Ok(o.into_ptr()), - None => Err(crate::exceptions::StopIteration::py_err(())), + match self { + IterNextOutput::Yield(o) => Ok(o.into_ptr()), + IterNextOutput::Return(opt) => Err(crate::exceptions::StopIteration::py_err((opt,))), } } } -impl IntoPyCallbackOutput for Option +impl IntoPyCallbackOutput for IterNextOutput where T: IntoPy, + U: IntoPy, { - fn convert(self, py: Python) -> PyResult { - Ok(IterNextOutput(self.map(|o| o.into_py(py)))) + fn convert(self, py: Python) -> PyResult { + match self { + IterNextOutput::Yield(o) => Ok(IterNextOutput::Yield(o.into_py(py))), + IterNextOutput::Return(o) => Ok(IterNextOutput::Return(o.into_py(py))), + } + } +} + +impl IntoPyCallbackOutput for Option +where + T: IntoPy, +{ + fn convert(self, py: Python) -> PyResult { + match self { + Some(o) => Ok(PyIterNextOutput::Yield(o.into_py(py))), + None => Ok(PyIterNextOutput::Return(py.None())), + } } } diff --git a/src/class/pyasync.rs b/src/class/pyasync.rs index 83df141053d..a748b60ed44 100644 --- a/src/class/pyasync.rs +++ b/src/class/pyasync.rs @@ -71,7 +71,7 @@ pub trait PyAsyncAiterProtocol<'p>: PyAsyncProtocol<'p> { pub trait PyAsyncAnextProtocol<'p>: PyAsyncProtocol<'p> { type Receiver: TryFromPyCell<'p, Self>; - type Result: IntoPyCallbackOutput; + type Result: IntoPyCallbackOutput; } pub trait PyAsyncAenterProtocol<'p>: PyAsyncProtocol<'p> { @@ -107,23 +107,44 @@ impl ffi::PyAsyncMethods { } } -pub struct IterANextOutput(Option); +pub enum IterANextOutput { + Yield(T), + Return(U), +} + +pub type PyIterANextOutput = IterANextOutput; -impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterANextOutput { +impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterANextOutput { fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> { - match self.0 { - Some(o) => Ok(o.into_ptr()), - None => Err(crate::exceptions::StopAsyncIteration::py_err(())), + match self { + IterANextOutput::Yield(o) => Ok(o.into_ptr()), + IterANextOutput::Return(opt) => Err(crate::exceptions::StopAsyncIteration::py_err((opt,))), + } + } +} + +impl IntoPyCallbackOutput for IterANextOutput +where + T: IntoPy, + U: IntoPy, +{ + fn convert(self, py: Python) -> PyResult { + match self { + IterANextOutput::Yield(o) => Ok(IterANextOutput::Yield(o.into_py(py))), + IterANextOutput::Return(o) => Ok(IterANextOutput::Return(o.into_py(py))), } } } -impl IntoPyCallbackOutput for Option +impl IntoPyCallbackOutput for Option where T: IntoPy, { - fn convert(self, py: Python) -> PyResult { - Ok(IterANextOutput(self.map(|o| o.into_py(py)))) + fn convert(self, py: Python) -> PyResult { + match self { + Some(o) => Ok(PyIterANextOutput::Yield(o.into_py(py))), + None => Ok(PyIterANextOutput::Return(py.None())), + } } }