From 3faaa7e4fdde14ed9194d3bebc27009b51b09024 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 26 Mar 2019 23:59:11 -0400 Subject: [PATCH] Add accumulate_axis_inplace method --- src/impl_methods.rs | 56 +++++++++++++++++++++++++++++++++++++++++++++ tests/array.rs | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 444d98525..6ea8d07ac 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2138,4 +2138,60 @@ where } }) } + + /// Iterates over pairs of consecutive elements along the axis. + /// + /// The first argument to the closure is an element, and the second + /// argument is the next element along the axis. Iteration is guaranteed to + /// proceed in order along the specified axis, but in all other respects + /// the iteration order is unspecified. + /// + /// # Example + /// + /// For example, this can be used to compute the cumulative sum along an + /// axis: + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let mut arr = array![ + /// [[1, 2], [3, 4], [5, 6]], + /// [[7, 8], [9, 10], [11, 12]], + /// ]; + /// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + /// assert_eq!( + /// arr, + /// array![ + /// [[1, 2], [4, 6], [9, 12]], + /// [[7, 8], [16, 18], [27, 30]], + /// ], + /// ); + /// ``` + pub fn accumulate_axis_inplace(&mut self, axis: Axis, mut f: F) + where + F: FnMut(&A, &mut A), + S: DataMut, + { + if self.len_of(axis) <= 1 { + return; + } + let mut prev = self.raw_view(); + prev.slice_axis_inplace(axis, Slice::from(..-1)); + let mut curr = self.raw_view_mut(); + curr.slice_axis_inplace(axis, Slice::from(1..)); + // This implementation relies on `Zip` iterating along `axis` in order. + Zip::from(prev).and(curr).apply(|prev, curr| unsafe { + // These pointer dereferences and borrows are safe because: + // + // 1. They're pointers to elements in the array. + // + // 2. `S: DataMut` guarantees that elements are safe to borrow + // mutably and that they don't alias. + // + // 3. The lifetimes of the borrows last only for the duration + // of the call to `f`, so aliasing across calls to `f` + // cannot occur. + f(&*prev, &mut *curr) + }); + } } diff --git a/tests/array.rs b/tests/array.rs index 28e2e7fbc..c1e0c5b24 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -2004,6 +2004,47 @@ fn test_map_axis() { assert_eq!(c, answer2); } +#[test] +fn test_accumulate_axis_inplace_noop() { + let mut a = Array2::::zeros((0, 3)); + a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a, Array2::zeros((0, 3))); + + let mut a = Array2::::zeros((3, 1)); + a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + assert_eq!(a, Array2::zeros((3, 1))); +} + +#[test] +fn test_accumulate_axis_inplace_nonstandard_layout() { + let a = arr2(&[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10,11,12]]); + + let mut a_t = a.clone().reversed_axes(); + a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a_t, aview2(&[[1, 4, 7, 10], + [3, 9, 15, 21], + [6, 15, 24, 33]])); + + let mut a0 = a.clone(); + a0.invert_axis(Axis(0)); + a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a0, aview2(&[[10, 11, 12], + [17, 19, 21], + [21, 24, 27], + [22, 26, 30]])); + + let mut a1 = a.clone(); + a1.invert_axis(Axis(1)); + a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + assert_eq!(a1, aview2(&[[3, 5, 6], + [6, 11, 15], + [9, 17, 24], + [12, 23, 33]])); +} + #[test] fn test_to_vec() { let mut a = arr2(&[[1, 2, 3],