Skip to content

Commit

Permalink
Merge pull request #850 from andrei-papou/stack-concatenate-renaming
Browse files Browse the repository at this point in the history
Make stack and concatenate compliant with numpy naming.
  • Loading branch information
bluss committed Nov 28, 2020
2 parents 2bae4cc + 13dbaf1 commit 716221c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 75 deletions.
1 change: 0 additions & 1 deletion src/lib.rs
Expand Up @@ -135,7 +135,6 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane
pub use crate::arraytraits::AsArray;
pub use crate::linalg_traits::{LinalgScalar, NdFloat};

#[allow(deprecated)]
pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::impl_views::IndexLonger;
Expand Down
93 changes: 47 additions & 46 deletions src/stacking.rs
Expand Up @@ -9,6 +9,42 @@
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack(Axis(0), &[a.view(), a.view()])
/// == Ok(arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]]))
/// );
/// # }
/// ```
pub fn stack<A, D>(
axis: Axis,
arrays: &[ArrayView<A, D>],
) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Copy,
D: Dimension,
D::Larger: RemoveAxis,
{
stack_new_axis(axis, arrays)
}

/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
Expand All @@ -17,23 +53,19 @@ use crate::imp_prelude::*;
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, stack};
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack(Axis(0), &[a.view(), a.view()])
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate` function instead"
)]
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
Expand Down Expand Up @@ -77,35 +109,6 @@ where
Ok(res)
}

/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[allow(deprecated)]
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
{
stack(axis, arrays)
}

/// Stack arrays along the new axis.
///
/// ***Errors*** if the arrays have mismatching shapes.
Expand Down Expand Up @@ -173,7 +176,7 @@ where
Ok(res)
}

/// Concatenate arrays along the given axis.
/// Stack arrays along the new axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
Expand All @@ -183,25 +186,23 @@ where
/// ***Panics*** if the `stack` function would return an error.
///
/// ```
/// use ndarray::{arr2, stack, Axis};
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack![Axis(0), a, a]
/// == arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]])
/// == arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]])
/// );
/// # }
/// ```
#[deprecated(
since = "0.13.2",
note = "Please use the `concatenate!` macro instead"
)]
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
Expand Down
35 changes: 7 additions & 28 deletions tests/stacking.rs
@@ -1,31 +1,7 @@
#![allow(deprecated)]

use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};

#[test]
fn concatenating() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));

let c = stack![Axis(0), a, b];
assert_eq!(
c,
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
);

let d = stack![Axis(0), a.row(0), &[9., 9.]];
assert_eq!(d, aview1(&[2., 2., 9., 9.]));

let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack(Axis(2), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);

let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
Expand All @@ -52,16 +28,19 @@ fn concatenating() {
#[test]
fn stacking() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap();
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = stack![Axis(0), a, a];
assert_eq!(c, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]);
let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]);
let res = ndarray::stack(Axis(3), &[a.view(), a.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]);
let res: Result<Array2<f64>, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}

0 comments on commit 716221c

Please sign in to comment.