Skip to content

Commit

Permalink
Merge pull request #611 from jturner314/accumulate-axis-inplace
Browse files Browse the repository at this point in the history
Add accumulate_axis_inplace method and implement NdProducer for RawArrayView/Mut
  • Loading branch information
bluss committed Sep 18, 2019
2 parents d4dd6f5 + 886b6a1 commit 642a44c
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 6 deletions.
56 changes: 56 additions & 0 deletions src/impl_methods.rs
Expand Up @@ -2231,4 +2231,60 @@ where
})
}
}

/// Iterates over pairs of consecutive elements along the axis.
///
/// The first argument to the closure is an element, and the second
/// argument is the next element along the axis. Iteration is guaranteed to
/// proceed in order along the specified axis, but in all other respects
/// the iteration order is unspecified.
///
/// # Example
///
/// For example, this can be used to compute the cumulative sum along an
/// axis:
///
/// ```
/// use ndarray::{array, Axis};
///
/// let mut arr = array![
/// [[1, 2], [3, 4], [5, 6]],
/// [[7, 8], [9, 10], [11, 12]],
/// ];
/// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
/// assert_eq!(
/// arr,
/// array![
/// [[1, 2], [4, 6], [9, 12]],
/// [[7, 8], [16, 18], [27, 30]],
/// ],
/// );
/// ```
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F)
where
F: FnMut(&A, &mut A),
S: DataMut,
{
if self.len_of(axis) <= 1 {
return;
}
let mut curr = self.raw_view_mut(); // mut borrow of the array here
let mut prev = curr.raw_view(); // derive further raw views from the same borrow
prev.slice_axis_inplace(axis, Slice::from(..-1));
curr.slice_axis_inplace(axis, Slice::from(1..));
// This implementation relies on `Zip` iterating along `axis` in order.
Zip::from(prev).and(curr).apply(|prev, curr| unsafe {
// These pointer dereferences and borrows are safe because:
//
// 1. They're pointers to elements in the array.
//
// 2. `S: DataMut` guarantees that elements are safe to borrow
// mutably and that they don't alias.
//
// 3. The lifetimes of the borrows last only for the duration
// of the call to `f`, so aliasing across calls to `f`
// cannot occur.
f(&*prev, &mut *curr)
});
}
}
118 changes: 116 additions & 2 deletions src/zip/mod.rs
Expand Up @@ -47,7 +47,7 @@ where

impl<S, D> ArrayBase<S, D>
where
S: Data,
S: RawData,
D: Dimension,
{
pub(crate) fn layout_impl(&self) -> Layout {
Expand All @@ -57,7 +57,7 @@ where
} else {
CORDER
}
} else if self.ndim() > 1 && self.t().is_standard_layout() {
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
FORDER
} else {
0
Expand Down Expand Up @@ -192,6 +192,14 @@ pub trait Offset: Copy {
private_decl! {}
}

impl<T> Offset for *const T {
type Stride = isize;
unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
self.offset(s * (index as isize))
}
private_impl! {}
}

impl<T> Offset for *mut T {
type Stride = isize;
unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
Expand Down Expand Up @@ -389,6 +397,112 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> {
}
}

impl<A, D: Dimension> NdProducer for RawArrayView<A, D> {
type Item = *const A;
type Dim = D;
type Ptr = *const A;
type Stride = isize;

private_impl! {}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}

#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}

#[doc(hidden)]
fn as_ptr(&self) -> *const A {
self.as_ptr()
}

#[doc(hidden)]
fn layout(&self) -> Layout {
self.layout_impl()
}

#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *const A) -> *const A {
ptr
}

#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A {
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
}

#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.stride_of(axis)
}

#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
1
}

#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}

impl<A, D: Dimension> NdProducer for RawArrayViewMut<A, D> {
type Item = *mut A;
type Dim = D;
type Ptr = *mut A;
type Stride = isize;

private_impl! {}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}

#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}

#[doc(hidden)]
fn as_ptr(&self) -> *mut A {
self.as_ptr() as _
}

#[doc(hidden)]
fn layout(&self) -> Layout {
self.layout_impl()
}

#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *mut A) -> *mut A {
ptr
}

#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
}

#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.stride_of(axis)
}

#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
1
}

#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}

/// Lock step function application across several arrays or other producers.
///
/// Zip allows matching several producers to each other elementwise and applying
Expand Down
50 changes: 46 additions & 4 deletions tests/array.rs
Expand Up @@ -1105,7 +1105,7 @@ fn owned_array_with_stride() {

#[test]
fn owned_array_discontiguous() {
use ::std::iter::repeat;
use std::iter::repeat;
let v: Vec<_> = (0..12).flat_map(|x| repeat(x).take(2)).collect();
let dim = (3, 2, 2);
let strides = (8, 4, 2);
Expand All @@ -1118,9 +1118,9 @@ fn owned_array_discontiguous() {

#[test]
fn owned_array_discontiguous_drop() {
use ::std::cell::RefCell;
use ::std::collections::BTreeSet;
use ::std::rc::Rc;
use std::cell::RefCell;
use std::collections::BTreeSet;
use std::rc::Rc;

struct InsertOnDrop<T: Ord>(Rc<RefCell<BTreeSet<T>>>, Option<T>);
impl<T: Ord> Drop for InsertOnDrop<T> {
Expand Down Expand Up @@ -1959,6 +1959,48 @@ fn test_map_axis() {
itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4);
}

#[test]
fn test_accumulate_axis_inplace_noop() {
let mut a = Array2::<u8>::zeros((0, 3));
a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
assert_eq!(a, Array2::zeros((0, 3)));

let mut a = Array2::<u8>::zeros((3, 1));
a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
assert_eq!(a, Array2::zeros((3, 1)));
}

#[rustfmt::skip] // Allow block array formatting
#[test]
fn test_accumulate_axis_inplace_nonstandard_layout() {
let a = arr2(&[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10,11,12]]);

let mut a_t = a.clone().reversed_axes();
a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
assert_eq!(a_t, aview2(&[[1, 4, 7, 10],
[3, 9, 15, 21],
[6, 15, 24, 33]]));

let mut a0 = a.clone();
a0.invert_axis(Axis(0));
a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
assert_eq!(a0, aview2(&[[10, 11, 12],
[17, 19, 21],
[21, 24, 27],
[22, 26, 30]]));

let mut a1 = a.clone();
a1.invert_axis(Axis(1));
a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
assert_eq!(a1, aview2(&[[3, 5, 6],
[6, 11, 15],
[9, 17, 24],
[12, 23, 33]]));
}

#[test]
fn test_to_vec() {
let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
Expand Down

0 comments on commit 642a44c

Please sign in to comment.