Skip to content

Commit

Permalink
Add ability to return from __next__ / __anext__
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jun 23, 2020
1 parent a5e3d4e commit c67e94d
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Add `PyByteArray::data`, `PyByteArray::as_bytes`, and `PyByteArray::as_bytes_mut`. [#967](https://github.com/PyO3/pyo3/pull/967)
- Add `GILOnceCell` to use in situations where `lazy_static` or `once_cell` can deadlock. [#975](https://github.com/PyO3/pyo3/pull/975)
- Add `Py::borrow`, `Py::borrow_mut`, `Py::try_borrow`, and `Py::try_borrow_mut` for accessing `#[pyclass]` values. [#976](https://github.com/PyO3/pyo3/pull/976)
- Add `IterNextOutput` and `IterANextOutput` for returning from `__next__` / `__anext__`. [#997](https://github.com/PyO3/pyo3/pull/997)

### Changed
- Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934)
Expand Down
1 change: 1 addition & 0 deletions examples/rustapi_module/setup.py
Expand Up @@ -99,6 +99,7 @@ def make_rust_extension(module_name):
make_rust_extension("rustapi_module.othermod"),
make_rust_extension("rustapi_module.subclassing"),
make_rust_extension("rustapi_module.test_dict"),
make_rust_extension("rustapi_module.pyclass_iter"),
],
install_requires=install_requires,
tests_require=tests_require,
Expand Down
1 change: 1 addition & 0 deletions examples/rustapi_module/src/lib.rs
Expand Up @@ -3,4 +3,5 @@ pub mod datetime;
pub mod dict_iter;
pub mod objstore;
pub mod othermod;
pub mod pyclass_iter;
pub mod subclassing;
34 changes: 34 additions & 0 deletions examples/rustapi_module/src/pyclass_iter.rs
@@ -0,0 +1,34 @@
use pyo3::class::iter::{IterNextOutput, PyIterProtocol};
use pyo3::prelude::*;

/// This is for demonstrating how to return a value from __next__
#[pyclass]
struct PyClassIter {
count: usize,
}

#[pymethods]
impl PyClassIter {
#[new]
pub fn new() -> Self {
PyClassIter { count: 0 }
}
}

#[pyproto]
impl PyIterProtocol for PyClassIter {
fn __next__(mut slf: PyRefMut<Self>) -> IterNextOutput<usize, &'static str> {
if slf.count < 5 {
slf.count += 1;
IterNextOutput::Yield(slf.count)
} else {
IterNextOutput::Return("Ended")
}
}
}

#[pymodule]
pub fn pyclass_iter(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyClassIter>()?;
Ok(())
}
15 changes: 15 additions & 0 deletions examples/rustapi_module/tests/test_pyclass_iter.py
@@ -0,0 +1,15 @@
import pytest
from rustapi_module import pyclass_iter


def test_iter():
i = pyclass_iter.PyClassIter()
assert next(i) == 1
assert next(i) == 2
assert next(i) == 3
assert next(i) == 4
assert next(i) == 5

with pytest.raises(StopIteration) as excinfo:
next(i)
assert excinfo.value.value == "Ended"
77 changes: 68 additions & 9 deletions src/class/iter.rs
Expand Up @@ -9,6 +9,40 @@ use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, Python};

/// Python Iterator Interface.
///
/// # Example
///
/// The following example shows how to implement a simple Python iterator in Rust which yields
/// the integers 1 to 5, before raising `StopIteration('Ended')`.
///
/// ```rust
/// use pyo3::prelude::*;
/// use pyo3::PyIterProtocol;
/// use pyo3::class::iter::IterNextOutput;
///
/// #[pyclass]
/// struct Iter {
/// count: usize
/// }
///
/// #[pyproto]
/// impl PyIterProtocol for Iter {
/// fn __next__(mut slf: PyRefMut<Self>) -> IterNextOutput<usize, &'static str> {
/// if slf.count < 5 {
/// slf.count += 1;
/// IterNextOutput::Yield(slf.count)
/// } else {
/// IterNextOutput::Return("Ended")
/// }
/// }
/// }
///
/// # let gil = Python::acquire_gil();
/// # let py = gil.python();
/// # let inst = Py::new(py, Iter { count: 0 }).unwrap();
/// # pyo3::py_run!(py, inst, "assert next(inst) == 1");
/// # // test of StopIteration is done in examples/rustapi_module/pyclass_iter.rs
/// ```
///
/// Check [CPython doc](https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_iter)
/// for more.
#[allow(unused_variables)]
Expand All @@ -35,7 +69,7 @@ pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> {

pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Result: IntoPyCallbackOutput<IterNextOutput>;
type Result: IntoPyCallbackOutput<PyIterNextOutput>;
}

#[derive(Default)]
Expand Down Expand Up @@ -64,22 +98,47 @@ impl PyIterMethods {
}
}

pub struct IterNextOutput(Option<PyObject>);
/// Output of `__next__` which can either `yield` the next value in the iteration, or
/// `return` a value to raise `StopIteration` in Python.
///
/// See [`PyIterProtocol`] for an example.
pub enum IterNextOutput<T, U> {
Yield(T),
Return(U),
}

pub type PyIterNextOutput = IterNextOutput<PyObject, PyObject>;

impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput {
impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterNextOutput {
fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> {
match self.0 {
Some(o) => Ok(o.into_ptr()),
None => Err(crate::exceptions::StopIteration::py_err(())),
match self {
IterNextOutput::Yield(o) => Ok(o.into_ptr()),
IterNextOutput::Return(opt) => Err(crate::exceptions::StopIteration::py_err((opt,))),
}
}
}

impl<T> IntoPyCallbackOutput<IterNextOutput> for Option<T>
impl<T, U> IntoPyCallbackOutput<PyIterNextOutput> for IterNextOutput<T, U>
where
T: IntoPy<PyObject>,
U: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<IterNextOutput> {
Ok(IterNextOutput(self.map(|o| o.into_py(py))))
fn convert(self, py: Python) -> PyResult<PyIterNextOutput> {
match self {
IterNextOutput::Yield(o) => Ok(IterNextOutput::Yield(o.into_py(py))),
IterNextOutput::Return(o) => Ok(IterNextOutput::Return(o.into_py(py))),
}
}
}

impl<T> IntoPyCallbackOutput<PyIterNextOutput> for Option<T>
where
T: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<PyIterNextOutput> {
match self {
Some(o) => Ok(PyIterNextOutput::Yield(o.into_py(py))),
None => Ok(PyIterNextOutput::Return(py.None())),
}
}
}
39 changes: 30 additions & 9 deletions src/class/pyasync.rs
Expand Up @@ -71,7 +71,7 @@ pub trait PyAsyncAiterProtocol<'p>: PyAsyncProtocol<'p> {

pub trait PyAsyncAnextProtocol<'p>: PyAsyncProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Result: IntoPyCallbackOutput<IterANextOutput>;
type Result: IntoPyCallbackOutput<PyIterANextOutput>;
}

pub trait PyAsyncAenterProtocol<'p>: PyAsyncProtocol<'p> {
Expand Down Expand Up @@ -107,23 +107,44 @@ impl ffi::PyAsyncMethods {
}
}

pub struct IterANextOutput(Option<PyObject>);
pub enum IterANextOutput<T, U> {
Yield(T),
Return(U),
}

pub type PyIterANextOutput = IterANextOutput<PyObject, PyObject>;

impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterANextOutput {
impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterANextOutput {
fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> {
match self.0 {
Some(o) => Ok(o.into_ptr()),
None => Err(crate::exceptions::StopAsyncIteration::py_err(())),
match self {
IterANextOutput::Yield(o) => Ok(o.into_ptr()),
IterANextOutput::Return(opt) => Err(crate::exceptions::StopAsyncIteration::py_err((opt,))),
}
}
}

impl<T, U> IntoPyCallbackOutput<PyIterANextOutput> for IterANextOutput<T, U>
where
T: IntoPy<PyObject>,
U: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<PyIterANextOutput> {
match self {
IterANextOutput::Yield(o) => Ok(IterANextOutput::Yield(o.into_py(py))),
IterANextOutput::Return(o) => Ok(IterANextOutput::Return(o.into_py(py))),
}
}
}

impl<T> IntoPyCallbackOutput<IterANextOutput> for Option<T>
impl<T> IntoPyCallbackOutput<PyIterANextOutput> for Option<T>
where
T: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<IterANextOutput> {
Ok(IterANextOutput(self.map(|o| o.into_py(py))))
fn convert(self, py: Python) -> PyResult<PyIterANextOutput> {
match self {
Some(o) => Ok(PyIterANextOutput::Yield(o.into_py(py))),
None => Ok(PyIterANextOutput::Return(py.None())),
}
}
}

Expand Down

0 comments on commit c67e94d

Please sign in to comment.