From 0ba27b882ea016c9b8c589fcd2527e0949044028 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Thu, 15 Aug 2019 20:35:35 -0400 Subject: [PATCH 1/2] Implement .nth_back() for iterators The `.nth_back()` method was added to the `DoubleEndedIterator` trait in Rust 1.37. Providing an implementation for `Baseiter` and forwarding it for `Iter/Mut` improves performance. --- .travis.yml | 2 +- benches/iter.rs | 21 ++++++++++++++++ src/iterators/mod.rs | 24 ++++++++++++++++++ tests/iterators.rs | 59 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 4ffb25ad7..c46a2c790 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ sudo: required dist: trusty matrix: include: - - rust: 1.32.0 + - rust: 1.37.0 env: - FEATURES='test docs' - RUSTFLAGS='-D warnings' diff --git a/benches/iter.rs b/benches/iter.rs index 5cc04293f..c7aeae48a 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -74,6 +74,27 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) { bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::()); } +#[bench] +fn iter_rev_step_by_contiguous(bench: &mut Bencher) { + let a = Array::linspace(0., 1., 512); + bench.iter(|| { + a.iter().rev().step_by(2).for_each(|x| { + black_box(x); + }) + }); +} + +#[bench] +fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { + let mut a = Array::linspace(0., 1., 1024); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + bench.iter(|| { + a.iter().rev().step_by(2).for_each(|x| { + black_box(x); + }) + }); +} + const ZIPSZ: usize = 10_000; #[bench] diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index b98d944f8..ca827787a 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -131,6 +131,22 @@ impl DoubleEndedIterator for Baseiter { unsafe { Some(self.ptr.offset(offset)) } } + fn nth_back(&mut self, n: usize) -> Option<*mut A> { + let index = self.index?; + let len = self.dim[0] - index[0]; + if n < len { + self.dim[0] -= n + 1; + let offset = <_>::stride_offset(&self.dim, &self.strides); + if index == self.dim { + self.index = None; + } + unsafe { Some(self.ptr.offset(offset)) } + } else { + self.index = None; + None + } + } + fn rfold(mut self, init: Acc, mut g: G) -> Acc where G: FnMut(Acc, *mut A) -> Acc, @@ -437,6 +453,10 @@ impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> { either_mut!(self.inner, iter => iter.next_back()) } + fn nth_back(&mut self, n: usize) -> Option<&'a A> { + either_mut!(self.inner, iter => iter.nth_back(n)) + } + fn rfold(self, init: Acc, g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc, @@ -561,6 +581,10 @@ impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> { either_mut!(self.inner, iter => iter.next_back()) } + fn nth_back(&mut self, n: usize) -> Option<&'a mut A> { + either_mut!(self.inner, iter => iter.nth_back(n)) + } + fn rfold(self, init: Acc, g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc, diff --git a/tests/iterators.rs b/tests/iterators.rs index a1c19c39a..fec988a3d 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -564,6 +564,65 @@ fn test_fold() { assert_eq!(a.iter().fold(0, |acc, &x| acc + x), 1); } +#[test] +fn test_nth_back() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + + { + assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1])); + assert_eq!(a.iter().nth_back(1), Some(&a[a.len() - 2])); + assert_eq!(a.iter().nth_back(a.len() - 2), Some(&a[1])); + assert_eq!(a.iter().nth_back(a.len() - 1), Some(&a[0])); + assert_eq!(a.iter().nth_back(a.len()), None); + assert_eq!(a.iter().nth_back(a.len() + 1), None); + assert_eq!(a.iter().nth_back(a.len() + 2), None); + } + + { + let mut iter1 = a.iter(); + let mut iter2 = a.iter(); + for _ in 0..(a.len() + 1) { + assert_eq!(iter1.nth_back(0), iter2.next_back()); + assert_eq!(iter1.len(), iter2.len()); + } + } + + { + let mut iter1 = a.iter(); + let mut iter2 = a.iter(); + for _ in 0..(a.len() / 3 + 1) { + assert_eq!(iter1.nth_back(2), { + iter2.next_back(); + iter2.next_back(); + iter2.next_back() + }); + assert_eq!(iter1.len(), iter2.len()); + } + } + + { + let mut iter = a.iter(); + assert_eq!(iter.nth_back(a.len()), None); + assert_eq!(iter.next(), None); + } + + { + let mut iter = a.iter(); + iter.next(); + iter.next_back(); + assert_eq!(iter.len(), a.len() - 2); + assert_eq!(iter.nth_back(1), Some(&a[a.len() - 3])); + assert_eq!(iter.len(), a.len() - 4); + assert_eq!(iter.nth_back(a.len() - 6), Some(&a[2])); + assert_eq!(iter.len(), 1); + assert_eq!(iter.next(), Some(&a[1])); + assert_eq!(iter.len(), 0); + assert_eq!(iter.next(), None); + assert_eq!(iter.next_back(), None); + } +} + #[test] fn test_rfold() { { From 43e03f26bc7bfa9f10d7ed4dd0940c9e5b929c43 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 3 Sep 2019 19:13:12 -0400 Subject: [PATCH 2/2] Split test_nth_back into multiple tests --- tests/iterators.rs | 105 ++++++++++++++++++++++++--------------------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/tests/iterators.rs b/tests/iterators.rs index fec988a3d..5449d8a26 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -565,62 +565,71 @@ fn test_fold() { } #[test] -fn test_nth_back() { +fn nth_back_examples() { let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1])); + assert_eq!(a.iter().nth_back(1), Some(&a[a.len() - 2])); + assert_eq!(a.iter().nth_back(a.len() - 2), Some(&a[1])); + assert_eq!(a.iter().nth_back(a.len() - 1), Some(&a[0])); + assert_eq!(a.iter().nth_back(a.len()), None); + assert_eq!(a.iter().nth_back(a.len() + 1), None); + assert_eq!(a.iter().nth_back(a.len() + 2), None); +} - { - assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1])); - assert_eq!(a.iter().nth_back(1), Some(&a[a.len() - 2])); - assert_eq!(a.iter().nth_back(a.len() - 2), Some(&a[1])); - assert_eq!(a.iter().nth_back(a.len() - 1), Some(&a[0])); - assert_eq!(a.iter().nth_back(a.len()), None); - assert_eq!(a.iter().nth_back(a.len() + 1), None); - assert_eq!(a.iter().nth_back(a.len() + 2), None); - } - - { - let mut iter1 = a.iter(); - let mut iter2 = a.iter(); - for _ in 0..(a.len() + 1) { - assert_eq!(iter1.nth_back(0), iter2.next_back()); - assert_eq!(iter1.len(), iter2.len()); - } +#[test] +fn nth_back_zero_n() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter1 = a.iter(); + let mut iter2 = a.iter(); + for _ in 0..(a.len() + 1) { + assert_eq!(iter1.nth_back(0), iter2.next_back()); + assert_eq!(iter1.len(), iter2.len()); } +} - { - let mut iter1 = a.iter(); - let mut iter2 = a.iter(); - for _ in 0..(a.len() / 3 + 1) { - assert_eq!(iter1.nth_back(2), { - iter2.next_back(); - iter2.next_back(); - iter2.next_back() - }); - assert_eq!(iter1.len(), iter2.len()); - } +#[test] +fn nth_back_nonzero_n() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter1 = a.iter(); + let mut iter2 = a.iter(); + for _ in 0..(a.len() / 3 + 1) { + assert_eq!(iter1.nth_back(2), { + iter2.next_back(); + iter2.next_back(); + iter2.next_back() + }); + assert_eq!(iter1.len(), iter2.len()); } +} - { - let mut iter = a.iter(); - assert_eq!(iter.nth_back(a.len()), None); - assert_eq!(iter.next(), None); - } +#[test] +fn nth_back_past_end() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter = a.iter(); + assert_eq!(iter.nth_back(a.len()), None); + assert_eq!(iter.next(), None); +} - { - let mut iter = a.iter(); - iter.next(); - iter.next_back(); - assert_eq!(iter.len(), a.len() - 2); - assert_eq!(iter.nth_back(1), Some(&a[a.len() - 3])); - assert_eq!(iter.len(), a.len() - 4); - assert_eq!(iter.nth_back(a.len() - 6), Some(&a[2])); - assert_eq!(iter.len(), 1); - assert_eq!(iter.next(), Some(&a[1])); - assert_eq!(iter.len(), 0); - assert_eq!(iter.next(), None); - assert_eq!(iter.next_back(), None); - } +#[test] +fn nth_back_partially_consumed() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter = a.iter(); + iter.next(); + iter.next_back(); + assert_eq!(iter.len(), a.len() - 2); + assert_eq!(iter.nth_back(1), Some(&a[a.len() - 3])); + assert_eq!(iter.len(), a.len() - 4); + assert_eq!(iter.nth_back(a.len() - 6), Some(&a[2])); + assert_eq!(iter.len(), 1); + assert_eq!(iter.next(), Some(&a[1])); + assert_eq!(iter.len(), 0); + assert_eq!(iter.next(), None); + assert_eq!(iter.next_back(), None); } #[test]