Skip to content

Commit

Permalink
Implement .nth_back() for iterators (#686)
Browse files Browse the repository at this point in the history
* Implement .nth_back() for iterators

The `.nth_back()` method was added to the `DoubleEndedIterator` trait
in Rust 1.37. Providing an implementation for `Baseiter` and
forwarding it for `Iter/Mut` improves performance.

* Split test_nth_back into multiple tests
  • Loading branch information
jturner314 authored and LukeMathWalker committed Sep 5, 2019
1 parent c916203 commit 4976d96
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -4,7 +4,7 @@ sudo: required
dist: trusty
matrix:
include:
- rust: 1.32.0
- rust: 1.37.0
env:
- FEATURES='test docs'
- RUSTFLAGS='-D warnings'
Expand Down
21 changes: 21 additions & 0 deletions benches/iter.rs
Expand Up @@ -74,6 +74,27 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) {
bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::<f32>());
}

#[bench]
fn iter_rev_step_by_contiguous(bench: &mut Bencher) {
let a = Array::linspace(0., 1., 512);
bench.iter(|| {
a.iter().rev().step_by(2).for_each(|x| {
black_box(x);
})
});
}

#[bench]
fn iter_rev_step_by_discontiguous(bench: &mut Bencher) {
let mut a = Array::linspace(0., 1., 1024);
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
bench.iter(|| {
a.iter().rev().step_by(2).for_each(|x| {
black_box(x);
})
});
}

const ZIPSZ: usize = 10_000;

#[bench]
Expand Down
24 changes: 24 additions & 0 deletions src/iterators/mod.rs
Expand Up @@ -131,6 +131,22 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1> {
unsafe { Some(self.ptr.offset(offset)) }
}

fn nth_back(&mut self, n: usize) -> Option<*mut A> {
let index = self.index?;
let len = self.dim[0] - index[0];
if n < len {
self.dim[0] -= n + 1;
let offset = <_>::stride_offset(&self.dim, &self.strides);
if index == self.dim {
self.index = None;
}
unsafe { Some(self.ptr.offset(offset)) }
} else {
self.index = None;
None
}
}

fn rfold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
where
G: FnMut(Acc, *mut A) -> Acc,
Expand Down Expand Up @@ -437,6 +453,10 @@ impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> {
either_mut!(self.inner, iter => iter.next_back())
}

fn nth_back(&mut self, n: usize) -> Option<&'a A> {
either_mut!(self.inner, iter => iter.nth_back(n))
}

fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
where
G: FnMut(Acc, Self::Item) -> Acc,
Expand Down Expand Up @@ -561,6 +581,10 @@ impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> {
either_mut!(self.inner, iter => iter.next_back())
}

fn nth_back(&mut self, n: usize) -> Option<&'a mut A> {
either_mut!(self.inner, iter => iter.nth_back(n))
}

fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
where
G: FnMut(Acc, Self::Item) -> Acc,
Expand Down
68 changes: 68 additions & 0 deletions tests/iterators.rs
Expand Up @@ -776,6 +776,74 @@ fn test_fold() {
assert_eq!(a.iter().fold(0, |acc, &x| acc + x), 1);
}

#[test]
fn nth_back_examples() {
let mut a: Array1<i32> = (0..256).collect();
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1]));
assert_eq!(a.iter().nth_back(1), Some(&a[a.len() - 2]));
assert_eq!(a.iter().nth_back(a.len() - 2), Some(&a[1]));
assert_eq!(a.iter().nth_back(a.len() - 1), Some(&a[0]));
assert_eq!(a.iter().nth_back(a.len()), None);
assert_eq!(a.iter().nth_back(a.len() + 1), None);
assert_eq!(a.iter().nth_back(a.len() + 2), None);
}

#[test]
fn nth_back_zero_n() {
let mut a: Array1<i32> = (0..256).collect();
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
let mut iter1 = a.iter();
let mut iter2 = a.iter();
for _ in 0..(a.len() + 1) {
assert_eq!(iter1.nth_back(0), iter2.next_back());
assert_eq!(iter1.len(), iter2.len());
}
}

#[test]
fn nth_back_nonzero_n() {
let mut a: Array1<i32> = (0..256).collect();
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
let mut iter1 = a.iter();
let mut iter2 = a.iter();
for _ in 0..(a.len() / 3 + 1) {
assert_eq!(iter1.nth_back(2), {
iter2.next_back();
iter2.next_back();
iter2.next_back()
});
assert_eq!(iter1.len(), iter2.len());
}
}

#[test]
fn nth_back_past_end() {
let mut a: Array1<i32> = (0..256).collect();
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
let mut iter = a.iter();
assert_eq!(iter.nth_back(a.len()), None);
assert_eq!(iter.next(), None);
}

#[test]
fn nth_back_partially_consumed() {
let mut a: Array1<i32> = (0..256).collect();
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
let mut iter = a.iter();
iter.next();
iter.next_back();
assert_eq!(iter.len(), a.len() - 2);
assert_eq!(iter.nth_back(1), Some(&a[a.len() - 3]));
assert_eq!(iter.len(), a.len() - 4);
assert_eq!(iter.nth_back(a.len() - 6), Some(&a[2]));
assert_eq!(iter.len(), 1);
assert_eq!(iter.next(), Some(&a[1]));
assert_eq!(iter.len(), 0);
assert_eq!(iter.next(), None);
assert_eq!(iter.next_back(), None);
}

#[test]
fn test_rfold() {
{
Expand Down

0 comments on commit 4976d96

Please sign in to comment.