From 566b53138a1e054ad0605282dc552077ff45f600 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Fri, 5 Feb 2021 17:23:32 -0500 Subject: [PATCH] Add slice_each_axis/_mut/_inplace methods --- src/impl_methods.rs | 58 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 31 ++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9eea42d15..a6143c28a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -25,6 +25,7 @@ use crate::error::{self, ErrorKind, ShapeError}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::Zip; +use crate::AxisDescription; use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, @@ -511,6 +512,63 @@ where debug_assert!(self.pointer_is_inbounds()); } + /// Return a view of a slice of the array, with a closure specifying the + /// slice for each axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + pub fn slice_each_axis(&self, f: F) -> ArrayView<'_, A, D> + where + F: FnMut(AxisDescription) -> Slice, + S: Data, + { + let mut view = self.view(); + view.slice_each_axis_inplace(f); + view + } + + /// Return a mutable view of a slice of the array, with a closure + /// specifying the slice for each axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + pub fn slice_each_axis_mut(&mut self, f: F) -> ArrayViewMut<'_, A, D> + where + F: FnMut(AxisDescription) -> Slice, + S: DataMut, + { + let mut view = self.view_mut(); + view.slice_each_axis_inplace(f); + view + } + + /// Slice the array in place, with a closure specifying the slice for each + /// axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + pub fn slice_each_axis_inplace(&mut self, mut f: F) + where + F: FnMut(AxisDescription) -> Slice, + { + (0..self.ndim()).for_each(|ax| { + self.slice_axis_inplace( + Axis(ax), + f(AxisDescription( + Axis(ax), + self.dim[ax], + self.strides[ax] as isize, + )), + ) + }) + } + /// Return a reference to the element at `index`, or return `None` /// if the index is out of bounds. /// diff --git a/src/lib.rs b/src/lib.rs index 0c7caf735..03fd96253 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -503,6 +503,19 @@ pub type Ixs = isize; /// [`.slice_move()`]: #method.slice_move /// [`.slice_collapse()`]: #method.slice_collapse /// +/// When slicing arrays with generic dimensionality, creating an instance of +/// [`&SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`] +/// is awkward. In these cases, it's usually more convenient to use +/// [`.slice_each_axis()`]/[`.slice_each_axis_mut()`]/[`.slice_each_axis_inplace()`] +/// or to create a view and then slice individual axes of the view using +/// methods such as [`.slice_axis_inplace()`] and [`.collapse_axis()`]. +/// +/// [`.slice_each_axis()`]: #method.slice_each_axis +/// [`.slice_each_axis_mut()`]: #method.slice_each_axis_mut +/// [`.slice_each_axis_inplace()`]: #method.slice_each_axis_inplace +/// [`.slice_axis_inplace()`]: #method.slice_axis_inplace +/// [`.collapse_axis()`]: #method.collapse_axis +/// /// It's possible to take multiple simultaneous *mutable* slices with /// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only) /// [`.multi_slice_move()`]. @@ -511,8 +524,7 @@ pub type Ixs = isize; /// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move /// /// ``` -/// -/// use ndarray::{arr2, arr3, s}; +/// use ndarray::{arr2, arr3, s, ArrayBase, DataMut, Dimension, Slice}; /// /// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`. /// @@ -571,6 +583,21 @@ pub type Ixs = isize; /// [5, 7]]); /// assert_eq!(s0, i); /// assert_eq!(s1, j); +/// +/// // Generic function which assigns the specified value to the elements which +/// // have indices in the lower half along all axes. +/// fn fill_lower(arr: &mut ArrayBase, x: S::Elem) +/// where +/// S: DataMut, +/// S::Elem: Clone, +/// D: Dimension, +/// { +/// arr.slice_each_axis_mut(|ax| Slice::from(0..ax.len() / 2)).fill(x); +/// } +/// fill_lower(&mut h, 9); +/// let k = arr2(&[[9, 9, 2, 3], +/// [4, 5, 6, 7]]); +/// assert_eq!(h, k); /// ``` /// /// ## Subviews