Skip to content

Commit

Permalink
Fix support of sequence protocol for returned lists (#730)
Browse files Browse the repository at this point in the history
* Fix support of sequence protocol for returned lists

Add the `sequence` option to PyO3's `pyclass`, so that the `sq_length`
slot is implemented [1]. Implementing this method is required for
sequence types, or Python C API functions like `PySequence_Size` will
fail with an error.

The `reversed` built-in method relies on `PySequence_*` methods. A test
reversing `NodeIndices` is added to guard against future violations of
the sequence protocol.

Fixes #696.

[1]: PyO3/pyo3#2567

* Add release note

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>
  • Loading branch information
tuxu and mtreinish committed Nov 9, 2022
1 parent 53f0e31 commit 4592bc0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
10 changes: 10 additions & 0 deletions releasenotes/notes/fix-sequence-protocol-e95246e864cc850a.yaml
@@ -0,0 +1,10 @@
---
fixes:
- |
Fixed an issue with the custom sequence return types,
:class:`~.BFSSuccessors`, :class:`~.NodeIndices`, :class:`~.EdgeList`,
:class:`~.WeightedEdgeList`, :class:`~.EdgeIndices`, and :class:`~.Chains`
where they previosuly were missing certain attributes that prevented them
being used as a sequence for certain built-in functions such as
``reversed()``.
Fixed `#696 <https://github.com/Qiskit/rustworkx/issues/696>`__.
2 changes: 1 addition & 1 deletion src/iterators.rs
Expand Up @@ -475,7 +475,7 @@ impl PyConvertToPyArray for Vec<(usize, usize, PyObject)> {
macro_rules! custom_vec_iter_impl {
($name:ident, $data:ident, $T:ty, $doc:literal) => {
#[doc = $doc]
#[pyclass(module = "rustworkx")]
#[pyclass(module = "rustworkx", sequence)]
#[derive(Clone)]
pub struct $name {
pub $data: Vec<$T>,
Expand Down
6 changes: 6 additions & 0 deletions tests/rustworkx_tests/test_custom_return_types.py
Expand Up @@ -174,6 +174,12 @@ def test_slices_negatives(self):
self.assertEqual([2, 3], slice_return)
self.assertEqual([], indices[-1:-2])

def test_reversed(self):
indices = self.dag.node_indices()
reversed_slice = indices[::-1]
reversed_elems = list(reversed(indices))
self.assertEqual(reversed_slice, reversed_elems)

def test_numpy_conversion(self):
res = self.dag.node_indexes()
np.testing.assert_array_equal(np.asarray(res, dtype=np.uintp), np.array([0, 1]))
Expand Down

0 comments on commit 4592bc0

Please sign in to comment.