Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make stack and concatenate compliant with numpy naming. #850

Merged
merged 2 commits into from Nov 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion src/lib.rs
Expand Up @@ -132,7 +132,6 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane
pub use crate::arraytraits::AsArray;
pub use crate::linalg_traits::{LinalgScalar, NdFloat};

#[allow(deprecated)]
pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::impl_views::IndexLonger;
Expand Down
93 changes: 47 additions & 46 deletions src/stacking.rs
Expand Up @@ -9,6 +9,42 @@
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack(Axis(0), &[a.view(), a.view()])
/// == Ok(arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]]))
/// );
/// # }
/// ```
pub fn stack<A, D>(
axis: Axis,
arrays: &[ArrayView<A, D>],
) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Copy,
D: Dimension,
D::Larger: RemoveAxis,
{
stack_new_axis(axis, arrays)
}

/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
Expand All @@ -17,23 +53,19 @@ use crate::imp_prelude::*;
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, stack};
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack(Axis(0), &[a.view(), a.view()])
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate` function instead"
)]
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
Expand Down Expand Up @@ -77,35 +109,6 @@ where
Ok(res)
}

/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[allow(deprecated)]
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
{
stack(axis, arrays)
}

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
Expand Down Expand Up @@ -173,7 +176,7 @@ where
Ok(res)
}

/// Concatenate arrays along the given axis.
/// Stack arrays along the new axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
Expand All @@ -183,25 +186,23 @@ where
/// ***Panics*** if the `stack` function would return an error.
///
/// ```
/// use ndarray::{arr2, stack, Axis};
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack![Axis(0), a, a]
/// == arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]])
/// == arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]])
/// );
/// # }
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate!` macro instead"
)]
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
Expand Down
35 changes: 7 additions & 28 deletions tests/stacking.rs
@@ -1,31 +1,7 @@
#![allow(deprecated)]

use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};

#[test]
fn concatenating() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));

let c = stack![Axis(0), a, b];
assert_eq!(
c,
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
);

let d = stack![Axis(0), a.row(0), &[9., 9.]];
assert_eq!(d, aview1(&[2., 2., 9., 9.]));

let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack(Axis(2), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);

let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
Expand All @@ -52,16 +28,19 @@ fn concatenating() {
#[test]
fn stacking() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap();
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = stack![Axis(0), a, a];
assert_eq!(c, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]);
let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]);
let res = ndarray::stack(Axis(3), &[a.view(), a.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]);
let res: Result<Array2<f64>, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}