Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack and concatenate #844

Merged
merged 7 commits into from Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/doc/ndarray_for_numpy_users/mod.rs
Expand Up @@ -533,6 +533,7 @@
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
Expand Down Expand Up @@ -642,6 +643,8 @@
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
//! [stack!]: ../../macro.stack.html
//! [stack()]: ../../fn.stack.html
//! [stack_new_axis!]: ../../macro.stack_new_axis.html
//! [stack_new_axis()]: ../../fn.stack_new_axis.html
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis
Expand Down
4 changes: 2 additions & 2 deletions src/impl_methods.rs
Expand Up @@ -28,7 +28,7 @@ use crate::iter::{
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
};
use crate::slice::MultiSlice;
use crate::stacking::stack;
use crate::stacking::concatenate;
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};

/// # Methods For All Array Types
Expand Down Expand Up @@ -840,7 +840,7 @@ where
dim.set_axis(axis, 0);
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
} else {
stack(axis, &subs).unwrap()
concatenate(axis, &subs).unwrap()
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Expand Up @@ -131,7 +131,9 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane

pub use crate::arraytraits::AsArray;
pub use crate::linalg_traits::{LinalgScalar, NdFloat};
pub use crate::stacking::stack;

#[allow(deprecated)]
pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::impl_views::IndexLonger;
pub use crate::shape_builder::ShapeBuilder;
Expand Down
181 changes: 175 additions & 6 deletions src/stacking.rs
Expand Up @@ -9,7 +9,7 @@
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;

/// Stack arrays along the given axis.
/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
Expand All @@ -29,10 +29,11 @@ use crate::imp_prelude::*;
/// [3., 3.]]))
/// );
/// ```
pub fn stack<'a, A, D>(
axis: Axis,
arrays: &[ArrayView<'a, A, D>],
) -> Result<Array<A, D>, ShapeError>
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate` function instead"
)]
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
Expand Down Expand Up @@ -76,7 +77,103 @@ where
Ok(res)
}

/// Stack arrays along the given axis.
/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[allow(deprecated)]
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
{
stack(axis, arrays)
}

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack_new_axis(Axis(0), &[a.view(), a.view()])
/// == Ok(arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]]))
/// );
/// # }
/// ```
pub fn stack_new_axis<A, D>(
axis: Axis,
arrays: &[ArrayView<A, D>],
) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Copy,
D: Dimension,
D::Larger: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let common_dim = arrays[0].raw_dim();
// Avoid panic on `insert_axis` call, return an Err instead of it.
if axis.index() > common_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let mut res_dim = common_dim.insert_axis(axis);

if arrays.iter().any(|a| a.raw_dim() != common_dim) {
return Err(from_kind(ErrorKind::IncompatibleShape));
}

res_dim.set_axis(axis, arrays.len());

// we can safely use uninitialized values here because they are Copy
// and we will only ever write to them
let size = res_dim.size();
let mut v = Vec::with_capacity(size);
unsafe {
v.set_len(size);
}
let mut res = Array::from_shape_vec(res_dim, v)?;

res.axis_iter_mut(axis)
.zip(arrays.into_iter())
.for_each(|(mut assign_view, array)| {
assign_view.assign(&array);
});

Ok(res)
}

/// Concatenate arrays along the given axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
Expand All @@ -103,9 +200,81 @@ where
/// );
/// # }
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate!` macro instead"
)]
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}

/// Concatenate arrays along the given axis.
///
/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
///
/// [1]: fn.concatenate.html
///
/// ***Panics*** if the `concatenate` function would return an error.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, concatenate, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate![Axis(0), a, a]
/// == arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]])
/// );
/// # }
/// ```
#[macro_export]
macro_rules! concatenate {
($axis:expr, $( $array:expr ),+ ) => {
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}

/// Stack arrays along the new axis.
///
/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
///
/// [1]: fn.stack_new_axis.html
///
/// ***Panics*** if the `stack` function would return an error.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack_new_axis![Axis(0), a, a]
/// == arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]])
/// );
/// # }
/// ```
#[macro_export]
macro_rules! stack_new_axis {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}
45 changes: 43 additions & 2 deletions tests/stacking.rs
@@ -1,7 +1,9 @@
use ndarray::{arr2, aview1, stack, Array2, Axis, ErrorKind};
#![allow(deprecated)]

use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};

#[test]
fn stacking() {
fn concatenating() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
Expand All @@ -23,4 +25,43 @@ fn stacking() {

let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);

let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));

let c = concatenate![Axis(0), a, b];
assert_eq!(
c,
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
);

let d = concatenate![Axis(0), a.row(0), &[9., 9.]];
assert_eq!(d, aview1(&[2., 2., 9., 9.]));

let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::concatenate(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}

#[test]
fn stacking() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}