Skip to content

Commit

Permalink
Merge pull request #913 from jturner314/slice_each_axis_inplace
Browse files Browse the repository at this point in the history
Add slice_each_axis/_mut/_inplace methods
  • Loading branch information
jturner314 committed Feb 7, 2021
2 parents f7b9816 + 566b531 commit 8dd1509
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
58 changes: 58 additions & 0 deletions src/impl_methods.rs
Expand Up @@ -25,6 +25,7 @@ use crate::error::{self, ErrorKind, ShapeError};
use crate::math_cell::MathCell;
use crate::itertools::zip;
use crate::zip::Zip;
use crate::AxisDescription;

use crate::iter::{
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
Expand Down Expand Up @@ -511,6 +512,63 @@ where
debug_assert!(self.pointer_is_inbounds());
}

/// Return a view of a slice of the array, with a closure specifying the
/// slice for each axis.
///
/// This is especially useful for code which is generic over the
/// dimensionality of the array.
///
/// **Panics** if an index is out of bounds or step size is zero.
pub fn slice_each_axis<F>(&self, f: F) -> ArrayView<'_, A, D>
where
F: FnMut(AxisDescription) -> Slice,
S: Data,
{
let mut view = self.view();
view.slice_each_axis_inplace(f);
view
}

/// Return a mutable view of a slice of the array, with a closure
/// specifying the slice for each axis.
///
/// This is especially useful for code which is generic over the
/// dimensionality of the array.
///
/// **Panics** if an index is out of bounds or step size is zero.
pub fn slice_each_axis_mut<F>(&mut self, f: F) -> ArrayViewMut<'_, A, D>
where
F: FnMut(AxisDescription) -> Slice,
S: DataMut,
{
let mut view = self.view_mut();
view.slice_each_axis_inplace(f);
view
}

/// Slice the array in place, with a closure specifying the slice for each
/// axis.
///
/// This is especially useful for code which is generic over the
/// dimensionality of the array.
///
/// **Panics** if an index is out of bounds or step size is zero.
pub fn slice_each_axis_inplace<F>(&mut self, mut f: F)
where
F: FnMut(AxisDescription) -> Slice,
{
(0..self.ndim()).for_each(|ax| {
self.slice_axis_inplace(
Axis(ax),
f(AxisDescription(
Axis(ax),
self.dim[ax],
self.strides[ax] as isize,
)),
)
})
}

/// Return a reference to the element at `index`, or return `None`
/// if the index is out of bounds.
///
Expand Down
31 changes: 29 additions & 2 deletions src/lib.rs
Expand Up @@ -503,6 +503,19 @@ pub type Ixs = isize;
/// [`.slice_move()`]: #method.slice_move
/// [`.slice_collapse()`]: #method.slice_collapse
///
/// When slicing arrays with generic dimensionality, creating an instance of
/// [`&SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`]
/// is awkward. In these cases, it's usually more convenient to use
/// [`.slice_each_axis()`]/[`.slice_each_axis_mut()`]/[`.slice_each_axis_inplace()`]
/// or to create a view and then slice individual axes of the view using
/// methods such as [`.slice_axis_inplace()`] and [`.collapse_axis()`].
///
/// [`.slice_each_axis()`]: #method.slice_each_axis
/// [`.slice_each_axis_mut()`]: #method.slice_each_axis_mut
/// [`.slice_each_axis_inplace()`]: #method.slice_each_axis_inplace
/// [`.slice_axis_inplace()`]: #method.slice_axis_inplace
/// [`.collapse_axis()`]: #method.collapse_axis
///
/// It's possible to take multiple simultaneous *mutable* slices with
/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only)
/// [`.multi_slice_move()`].
Expand All @@ -511,8 +524,7 @@ pub type Ixs = isize;
/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move
///
/// ```
///
/// use ndarray::{arr2, arr3, s};
/// use ndarray::{arr2, arr3, s, ArrayBase, DataMut, Dimension, Slice};
///
/// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`.
///
Expand Down Expand Up @@ -571,6 +583,21 @@ pub type Ixs = isize;
/// [5, 7]]);
/// assert_eq!(s0, i);
/// assert_eq!(s1, j);
///
/// // Generic function which assigns the specified value to the elements which
/// // have indices in the lower half along all axes.
/// fn fill_lower<S, D>(arr: &mut ArrayBase<S, D>, x: S::Elem)
/// where
/// S: DataMut,
/// S::Elem: Clone,
/// D: Dimension,
/// {
/// arr.slice_each_axis_mut(|ax| Slice::from(0..ax.len() / 2)).fill(x);
/// }
/// fill_lower(&mut h, 9);
/// let k = arr2(&[[9, 9, 2, 3],
/// [4, 5, 6, 7]]);
/// assert_eq!(h, k);
/// ```
///
/// ## Subviews
Expand Down

0 comments on commit 8dd1509

Please sign in to comment.