Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for inserting new axes while slicing #570

Merged
merged 28 commits into from Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
24a3299
Rename SliceOrIndex to AxisSliceInfo
jturner314 Dec 9, 2018
6a16b88
Switch from Dimension::SliceArg to CanSlice trait
jturner314 Dec 9, 2018
546b69c
Add support for inserting new axes while slicing
jturner314 Dec 9, 2018
6e335ca
Rename SliceInfo generic params to Din and Dout
jturner314 Dec 17, 2018
d6b9cb0
Improve code style
jturner314 Dec 17, 2018
438d69a
Derive Clone, Copy, and Debug for NewAxis
jturner314 Dec 17, 2018
6050df3
Use stringify! for string literal of type name
jturner314 Dec 18, 2018
8d45268
Make step_by panic for variants other than Slice
jturner314 Dec 18, 2018
1d15275
Add DimAdd trait
jturner314 Dec 18, 2018
41cc4a1
Replace SliceNextIn/OutDim with SliceArg trait
jturner314 Dec 18, 2018
c66ad8c
Combine DimAdd impls for Ix0
jturner314 Feb 7, 2021
7776bfc
Implement CanSlice<IxDyn> for [AxisSliceInfo]
jturner314 Feb 14, 2021
ab79d28
Change SliceInfo to be repr(transparent)
jturner314 Feb 15, 2021
615113e
Add debug assertions to SliceInfo::new_unchecked
jturner314 Feb 15, 2021
e66e3c8
Fix safety of SliceInfo::new
jturner314 Feb 15, 2021
3ba6ceb
Add some impls of TryFrom for SliceInfo
jturner314 Feb 15, 2021
815e708
Make slice_move not call slice_collapse
jturner314 Feb 16, 2021
25a7bb0
Make slice_collapse return Err(_) for NewAxis
jturner314 Feb 16, 2021
5202a50
Expose CanSlice trait in public API
jturner314 Feb 16, 2021
319701d
Expose MultiSlice trait in public API
jturner314 Feb 16, 2021
d5d6482
Add DimAdd bounds to Dimension trait
jturner314 Feb 16, 2021
9614b13
Revert "Make slice_collapse return Err(_) for NewAxis"
jturner314 Feb 17, 2021
61cf7c0
Make slice_collapse panic on NewAxis
jturner314 Feb 17, 2021
91dbf3f
Rename DimAdd::Out to DimAdd::Output
jturner314 Feb 17, 2021
5dc77bd
Rename SliceArg to SliceNextDim
jturner314 Feb 17, 2021
87515c6
Rename CanSlice to SliceArg
jturner314 Feb 17, 2021
c4efbbf
Rename MultiSlice to MultiSliceArg
jturner314 Feb 17, 2021
7506f90
Clarify docs of .slice_collapse()
jturner314 Feb 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions blas-tests/tests/oper.rs
Expand Up @@ -6,8 +6,8 @@ extern crate num_traits;
use ndarray::linalg::general_mat_mul;
use ndarray::linalg::general_mat_vec_mul;
use ndarray::prelude::*;
use ndarray::{AxisSliceInfo, Ix, Ixs};
use ndarray::{Data, LinalgScalar};
use ndarray::{Ix, Ixs, SliceInfo, SliceOrIndex};

use approx::{assert_abs_diff_eq, assert_relative_eq};
use defmac::defmac;
Expand Down Expand Up @@ -420,19 +420,19 @@ fn scaled_add_3() {
let mut answer = a.clone();
let cdim = if n == 1 { vec![q] } else { vec![n, q] };
let cslice = if n == 1 {
vec![SliceOrIndex::from(..).step_by(s2)]
vec![AxisSliceInfo::from(..).step_by(s2)]
} else {
vec![
SliceOrIndex::from(..).step_by(s1),
SliceOrIndex::from(..).step_by(s2),
AxisSliceInfo::from(..).step_by(s1),
AxisSliceInfo::from(..).step_by(s2),
]
};

let c = range_mat64(n, q).into_shape(cdim).unwrap();

{
let mut av = a.slice_mut(s![..;s1, ..;s2]);
let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref());
let c = c.slice(&*cslice);

let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
answerv += &(beta * &c);
Expand Down
30 changes: 8 additions & 22 deletions src/dimension/dimension_trait.rs
Expand Up @@ -13,13 +13,14 @@ use alloc::vec::Vec;

use super::axes_of;
use super::conversion::Convert;
use super::ops::DimAdd;
use super::{stride_offset, stride_offset_checked};
use crate::itertools::{enumerate, zip};
use crate::{Axis, DimMax};
use crate::IntoDimension;
use crate::RemoveAxis;
use crate::{ArrayView1, ArrayViewMut1};
use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs, SliceOrIndex};
use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs};

/// Array shape and index trait.
///
Expand Down Expand Up @@ -51,26 +52,17 @@ pub trait Dimension:
+ DimMax<IxDyn, Output=IxDyn>
+ DimMax<<Self as Dimension>::Smaller, Output=Self>
+ DimMax<<Self as Dimension>::Larger, Output=<Self as Dimension>::Larger>
+ DimAdd<Self>
+ DimAdd<<Self as Dimension>::Smaller>
+ DimAdd<<Self as Dimension>::Larger>
+ DimAdd<Ix0, Output = Self>
+ DimAdd<Ix1, Output = <Self as Dimension>::Larger>
+ DimAdd<IxDyn, Output = IxDyn>
{
/// For fixed-size dimension representations (e.g. `Ix2`), this should be
/// `Some(ndim)`, and for variable-size dimension representations (e.g.
/// `IxDyn`), this should be `None`.
const NDIM: Option<usize>;
/// `SliceArg` is the type which is used to specify slicing for this
/// dimension.
///
/// For the fixed size dimensions it is a fixed size array of the correct
/// size, which you pass by reference. For the dynamic dimension it is
/// a slice.
///
/// - For `Ix1`: `[SliceOrIndex; 1]`
/// - For `Ix2`: `[SliceOrIndex; 2]`
/// - and so on..
/// - For `IxDyn`: `[SliceOrIndex]`
///
/// The easiest way to create a `&SliceInfo<SliceArg, Do>` is using the
/// [`s![]`](macro.s!.html) macro.
type SliceArg: ?Sized + AsRef<[SliceOrIndex]>;
/// Pattern matching friendly form of the dimension value.
///
/// - For `Ix1`: `usize`,
Expand Down Expand Up @@ -399,7 +391,6 @@ macro_rules! impl_insert_axis_array(

impl Dimension for Dim<[Ix; 0]> {
const NDIM: Option<usize> = Some(0);
type SliceArg = [SliceOrIndex; 0];
type Pattern = ();
type Smaller = Self;
type Larger = Ix1;
Expand Down Expand Up @@ -443,7 +434,6 @@ impl Dimension for Dim<[Ix; 0]> {

impl Dimension for Dim<[Ix; 1]> {
const NDIM: Option<usize> = Some(1);
type SliceArg = [SliceOrIndex; 1];
type Pattern = Ix;
type Smaller = Ix0;
type Larger = Ix2;
Expand Down Expand Up @@ -559,7 +549,6 @@ impl Dimension for Dim<[Ix; 1]> {

impl Dimension for Dim<[Ix; 2]> {
const NDIM: Option<usize> = Some(2);
type SliceArg = [SliceOrIndex; 2];
type Pattern = (Ix, Ix);
type Smaller = Ix1;
type Larger = Ix3;
Expand Down Expand Up @@ -716,7 +705,6 @@ impl Dimension for Dim<[Ix; 2]> {

impl Dimension for Dim<[Ix; 3]> {
const NDIM: Option<usize> = Some(3);
type SliceArg = [SliceOrIndex; 3];
type Pattern = (Ix, Ix, Ix);
type Smaller = Ix2;
type Larger = Ix4;
Expand Down Expand Up @@ -839,7 +827,6 @@ macro_rules! large_dim {
($n:expr, $name:ident, $pattern:ty, $larger:ty, { $($insert_axis:tt)* }) => (
impl Dimension for Dim<[Ix; $n]> {
const NDIM: Option<usize> = Some($n);
type SliceArg = [SliceOrIndex; $n];
type Pattern = $pattern;
type Smaller = Dim<[Ix; $n - 1]>;
type Larger = $larger;
Expand Down Expand Up @@ -890,7 +877,6 @@ large_dim!(6, Ix6, (Ix, Ix, Ix, Ix, Ix, Ix), IxDyn, {
/// and memory wasteful, but it allows an arbitrary and dynamic number of axes.
impl Dimension for IxDyn {
const NDIM: Option<usize> = None;
type SliceArg = [SliceOrIndex];
type Pattern = Self;
type Smaller = Self;
type Larger = Self;
Expand Down
74 changes: 55 additions & 19 deletions src/dimension/mod.rs
Expand Up @@ -7,7 +7,8 @@
// except according to those terms.

use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::{Ix, Ixs, Slice, SliceOrIndex};
use crate::slice::SliceArg;
use crate::{AxisSliceInfo, Ix, Ixs, Slice};
use num_integer::div_floor;

pub use self::axes::{axes_of, Axes, AxisDescription};
Expand All @@ -18,6 +19,7 @@ pub use self::dim::*;
pub use self::dimension_trait::Dimension;
pub use self::dynindeximpl::IxDynImpl;
pub use self::ndindex::NdIndex;
pub use self::ops::DimAdd;
pub use self::remove_axis::RemoveAxis;

use crate::shape_builder::Strides;
Expand All @@ -35,6 +37,7 @@ pub mod dim;
mod dimension_trait;
mod dynindeximpl;
mod ndindex;
mod ops;
mod remove_axis;

/// Calculate offset from `Ix` stride converting sign properly
Expand Down Expand Up @@ -596,20 +599,24 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
/// Returns `true` iff the slices intersect.
pub fn slices_intersect<D: Dimension>(
dim: &D,
indices1: &D::SliceArg,
indices2: &D::SliceArg,
indices1: &impl SliceArg<D>,
indices2: &impl SliceArg<D>,
) -> bool {
debug_assert_eq!(indices1.as_ref().len(), indices2.as_ref().len());
for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) {
// The slices do not intersect iff any pair of `SliceOrIndex` does not intersect.
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
for (&axis_len, &si1, &si2) in izip!(
dim.slice(),
indices1.as_ref().iter().filter(|si| !si.is_new_axis()),
indices2.as_ref().iter().filter(|si| !si.is_new_axis()),
) {
// The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect.
match (si1, si2) {
(
SliceOrIndex::Slice {
AxisSliceInfo::Slice {
start: start1,
end: end1,
step: step1,
},
SliceOrIndex::Slice {
AxisSliceInfo::Slice {
start: start2,
end: end2,
step: step2,
Expand All @@ -630,8 +637,8 @@ pub fn slices_intersect<D: Dimension>(
return false;
}
}
(SliceOrIndex::Slice { start, end, step }, SliceOrIndex::Index(ind))
| (SliceOrIndex::Index(ind), SliceOrIndex::Slice { start, end, step }) => {
(AxisSliceInfo::Slice { start, end, step }, AxisSliceInfo::Index(ind))
| (AxisSliceInfo::Index(ind), AxisSliceInfo::Slice { start, end, step }) => {
let ind = abs_index(axis_len, ind);
let (min, max) = match slice_min_max(axis_len, Slice::new(start, end, step)) {
Some(m) => m,
Expand All @@ -641,13 +648,14 @@ pub fn slices_intersect<D: Dimension>(
return false;
}
}
(SliceOrIndex::Index(ind1), SliceOrIndex::Index(ind2)) => {
(AxisSliceInfo::Index(ind1), AxisSliceInfo::Index(ind2)) => {
let ind1 = abs_index(axis_len, ind1);
let ind2 = abs_index(axis_len, ind2);
if ind1 != ind2 {
return false;
}
}
(AxisSliceInfo::NewAxis, _) | (_, AxisSliceInfo::NewAxis) => unreachable!(),
}
}
true
Expand Down Expand Up @@ -719,7 +727,7 @@ mod test {
};
use crate::error::{from_kind, ErrorKind};
use crate::slice::Slice;
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn};
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis};
use num_integer::gcd;
use quickcheck::{quickcheck, TestResult};

Expand Down Expand Up @@ -993,17 +1001,45 @@ mod test {

#[test]
fn slices_intersect_true() {
assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..]));
assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3]));
assert!(slices_intersect(
&Dim([4, 5]),
s![NewAxis, .., NewAxis, ..],
s![.., NewAxis, .., NewAxis]
));
assert!(slices_intersect(
&Dim([4, 5]),
s![NewAxis, 0, ..],
s![0, ..]
));
assert!(slices_intersect(
&Dim([4, 5]),
s![..;2, ..],
s![..;3, NewAxis, ..]
));
assert!(slices_intersect(
&Dim([4, 5]),
s![.., ..;2],
s![.., 1..;3, NewAxis]
));
assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6]));
}

#[test]
fn slices_intersect_false() {
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6]));
assert!(!slices_intersect(
&Dim([4, 5]),
s![..;2, ..],
s![NewAxis, 1..;2, ..]
));
assert!(!slices_intersect(
&Dim([4, 5]),
s![..;2, NewAxis, ..],
s![1..;3, ..]
));
assert!(!slices_intersect(
&Dim([4, 5]),
s![.., ..;9],
s![.., 3..;6, NewAxis]
));
}
}
90 changes: 90 additions & 0 deletions src/dimension/ops.rs
@@ -0,0 +1,90 @@
use crate::imp_prelude::*;

/// Adds the two dimensions at compile time.
pub trait DimAdd<D: Dimension> {
/// The sum of the two dimensions.
type Output: Dimension;
}

macro_rules! impl_dimadd_const_out_const {
($lhs:expr, $rhs:expr) => {
impl DimAdd<Dim<[usize; $rhs]>> for Dim<[usize; $lhs]> {
type Output = Dim<[usize; $lhs + $rhs]>;
}
};
}

macro_rules! impl_dimadd_const_out_dyn {
($lhs:expr, IxDyn) => {
impl DimAdd<IxDyn> for Dim<[usize; $lhs]> {
type Output = IxDyn;
}
};
($lhs:expr, $rhs:expr) => {
impl DimAdd<Dim<[usize; $rhs]>> for Dim<[usize; $lhs]> {
type Output = IxDyn;
}
};
}

impl<D: Dimension> DimAdd<D> for Ix0 {
type Output = D;
}

impl_dimadd_const_out_const!(1, 0);
impl_dimadd_const_out_const!(1, 1);
impl_dimadd_const_out_const!(1, 2);
impl_dimadd_const_out_const!(1, 3);
impl_dimadd_const_out_const!(1, 4);
impl_dimadd_const_out_const!(1, 5);
impl_dimadd_const_out_dyn!(1, 6);
impl_dimadd_const_out_dyn!(1, IxDyn);

impl_dimadd_const_out_const!(2, 0);
impl_dimadd_const_out_const!(2, 1);
impl_dimadd_const_out_const!(2, 2);
impl_dimadd_const_out_const!(2, 3);
impl_dimadd_const_out_const!(2, 4);
impl_dimadd_const_out_dyn!(2, 5);
impl_dimadd_const_out_dyn!(2, 6);
impl_dimadd_const_out_dyn!(2, IxDyn);

impl_dimadd_const_out_const!(3, 0);
impl_dimadd_const_out_const!(3, 1);
impl_dimadd_const_out_const!(3, 2);
impl_dimadd_const_out_const!(3, 3);
impl_dimadd_const_out_dyn!(3, 4);
impl_dimadd_const_out_dyn!(3, 5);
impl_dimadd_const_out_dyn!(3, 6);
impl_dimadd_const_out_dyn!(3, IxDyn);

impl_dimadd_const_out_const!(4, 0);
impl_dimadd_const_out_const!(4, 1);
impl_dimadd_const_out_const!(4, 2);
impl_dimadd_const_out_dyn!(4, 3);
impl_dimadd_const_out_dyn!(4, 4);
impl_dimadd_const_out_dyn!(4, 5);
impl_dimadd_const_out_dyn!(4, 6);
impl_dimadd_const_out_dyn!(4, IxDyn);

impl_dimadd_const_out_const!(5, 0);
impl_dimadd_const_out_const!(5, 1);
impl_dimadd_const_out_dyn!(5, 2);
impl_dimadd_const_out_dyn!(5, 3);
impl_dimadd_const_out_dyn!(5, 4);
impl_dimadd_const_out_dyn!(5, 5);
impl_dimadd_const_out_dyn!(5, 6);
impl_dimadd_const_out_dyn!(5, IxDyn);

impl_dimadd_const_out_const!(6, 0);
impl_dimadd_const_out_dyn!(6, 1);
impl_dimadd_const_out_dyn!(6, 2);
impl_dimadd_const_out_dyn!(6, 3);
impl_dimadd_const_out_dyn!(6, 4);
impl_dimadd_const_out_dyn!(6, 5);
impl_dimadd_const_out_dyn!(6, 6);
impl_dimadd_const_out_dyn!(6, IxDyn);

impl<D: Dimension> DimAdd<D> for IxDyn {
type Output = IxDyn;
}
2 changes: 1 addition & 1 deletion src/doc/ndarray_for_numpy_users/mod.rs
Expand Up @@ -532,7 +532,7 @@
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
//! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1
//! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
//! `a.flatten()` | [`use std::iter::FromIterator; Array::from_iter(a.iter().cloned())`][::from_iter()] | create a 1-D array by flattening `a`
Expand Down