Skip to content

Commit

Permalink
Support zero-length axis in .map_axis/_mut() (#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew authored and jturner314 committed Apr 9, 2019
1 parent 47b2691 commit 2924f2e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ serde = { version = "1.0", optional = true }
defmac = "0.2"
quickcheck = { version = "0.7.2", default-features = false }
rawpointer = "0.1"
itertools = { version = "0.7.0", default-features = false, features = ["use_std"] }
approx = "0.3"

[features]
Expand Down
38 changes: 24 additions & 14 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2102,13 +2102,18 @@ where
{
let view_len = self.len_of(axis);
let view_stride = self.strides.axis(axis);
// use the 0th subview as a map to each 1d array view extended from
// the 0th element.
self.index_axis(axis, 0).map(|first_elt| {
unsafe {
mapping(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
}
})
if view_len == 0 {
let new_dim = self.dim.remove_axis(axis);
Array::from_shape_fn(new_dim, move |_| mapping(ArrayView::from(&[])))
} else {
// use the 0th subview as a map to each 1d array view extended from
// the 0th element.
self.index_axis(axis, 0).map(|first_elt| {
unsafe {
mapping(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
}
})
}
}

/// Reduce the values along an axis into just one value, producing a new
Expand All @@ -2130,12 +2135,17 @@ where
{
let view_len = self.len_of(axis);
let view_stride = self.strides.axis(axis);
// use the 0th subview as a map to each 1d array view extended from
// the 0th element.
self.index_axis_mut(axis, 0).map_mut(|first_elt: &mut A| {
unsafe {
mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
}
})
if view_len == 0 {
let new_dim = self.dim.remove_axis(axis);
Array::from_shape_fn(new_dim, move |_| mapping(ArrayViewMut::from(&mut [])))
} else {
// use the 0th subview as a map to each 1d array view extended from
// the 0th element.
self.index_axis_mut(axis, 0).map_mut(|first_elt| {
unsafe {
mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
}
})
}
}
}
23 changes: 22 additions & 1 deletion tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use ndarray::{
};
use ndarray::indices;
use defmac::defmac;
use itertools::{enumerate, zip};
use itertools::{enumerate, zip, Itertools};

macro_rules! assert_panics {
($body:expr) => {
Expand Down Expand Up @@ -1833,6 +1833,27 @@ fn test_map_axis() {
let c = a.map_axis(Axis(1), |view| view.sum());
let answer2 = arr1(&[6, 15, 24, 33]);
assert_eq!(c, answer2);

// Test zero-length axis case
let arr = Array3::<f32>::zeros((3, 0, 4));
let mut counter = 0;
let result = arr.map_axis(Axis(1), |x| {
assert_eq!(x.shape(), &[0]);
counter += 1;
counter
});
assert_eq!(result.shape(), &[3, 4]);
itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4);

let mut arr = Array3::<f32>::zeros((3, 0, 4));
let mut counter = 0;
let result = arr.map_axis_mut(Axis(1), |x| {
assert_eq!(x.shape(), &[0]);
counter += 1;
counter
});
assert_eq!(result.shape(), &[3, 4]);
itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4);
}

#[test]
Expand Down

0 comments on commit 2924f2e

Please sign in to comment.