From 71c8c8fd6403a7ec24eadf2da658439f1054d414 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 26 Mar 2019 23:33:47 -0400 Subject: [PATCH 1/4] Implement NdProducer for RawArrayView/Mut --- src/zip/mod.rs | 118 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 2 deletions(-) diff --git a/src/zip/mod.rs b/src/zip/mod.rs index 940d87cc3..ccc123db7 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -47,7 +47,7 @@ where impl ArrayBase where - S: Data, + S: RawData, D: Dimension, { pub(crate) fn layout_impl(&self) -> Layout { @@ -57,7 +57,7 @@ where } else { CORDER } - } else if self.ndim() > 1 && self.t().is_standard_layout() { + } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { FORDER } else { 0 @@ -192,6 +192,14 @@ pub trait Offset: Copy { private_decl! {} } +impl Offset for *const T { + type Stride = isize; + unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { + self.offset(s * (index as isize)) + } + private_impl! {} +} + impl Offset for *mut T { type Stride = isize; unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { @@ -389,6 +397,112 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> { } } +impl NdProducer for RawArrayView { + type Item = *const A; + type Dim = D; + type Ptr = *const A; + type Stride = isize; + + private_impl! {} + #[doc(hidden)] + fn raw_dim(&self) -> Self::Dim { + self.raw_dim() + } + + #[doc(hidden)] + fn equal_dim(&self, dim: &Self::Dim) -> bool { + self.dim.equal(dim) + } + + #[doc(hidden)] + fn as_ptr(&self) -> *const A { + self.as_ptr() + } + + #[doc(hidden)] + fn layout(&self) -> Layout { + self.layout_impl() + } + + #[doc(hidden)] + unsafe fn as_ref(&self, ptr: *const A) -> *const A { + ptr + } + + #[doc(hidden)] + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + #[doc(hidden)] + fn stride_of(&self, axis: Axis) -> isize { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride { + 1 + } + + #[doc(hidden)] + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { + self.split_at(axis, index) + } +} + +impl NdProducer for RawArrayViewMut { + type Item = *mut A; + type Dim = D; + type Ptr = *mut A; + type Stride = isize; + + private_impl! {} + #[doc(hidden)] + fn raw_dim(&self) -> Self::Dim { + self.raw_dim() + } + + #[doc(hidden)] + fn equal_dim(&self, dim: &Self::Dim) -> bool { + self.dim.equal(dim) + } + + #[doc(hidden)] + fn as_ptr(&self) -> *mut A { + self.as_ptr() as _ + } + + #[doc(hidden)] + fn layout(&self) -> Layout { + self.layout_impl() + } + + #[doc(hidden)] + unsafe fn as_ref(&self, ptr: *mut A) -> *mut A { + ptr + } + + #[doc(hidden)] + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + #[doc(hidden)] + fn stride_of(&self, axis: Axis) -> isize { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride { + 1 + } + + #[doc(hidden)] + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { + self.split_at(axis, index) + } +} + /// Lock step function application across several arrays or other producers. /// /// Zip allows matching several producers to each other elementwise and applying From bcd7078d50509beb38f1e2aef4ee48d4c4d70015 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 26 Mar 2019 23:59:11 -0400 Subject: [PATCH 2/4] Add accumulate_axis_inplace method --- src/impl_methods.rs | 56 +++++++++++++++++++++++++++++++++++++++++++++ tests/array.rs | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index c005ff22d..8e11fab3f 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2231,4 +2231,60 @@ where }) } } + + /// Iterates over pairs of consecutive elements along the axis. + /// + /// The first argument to the closure is an element, and the second + /// argument is the next element along the axis. Iteration is guaranteed to + /// proceed in order along the specified axis, but in all other respects + /// the iteration order is unspecified. + /// + /// # Example + /// + /// For example, this can be used to compute the cumulative sum along an + /// axis: + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let mut arr = array![ + /// [[1, 2], [3, 4], [5, 6]], + /// [[7, 8], [9, 10], [11, 12]], + /// ]; + /// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + /// assert_eq!( + /// arr, + /// array![ + /// [[1, 2], [4, 6], [9, 12]], + /// [[7, 8], [16, 18], [27, 30]], + /// ], + /// ); + /// ``` + pub fn accumulate_axis_inplace(&mut self, axis: Axis, mut f: F) + where + F: FnMut(&A, &mut A), + S: DataMut, + { + if self.len_of(axis) <= 1 { + return; + } + let mut prev = self.raw_view(); + prev.slice_axis_inplace(axis, Slice::from(..-1)); + let mut curr = self.raw_view_mut(); + curr.slice_axis_inplace(axis, Slice::from(1..)); + // This implementation relies on `Zip` iterating along `axis` in order. + Zip::from(prev).and(curr).apply(|prev, curr| unsafe { + // These pointer dereferences and borrows are safe because: + // + // 1. They're pointers to elements in the array. + // + // 2. `S: DataMut` guarantees that elements are safe to borrow + // mutably and that they don't alias. + // + // 3. The lifetimes of the borrows last only for the duration + // of the call to `f`, so aliasing across calls to `f` + // cannot occur. + f(&*prev, &mut *curr) + }); + } } diff --git a/tests/array.rs b/tests/array.rs index ea5c1e82a..65a7d4f06 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1952,6 +1952,47 @@ fn test_map_axis() { itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4); } +#[test] +fn test_accumulate_axis_inplace_noop() { + let mut a = Array2::::zeros((0, 3)); + a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a, Array2::zeros((0, 3))); + + let mut a = Array2::::zeros((3, 1)); + a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + assert_eq!(a, Array2::zeros((3, 1))); +} + +#[test] +fn test_accumulate_axis_inplace_nonstandard_layout() { + let a = arr2(&[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10,11,12]]); + + let mut a_t = a.clone().reversed_axes(); + a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a_t, aview2(&[[1, 4, 7, 10], + [3, 9, 15, 21], + [6, 15, 24, 33]])); + + let mut a0 = a.clone(); + a0.invert_axis(Axis(0)); + a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a0, aview2(&[[10, 11, 12], + [17, 19, 21], + [21, 24, 27], + [22, 26, 30]])); + + let mut a1 = a.clone(); + a1.invert_axis(Axis(1)); + a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + assert_eq!(a1, aview2(&[[3, 5, 6], + [6, 11, 15], + [9, 17, 24], + [12, 23, 33]])); +} + #[test] fn test_to_vec() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); From 9d197a3fc84b17bd6516d99ffb348ce9499e2b94 Mon Sep 17 00:00:00 2001 From: bluss Date: Wed, 11 Sep 2019 20:44:26 +0200 Subject: [PATCH 3/4] Tweak how we create raw views in accumulate_axis_inplace We had: 1. let ptr1 = self.raw_view(); // Borrow &self 2. let ptr2 = self.raw_view_mut(); // Borrow &mut self 3. Use ptr1 and ptr2 While I'm not an expert at the unsafe coding guidelines for Rust, and there are more places in ndarray to revisit, I think it's best to change change 1 and 2 - they don't pass my internalized borrow checker. It seems as though the steps 1, 2, 3 could be against the rules as ptr1 is borrowed from the array data, and its scope straddles the mut borrow of the array data in 2. For this reason, I think this would be better: 1. let ptr2 = self.raw_view_mut() // Borrow &mut self 2. let ptr1 = derive from ptr2 3. use ptr1 and ptr2 RawView should hopefully be our ally in making a better ndarray from the foundation. --- src/impl_methods.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 8e11fab3f..4fbfcd98a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2268,9 +2268,9 @@ where if self.len_of(axis) <= 1 { return; } - let mut prev = self.raw_view(); + let mut curr = self.raw_view_mut(); // mut borrow of the array here + let mut prev = curr.raw_view(); // derive further raw views from the same borrow prev.slice_axis_inplace(axis, Slice::from(..-1)); - let mut curr = self.raw_view_mut(); curr.slice_axis_inplace(axis, Slice::from(1..)); // This implementation relies on `Zip` iterating along `axis` in order. Zip::from(prev).and(curr).apply(|prev, curr| unsafe { From 886b6a1ca98a658ea23a8c3dc1cd925a8cfa1391 Mon Sep 17 00:00:00 2001 From: bluss Date: Thu, 12 Sep 2019 09:06:18 +0200 Subject: [PATCH 4/4] Rustfmt updates for tests/array.rs --- tests/array.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/array.rs b/tests/array.rs index 65a7d4f06..807253104 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1098,7 +1098,7 @@ fn owned_array_with_stride() { #[test] fn owned_array_discontiguous() { - use ::std::iter::repeat; + use std::iter::repeat; let v: Vec<_> = (0..12).flat_map(|x| repeat(x).take(2)).collect(); let dim = (3, 2, 2); let strides = (8, 4, 2); @@ -1111,9 +1111,9 @@ fn owned_array_discontiguous() { #[test] fn owned_array_discontiguous_drop() { - use ::std::cell::RefCell; - use ::std::collections::BTreeSet; - use ::std::rc::Rc; + use std::cell::RefCell; + use std::collections::BTreeSet; + use std::rc::Rc; struct InsertOnDrop(Rc>>, Option); impl Drop for InsertOnDrop { @@ -1963,6 +1963,7 @@ fn test_accumulate_axis_inplace_noop() { assert_eq!(a, Array2::zeros((3, 1))); } +#[rustfmt::skip] // Allow block array formatting #[test] fn test_accumulate_axis_inplace_nonstandard_layout() { let a = arr2(&[[1, 2, 3],