Skip to content

Commit

Permalink
Add accumulate_axis_inplace method
Browse files Browse the repository at this point in the history
  • Loading branch information
jturner314 committed Mar 27, 2019
1 parent 26f7762 commit 3faaa7e
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/impl_methods.rs
Expand Up @@ -2138,4 +2138,60 @@ where
}
})
}

/// Iterates over pairs of consecutive elements along the axis.
///
/// The first argument to the closure is an element, and the second
/// argument is the next element along the axis. Iteration is guaranteed to
/// proceed in order along the specified axis, but in all other respects
/// the iteration order is unspecified.
///
/// # Example
///
/// For example, this can be used to compute the cumulative sum along an
/// axis:
///
/// ```
/// use ndarray::{array, Axis};
///
/// let mut arr = array![
/// [[1, 2], [3, 4], [5, 6]],
/// [[7, 8], [9, 10], [11, 12]],
/// ];
/// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
/// assert_eq!(
/// arr,
/// array![
/// [[1, 2], [4, 6], [9, 12]],
/// [[7, 8], [16, 18], [27, 30]],
/// ],
/// );
/// ```
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F)
where
F: FnMut(&A, &mut A),
S: DataMut,
{
if self.len_of(axis) <= 1 {
return;
}
let mut prev = self.raw_view();
prev.slice_axis_inplace(axis, Slice::from(..-1));
let mut curr = self.raw_view_mut();
curr.slice_axis_inplace(axis, Slice::from(1..));
// This implementation relies on `Zip` iterating along `axis` in order.
Zip::from(prev).and(curr).apply(|prev, curr| unsafe {
// These pointer dereferences and borrows are safe because:
//
// 1. They're pointers to elements in the array.
//
// 2. `S: DataMut` guarantees that elements are safe to borrow
// mutably and that they don't alias.
//
// 3. The lifetimes of the borrows last only for the duration
// of the call to `f`, so aliasing across calls to `f`
// cannot occur.
f(&*prev, &mut *curr)
});
}
}
41 changes: 41 additions & 0 deletions tests/array.rs
Expand Up @@ -2004,6 +2004,47 @@ fn test_map_axis() {
assert_eq!(c, answer2);
}

#[test]
fn test_accumulate_axis_inplace_noop() {
let mut a = Array2::<u8>::zeros((0, 3));
a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
assert_eq!(a, Array2::zeros((0, 3)));

let mut a = Array2::<u8>::zeros((3, 1));
a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
assert_eq!(a, Array2::zeros((3, 1)));
}

#[test]
fn test_accumulate_axis_inplace_nonstandard_layout() {
let a = arr2(&[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10,11,12]]);

let mut a_t = a.clone().reversed_axes();
a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
assert_eq!(a_t, aview2(&[[1, 4, 7, 10],
[3, 9, 15, 21],
[6, 15, 24, 33]]));

let mut a0 = a.clone();
a0.invert_axis(Axis(0));
a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
assert_eq!(a0, aview2(&[[10, 11, 12],
[17, 19, 21],
[21, 24, 27],
[22, 26, 30]]));

let mut a1 = a.clone();
a1.invert_axis(Axis(1));
a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
assert_eq!(a1, aview2(&[[3, 5, 6],
[6, 11, 15],
[9, 17, 24],
[12, 23, 33]]));
}

#[test]
fn test_to_vec() {
let mut a = arr2(&[[1, 2, 3],
Expand Down

0 comments on commit 3faaa7e

Please sign in to comment.