Skip to content

Commit

Permalink
Merge pull request #911 from jturner314/improve-map_inplace
Browse files Browse the repository at this point in the history
Improve `map_inplace`, and use it to replace `unordered_foreach_mut`
  • Loading branch information
jturner314 committed Feb 6, 2021
2 parents a66f364 + 0dce73a commit f7b9816
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 63 deletions.
30 changes: 30 additions & 0 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,36 @@ where
}
}

/// Move the axis which has the smallest absolute stride and a length
/// greater than one to be the last axis.
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
where
D: Dimension,
{
debug_assert_eq!(dim.ndim(), strides.ndim());
match dim.ndim() {
0 | 1 => {}
2 => {
if dim[1] <= 1
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
{
dim.slice_mut().swap(0, 1);
strides.slice_mut().swap(0, 1);
}
}
n => {
if let Some(min_stride_axis) = (0..n)
.filter(|&ax| dim[ax] > 1)
.min_by_key(|&ax| (strides[ax] as isize).abs())
{
let last = n - 1;
dim.slice_mut().swap(last, min_stride_axis);
strides.slice_mut().swap(last, min_stride_axis);
}
}
}
}

#[cfg(test)]
mod test {
use super::{
Expand Down
61 changes: 29 additions & 32 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use crate::arraytraits;
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked,
stride_offset, Axes,
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
};
use crate::error::{self, ErrorKind, ShapeError};
use crate::math_cell::MathCell;
Expand Down Expand Up @@ -1448,20 +1448,29 @@ where
/// Return the array’s data as a slice if it is contiguous,
/// return `None` otherwise.
pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]>
where
S: DataMut,
{
self.try_as_slice_memory_order_mut().ok()
}

/// Return the array’s data as a slice if it is contiguous, otherwise
/// return `self` in the `Err` variant.
pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self>
where
S: DataMut,
{
if self.is_contiguous() {
self.ensure_unique();
let offset = offset_from_ptr_to_memory(&self.dim, &self.strides);
unsafe {
Some(slice::from_raw_parts_mut(
Ok(slice::from_raw_parts_mut(
self.ptr.offset(offset).as_ptr(),
self.len(),
))
}
} else {
None
Err(self)
}
}

Expand Down Expand Up @@ -1943,7 +1952,7 @@ where
S: DataMut,
A: Clone,
{
self.unordered_foreach_mut(move |elt| *elt = x.clone());
self.map_inplace(move |elt| *elt = x.clone());
}

fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
Expand Down Expand Up @@ -1995,7 +2004,7 @@ where
S: DataMut,
F: FnMut(&mut A, &B),
{
self.unordered_foreach_mut(move |elt| f(elt, rhs_elem));
self.map_inplace(move |elt| f(elt, rhs_elem));
}

/// Traverse two arrays in unspecified order, in lock step,
Expand Down Expand Up @@ -2037,27 +2046,7 @@ where
slc.iter().fold(init, f)
} else {
let mut v = self.view();
// put the narrowest axis at the last position
match v.ndim() {
0 | 1 => {}
2 => {
if self.len_of(Axis(1)) <= 1
|| self.len_of(Axis(0)) > 1
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
{
v.swap_axes(0, 1);
}
}
n => {
let last = n - 1;
let narrow_axis = v
.axes()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(last, |ax| ax.axis().index());
v.swap_axes(last, narrow_axis);
}
}
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().fold(init, f)
}
}
Expand Down Expand Up @@ -2167,12 +2156,20 @@ where
/// Modify the array in place by calling `f` by mutable reference on each element.
///
/// Elements are visited in arbitrary order.
pub fn map_inplace<F>(&mut self, f: F)
pub fn map_inplace<'a, F>(&'a mut self, f: F)
where
S: DataMut,
F: FnMut(&mut A),
{
self.unordered_foreach_mut(f);
A: 'a,
F: FnMut(&'a mut A),
{
match self.try_as_slice_memory_order_mut() {
Ok(slc) => slc.iter_mut().for_each(f),
Err(arr) => {
let mut v = arr.view_mut();
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().for_each(f);
}
}
}

/// Modify the array in place by calling `f` by **v**alue on each element.
Expand Down Expand Up @@ -2202,7 +2199,7 @@ where
F: FnMut(A) -> A,
A: Clone,
{
self.unordered_foreach_mut(move |x| *x = f(x.clone()));
self.map_inplace(move |x| *x = f(x.clone()));
}

/// Call `f` for each element in the array.
Expand Down
10 changes: 5 additions & 5 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, x: B) -> ArrayBase<S, D> {
self.unordered_foreach_mut(move |elt| {
self.map_inplace(move |elt| {
*elt = elt.clone() $operator x.clone();
});
self
Expand Down Expand Up @@ -194,7 +194,7 @@ impl<S, D> $trt<ArrayBase<S, D>> for $scalar
rhs.$mth(self)
} or {{
let mut rhs = rhs;
rhs.unordered_foreach_mut(move |elt| {
rhs.map_inplace(move |elt| {
*elt = self $operator *elt;
});
rhs
Expand Down Expand Up @@ -299,7 +299,7 @@ mod arithmetic_ops {
type Output = Self;
/// Perform an elementwise negation of `self` and return the result.
fn neg(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
self.map_inplace(|elt| {
*elt = -elt.clone();
});
self
Expand Down Expand Up @@ -329,7 +329,7 @@ mod arithmetic_ops {
type Output = Self;
/// Perform an elementwise unary not of `self` and return the result.
fn not(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
self.map_inplace(|elt| {
*elt = !elt.clone();
});
self
Expand Down Expand Up @@ -386,7 +386,7 @@ mod assign_ops {
D: Dimension,
{
fn $method(&mut self, rhs: A) {
self.unordered_foreach_mut(move |elt| {
self.map_inplace(move |elt| {
elt.$method(rhs.clone());
});
}
Expand Down
27 changes: 1 addition & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ pub use crate::indexes::{indices, indices_of};
pub use crate::slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex};

use crate::iterators::Baseiter;
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut};
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes};

pub use crate::arraytraits::AsArray;
#[cfg(feature = "std")]
Expand Down Expand Up @@ -1545,22 +1545,6 @@ where
self.strides.clone()
}

/// Apply closure `f` to each element in the array, in whatever
/// order is the fastest to visit.
fn unordered_foreach_mut<F>(&mut self, mut f: F)
where
S: DataMut,
F: FnMut(&mut A),
{
if let Some(slc) = self.as_slice_memory_order_mut() {
slc.iter_mut().for_each(f);
} else {
for row in self.inner_rows_mut() {
row.into_iter_().fold((), |(), elt| f(elt));
}
}
}

/// Remove array axis `axis` and return the result.
fn try_remove_axis(self, axis: Axis) -> ArrayBase<S, D::Smaller> {
let d = self.dim.try_remove_axis(axis);
Expand All @@ -1576,15 +1560,6 @@ where
let n = self.ndim();
Lanes::new(self.view(), Axis(n.saturating_sub(1)))
}

/// n-d generalization of rows, just like inner iter
fn inner_rows_mut(&mut self) -> iterators::LanesMut<'_, A, D::Smaller>
where
S: DataMut,
{
let n = self.ndim();
LanesMut::new(self.view_mut(), Axis(n.saturating_sub(1)))
}
}

// parallel methods
Expand Down

0 comments on commit f7b9816

Please sign in to comment.