diff --git a/.travis.yml b/.travis.yml index 4ffb25ad7..c46a2c790 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ sudo: required dist: trusty matrix: include: - - rust: 1.32.0 + - rust: 1.37.0 env: - FEATURES='test docs' - RUSTFLAGS='-D warnings' diff --git a/benches/iter.rs b/benches/iter.rs index 5cc04293f..c7aeae48a 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -74,6 +74,27 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) { bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::()); } +#[bench] +fn iter_rev_step_by_contiguous(bench: &mut Bencher) { + let a = Array::linspace(0., 1., 512); + bench.iter(|| { + a.iter().rev().step_by(2).for_each(|x| { + black_box(x); + }) + }); +} + +#[bench] +fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { + let mut a = Array::linspace(0., 1., 1024); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + bench.iter(|| { + a.iter().rev().step_by(2).for_each(|x| { + black_box(x); + }) + }); +} + const ZIPSZ: usize = 10_000; #[bench] diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 882eaaa77..fb17f406e 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -131,6 +131,22 @@ impl DoubleEndedIterator for Baseiter { unsafe { Some(self.ptr.offset(offset)) } } + fn nth_back(&mut self, n: usize) -> Option<*mut A> { + let index = self.index?; + let len = self.dim[0] - index[0]; + if n < len { + self.dim[0] -= n + 1; + let offset = <_>::stride_offset(&self.dim, &self.strides); + if index == self.dim { + self.index = None; + } + unsafe { Some(self.ptr.offset(offset)) } + } else { + self.index = None; + None + } + } + fn rfold(mut self, init: Acc, mut g: G) -> Acc where G: FnMut(Acc, *mut A) -> Acc, @@ -437,6 +453,10 @@ impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> { either_mut!(self.inner, iter => iter.next_back()) } + fn nth_back(&mut self, n: usize) -> Option<&'a A> { + either_mut!(self.inner, iter => iter.nth_back(n)) + } + fn rfold(self, init: Acc, g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc, @@ -561,6 +581,10 @@ impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> { either_mut!(self.inner, iter => iter.next_back()) } + fn nth_back(&mut self, n: usize) -> Option<&'a mut A> { + either_mut!(self.inner, iter => iter.nth_back(n)) + } + fn rfold(self, init: Acc, g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc, diff --git a/tests/iterators.rs b/tests/iterators.rs index 325aa9797..371339b96 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -776,6 +776,74 @@ fn test_fold() { assert_eq!(a.iter().fold(0, |acc, &x| acc + x), 1); } +#[test] +fn nth_back_examples() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1])); + assert_eq!(a.iter().nth_back(1), Some(&a[a.len() - 2])); + assert_eq!(a.iter().nth_back(a.len() - 2), Some(&a[1])); + assert_eq!(a.iter().nth_back(a.len() - 1), Some(&a[0])); + assert_eq!(a.iter().nth_back(a.len()), None); + assert_eq!(a.iter().nth_back(a.len() + 1), None); + assert_eq!(a.iter().nth_back(a.len() + 2), None); +} + +#[test] +fn nth_back_zero_n() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter1 = a.iter(); + let mut iter2 = a.iter(); + for _ in 0..(a.len() + 1) { + assert_eq!(iter1.nth_back(0), iter2.next_back()); + assert_eq!(iter1.len(), iter2.len()); + } +} + +#[test] +fn nth_back_nonzero_n() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter1 = a.iter(); + let mut iter2 = a.iter(); + for _ in 0..(a.len() / 3 + 1) { + assert_eq!(iter1.nth_back(2), { + iter2.next_back(); + iter2.next_back(); + iter2.next_back() + }); + assert_eq!(iter1.len(), iter2.len()); + } +} + +#[test] +fn nth_back_past_end() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter = a.iter(); + assert_eq!(iter.nth_back(a.len()), None); + assert_eq!(iter.next(), None); +} + +#[test] +fn nth_back_partially_consumed() { + let mut a: Array1 = (0..256).collect(); + a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); + let mut iter = a.iter(); + iter.next(); + iter.next_back(); + assert_eq!(iter.len(), a.len() - 2); + assert_eq!(iter.nth_back(1), Some(&a[a.len() - 3])); + assert_eq!(iter.len(), a.len() - 4); + assert_eq!(iter.nth_back(a.len() - 6), Some(&a[2])); + assert_eq!(iter.len(), 1); + assert_eq!(iter.next(), Some(&a[1])); + assert_eq!(iter.len(), 0); + assert_eq!(iter.next(), None); + assert_eq!(iter.next_back(), None); +} + #[test] fn test_rfold() { {