From 48606c63f0fc903cfb48dc188839c16b3ddc5501 Mon Sep 17 00:00:00 2001 From: andrei-papou Date: Tue, 1 Oct 2019 16:03:05 +0300 Subject: [PATCH 1/7] Renamed `stack` to `concatenate` and added numpy-compliant `stack` function --- src/doc/ndarray_for_numpy_users/mod.rs | 5 +- src/impl_methods.rs | 4 +- src/lib.rs | 2 +- src/stacking.rs | 75 ++++++++++++++++++++++---- tests/stacking.rs | 33 +++++++++--- 5 files changed, 96 insertions(+), 23 deletions(-) diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index b061d8e67..79d1314a8 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -532,7 +532,8 @@ //! ------|-----------|------ //! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value //! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a` -//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1 +//! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1 +//! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1 //! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1 //! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`) //! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a` @@ -640,6 +641,8 @@ //! [.slice_move()]: ../../struct.ArrayBase.html#method.slice_move //! [.slice_mut()]: ../../struct.ArrayBase.html#method.slice_mut //! [.shape()]: ../../struct.ArrayBase.html#method.shape +//! [concatenate!]: ../../macro.concatenate.html +//! [concatenate()]: ../../fn.concatenate.html //! [stack!]: ../../macro.stack.html //! [stack()]: ../../fn.stack.html //! [.strides()]: ../../struct.ArrayBase.html#method.strides diff --git a/src/impl_methods.rs b/src/impl_methods.rs index db502663f..9853a1c47 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -28,7 +28,7 @@ use crate::iter::{ IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; use crate::slice::MultiSlice; -use crate::stacking::stack; +use crate::stacking::concatenate; use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex}; /// # Methods For All Array Types @@ -840,7 +840,7 @@ where dim.set_axis(axis, 0); unsafe { Array::from_shape_vec_unchecked(dim, vec![]) } } else { - stack(axis, &subs).unwrap() + concatenate(axis, &subs).unwrap() } } diff --git a/src/lib.rs b/src/lib.rs index eb09d31ea..191687c02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,7 +131,7 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane pub use crate::arraytraits::AsArray; pub use crate::linalg_traits::{LinalgScalar, NdFloat}; -pub use crate::stacking::stack; +pub use crate::stacking::{concatenate, stack}; pub use crate::impl_views::IndexLonger; pub use crate::shape_builder::ShapeBuilder; diff --git a/src/stacking.rs b/src/stacking.rs index e998b6d15..975a840d6 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -17,21 +17,21 @@ use crate::imp_prelude::*; /// if the result is larger than is possible to represent. /// /// ``` -/// use ndarray::{arr2, Axis, stack}; +/// use ndarray::{arr2, Axis, concatenate}; /// /// let a = arr2(&[[2., 2.], /// [3., 3.]]); /// assert!( -/// stack(Axis(0), &[a.view(), a.view()]) +/// concatenate(Axis(0), &[a.view(), a.view()]) /// == Ok(arr2(&[[2., 2.], /// [3., 3.], /// [2., 2.], /// [3., 3.]])) /// ); /// ``` -pub fn stack<'a, A, D>( +pub fn concatenate( axis: Axis, - arrays: &[ArrayView<'a, A, D>], + arrays: &[ArrayView], ) -> Result, ShapeError> where A: Copy, @@ -76,26 +76,45 @@ where Ok(res) } -/// Stack arrays along the given axis. +pub fn stack( + axis: Axis, + arrays: Vec>, +) -> Result, ShapeError> +where + A: Copy, + D: Dimension, + D::Larger: RemoveAxis, +{ + if let Some(ndim) = D::NDIM { + if axis.index() > ndim { + return Err(from_kind(ErrorKind::OutOfBounds)); + } + } + let arrays: Vec> = arrays.into_iter() + .map(|a| a.insert_axis(axis)).collect(); + concatenate(axis, &arrays) +} + +/// Concatenate arrays along the given axis. /// -/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each +/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each /// argument `a`. /// -/// [1]: fn.stack.html +/// [1]: fn.concatenate.html /// -/// ***Panics*** if the `stack` function would return an error. +/// ***Panics*** if the `concatenate` function would return an error. /// /// ``` /// extern crate ndarray; /// -/// use ndarray::{arr2, stack, Axis}; +/// use ndarray::{arr2, concatenate, Axis}; /// /// # fn main() { /// /// let a = arr2(&[[2., 2.], /// [3., 3.]]); /// assert!( -/// stack![Axis(0), a, a] +/// concatenate![Axis(0), a, a] /// == arr2(&[[2., 2.], /// [3., 3.], /// [2., 2.], @@ -104,8 +123,42 @@ where /// # } /// ``` #[macro_export] +macro_rules! concatenate { + ($axis:expr, $( $array:expr ),+ ) => { + $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() + } +} + +/// Stack arrays along the new axis. +/// +/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each +/// argument `a`. +/// +/// [1]: fn.concatenate.html +/// +/// ***Panics*** if the `stack` function would return an error. +/// +/// ``` +/// extern crate ndarray; +/// +/// use ndarray::{arr2, arr3, stack, Axis}; +/// +/// # fn main() { +/// +/// let a = arr2(&[[2., 2.], +/// [3., 3.]]); +/// assert!( +/// stack![Axis(0), a, a] +/// == arr3(&[[[2., 2.], +/// [3., 3.]], +/// [[2., 2.], +/// [3., 3.]]]) +/// ); +/// # } +/// ``` +#[macro_export] macro_rules! stack { ($axis:expr, $( $array:expr ),+ ) => { - $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() + $crate::stack($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap() } } diff --git a/tests/stacking.rs b/tests/stacking.rs index a9a031711..2baaeed04 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,26 +1,43 @@ -use ndarray::{arr2, aview1, stack, Array2, Axis, ErrorKind}; +use ndarray::{arr2, arr3, aview1, concatenate, Array2, Axis, ErrorKind, Ix1}; #[test] -fn stacking() { +fn concatenating() { let a = arr2(&[[2., 2.], [3., 3.]]); - let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); + let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]])); - let c = stack![Axis(0), a, b]; + let c = concatenate![Axis(0), a, b]; assert_eq!( c, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]]) ); - let d = stack![Axis(0), a.row(0), &[9., 9.]]; + let d = concatenate![Axis(0), a.row(0), &[9., 9.]]; assert_eq!(d, aview1(&[2., 2., 9., 9.])); - let res = ndarray::stack(Axis(1), &[a.view(), c.view()]); + let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); + + let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); + + let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); +} + +#[test] +fn stacking() { + let a = arr2(&[[2., 2.], [3., 3.]]); + let b = ndarray::stack(Axis(0), vec![a.view(), a.view()]).unwrap(); + assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); + + let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]); + let res = ndarray::stack(Axis(1), vec![a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); - let res = ndarray::stack(Axis(2), &[a.view(), c.view()]); + let res = ndarray::stack(Axis(3), vec![a.view(), a.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::stack(Axis(0), &[]); + let res: Result, _> = ndarray::stack::<_, Ix1>(Axis(0), vec![]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } From 9f1a3d08a1d7a5b1cebc6801e0dc6688fafee754 Mon Sep 17 00:00:00 2001 From: andrei-papou Date: Tue, 1 Oct 2019 16:23:32 +0300 Subject: [PATCH 2/7] Rustfmt fixes --- src/stacking.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/stacking.rs b/src/stacking.rs index 975a840d6..d0d4812b3 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -29,10 +29,7 @@ use crate::imp_prelude::*; /// [3., 3.]])) /// ); /// ``` -pub fn concatenate( - axis: Axis, - arrays: &[ArrayView], -) -> Result, ShapeError> +pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> where A: Copy, D: RemoveAxis, @@ -90,8 +87,8 @@ where return Err(from_kind(ErrorKind::OutOfBounds)); } } - let arrays: Vec> = arrays.into_iter() - .map(|a| a.insert_axis(axis)).collect(); + let arrays: Vec> = + arrays.into_iter().map(|a| a.insert_axis(axis)).collect(); concatenate(axis, &arrays) } From dee84ae025631aa1c38fb1b92bd7ad7652806b85 Mon Sep 17 00:00:00 2001 From: "andrei.papou" Date: Thu, 17 Oct 2019 01:51:49 +0300 Subject: [PATCH 3/7] Removed allocation in `stack_new_axis`, renamed `concatenate` -> `stack`, `stack` -> `stack_new_axis` --- src/doc/ndarray_for_numpy_users/mod.rs | 8 +-- src/impl_methods.rs | 4 +- src/lib.rs | 2 +- src/stacking.rs | 69 +++++++++++++++++--------- tests/stacking.rs | 22 ++++---- 5 files changed, 64 insertions(+), 41 deletions(-) diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index 79d1314a8..d3125e712 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -532,8 +532,8 @@ //! ------|-----------|------ //! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value //! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a` -//! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1 -//! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1 +//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1 +//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1 //! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1 //! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`) //! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a` @@ -641,10 +641,10 @@ //! [.slice_move()]: ../../struct.ArrayBase.html#method.slice_move //! [.slice_mut()]: ../../struct.ArrayBase.html#method.slice_mut //! [.shape()]: ../../struct.ArrayBase.html#method.shape -//! [concatenate!]: ../../macro.concatenate.html -//! [concatenate()]: ../../fn.concatenate.html //! [stack!]: ../../macro.stack.html //! [stack()]: ../../fn.stack.html +//! [stack_new_axis!]: ../../macro.stack_new_axis.html +//! [stack_new_axis()]: ../../fn.stack_new_axis.html //! [.strides()]: ../../struct.ArrayBase.html#method.strides //! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis //! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9853a1c47..db502663f 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -28,7 +28,7 @@ use crate::iter::{ IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; use crate::slice::MultiSlice; -use crate::stacking::concatenate; +use crate::stacking::stack; use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex}; /// # Methods For All Array Types @@ -840,7 +840,7 @@ where dim.set_axis(axis, 0); unsafe { Array::from_shape_vec_unchecked(dim, vec![]) } } else { - concatenate(axis, &subs).unwrap() + stack(axis, &subs).unwrap() } } diff --git a/src/lib.rs b/src/lib.rs index 191687c02..c12ea7a2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,7 +131,7 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane pub use crate::arraytraits::AsArray; pub use crate::linalg_traits::{LinalgScalar, NdFloat}; -pub use crate::stacking::{concatenate, stack}; +pub use crate::stacking::{stack, stack_new_axis}; pub use crate::impl_views::IndexLonger; pub use crate::shape_builder::ShapeBuilder; diff --git a/src/stacking.rs b/src/stacking.rs index d0d4812b3..24b98108a 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -17,19 +17,19 @@ use crate::imp_prelude::*; /// if the result is larger than is possible to represent. /// /// ``` -/// use ndarray::{arr2, Axis, concatenate}; +/// use ndarray::{arr2, Axis, stack}; /// /// let a = arr2(&[[2., 2.], /// [3., 3.]]); /// assert!( -/// concatenate(Axis(0), &[a.view(), a.view()]) +/// stack(Axis(0), &[a.view(), a.view()]) /// == Ok(arr2(&[[2., 2.], /// [3., 3.], /// [2., 2.], /// [3., 3.]])) /// ); /// ``` -pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> +pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> where A: Copy, D: RemoveAxis, @@ -73,7 +73,7 @@ where Ok(res) } -pub fn stack( +pub fn stack_new_axis( axis: Axis, arrays: Vec>, ) -> Result, ShapeError> @@ -82,36 +82,59 @@ where D: Dimension, D::Larger: RemoveAxis, { - if let Some(ndim) = D::NDIM { - if axis.index() > ndim { - return Err(from_kind(ErrorKind::OutOfBounds)); - } + if arrays.is_empty() { + return Err(from_kind(ErrorKind::Unsupported)); + } + let common_dim = arrays[0].raw_dim(); + // Avoid panic on `insert_axis` call, return an Err instead of it. + if axis.index() > common_dim.ndim() { + return Err(from_kind(ErrorKind::OutOfBounds)); + } + let mut res_dim = common_dim.insert_axis(axis); + + if arrays.iter().any(|a| a.raw_dim() != common_dim) { + return Err(from_kind(ErrorKind::IncompatibleShape)); } - let arrays: Vec> = - arrays.into_iter().map(|a| a.insert_axis(axis)).collect(); - concatenate(axis, &arrays) + + res_dim.set_axis(axis, arrays.len()); + + // we can safely use uninitialized values here because they are Copy + // and we will only ever write to them + let size = res_dim.size(); + let mut v = Vec::with_capacity(size); + unsafe { + v.set_len(size); + } + let mut res = Array::from_shape_vec(res_dim, v)?; + + res.axis_iter_mut(axis).zip(arrays.into_iter()) + .for_each(|(mut assign_view, array)| { + assign_view.assign(&array); + }); + + Ok(res) } /// Concatenate arrays along the given axis. /// -/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each +/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each /// argument `a`. /// -/// [1]: fn.concatenate.html +/// [1]: fn.stack.html /// /// ***Panics*** if the `concatenate` function would return an error. /// /// ``` /// extern crate ndarray; /// -/// use ndarray::{arr2, concatenate, Axis}; +/// use ndarray::{arr2, stack, Axis}; /// /// # fn main() { /// /// let a = arr2(&[[2., 2.], /// [3., 3.]]); /// assert!( -/// concatenate![Axis(0), a, a] +/// stack![Axis(0), a, a] /// == arr2(&[[2., 2.], /// [3., 3.], /// [2., 2.], @@ -120,32 +143,32 @@ where /// # } /// ``` #[macro_export] -macro_rules! concatenate { +macro_rules! stack { ($axis:expr, $( $array:expr ),+ ) => { - $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() + $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() } } /// Stack arrays along the new axis. /// -/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each +/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each /// argument `a`. /// -/// [1]: fn.concatenate.html +/// [1]: fn.stack_new_axis.html /// /// ***Panics*** if the `stack` function would return an error. /// /// ``` /// extern crate ndarray; /// -/// use ndarray::{arr2, arr3, stack, Axis}; +/// use ndarray::{arr2, arr3, stack_new_axis, Axis}; /// /// # fn main() { /// /// let a = arr2(&[[2., 2.], /// [3., 3.]]); /// assert!( -/// stack![Axis(0), a, a] +/// stack_new_axis![Axis(0), a, a] /// == arr3(&[[[2., 2.], /// [3., 3.]], /// [[2., 2.], @@ -154,8 +177,8 @@ macro_rules! concatenate { /// # } /// ``` #[macro_export] -macro_rules! stack { +macro_rules! stack_new_axis { ($axis:expr, $( $array:expr ),+ ) => { - $crate::stack($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap() + $crate::stack_new_axis($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap() } } diff --git a/tests/stacking.rs b/tests/stacking.rs index 2baaeed04..c07c5bb24 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,43 +1,43 @@ -use ndarray::{arr2, arr3, aview1, concatenate, Array2, Axis, ErrorKind, Ix1}; +use ndarray::{arr2, arr3, aview1, stack, Array2, Axis, ErrorKind, Ix1}; #[test] fn concatenating() { let a = arr2(&[[2., 2.], [3., 3.]]); - let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap(); + let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]])); - let c = concatenate![Axis(0), a, b]; + let c = stack![Axis(0), a, b]; assert_eq!( c, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]]) ); - let d = concatenate![Axis(0), a.row(0), &[9., 9.]]; + let d = stack![Axis(0), a.row(0), &[9., 9.]]; assert_eq!(d, aview1(&[2., 2., 9., 9.])); - let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]); + let res = ndarray::stack(Axis(1), &[a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); - let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); + let res = ndarray::stack(Axis(2), &[a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + let res: Result, _> = ndarray::stack(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } #[test] fn stacking() { let a = arr2(&[[2., 2.], [3., 3.]]); - let b = ndarray::stack(Axis(0), vec![a.view(), a.view()]).unwrap(); + let b = ndarray::stack_new_axis(Axis(0), vec![a.view(), a.view()]).unwrap(); assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]); - let res = ndarray::stack(Axis(1), vec![a.view(), c.view()]); + let res = ndarray::stack_new_axis(Axis(1), vec![a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); - let res = ndarray::stack(Axis(3), vec![a.view(), a.view()]); + let res = ndarray::stack_new_axis(Axis(3), vec![a.view(), a.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::stack::<_, Ix1>(Axis(0), vec![]); + let res: Result, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), vec![]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } From 1bb7104e43740116d00afc4471af48a99e52f325 Mon Sep 17 00:00:00 2001 From: "andrei.papou" Date: Thu, 17 Oct 2019 03:10:59 +0300 Subject: [PATCH 4/7] Rustfmt fixes --- src/stacking.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stacking.rs b/src/stacking.rs index 24b98108a..3847c6dc6 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -107,7 +107,8 @@ where } let mut res = Array::from_shape_vec(res_dim, v)?; - res.axis_iter_mut(axis).zip(arrays.into_iter()) + res.axis_iter_mut(axis) + .zip(arrays.into_iter()) .for_each(|(mut assign_view, array)| { assign_view.assign(&array); }); From 14f2f576f4dab52f97cb87e70f45eea69f7d935e Mon Sep 17 00:00:00 2001 From: "andrei.papou" Date: Wed, 18 Dec 2019 21:54:09 +0300 Subject: [PATCH 5/7] Introduced concatenate function, deprecated stack function since 0.13.0 --- src/lib.rs | 2 +- src/stacking.rs | 70 +++++++++++++++++++++++++++++++++++++++++++++-- tests/stacking.rs | 24 +++++++++++++++- 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c12ea7a2e..f9b90a0f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,7 +131,7 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane pub use crate::arraytraits::AsArray; pub use crate::linalg_traits::{LinalgScalar, NdFloat}; -pub use crate::stacking::{stack, stack_new_axis}; +pub use crate::stacking::{concatenate, stack, stack_new_axis}; pub use crate::impl_views::IndexLonger; pub use crate::shape_builder::ShapeBuilder; diff --git a/src/stacking.rs b/src/stacking.rs index 3847c6dc6..d5d0fb94c 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -9,7 +9,7 @@ use crate::error::{from_kind, ErrorKind, ShapeError}; use crate::imp_prelude::*; -/// Stack arrays along the given axis. +/// Concatenate arrays along the given axis. /// /// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`. /// (may be made more flexible in the future).
@@ -29,6 +29,10 @@ use crate::imp_prelude::*; /// [3., 3.]])) /// ); /// ``` +#[deprecated( + since = "0.13.0", + note = "Please use the `concatenate` function instead" +)] pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> where A: Copy, @@ -73,6 +77,34 @@ where Ok(res) } +/// Concatenate arrays along the given axis. +/// +/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`. +/// (may be made more flexible in the future).
+/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds, +/// if the result is larger than is possible to represent. +/// +/// ``` +/// use ndarray::{arr2, Axis, concatenate}; +/// +/// let a = arr2(&[[2., 2.], +/// [3., 3.]]); +/// assert!( +/// concatenate(Axis(0), &[a.view(), a.view()]) +/// == Ok(arr2(&[[2., 2.], +/// [3., 3.], +/// [2., 2.], +/// [3., 3.]])) +/// ); +/// ``` +pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> +where + A: Copy, + D: RemoveAxis, +{ + stack(axis, arrays) +} + pub fn stack_new_axis( axis: Axis, arrays: Vec>, @@ -123,7 +155,7 @@ where /// /// [1]: fn.stack.html /// -/// ***Panics*** if the `concatenate` function would return an error. +/// ***Panics*** if the `stack` function would return an error. /// /// ``` /// extern crate ndarray; @@ -150,6 +182,40 @@ macro_rules! stack { } } +/// Concatenate arrays along the given axis. +/// +/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each +/// argument `a`. +/// +/// [1]: fn.concatenate.html +/// +/// ***Panics*** if the `concatenate` function would return an error. +/// +/// ``` +/// extern crate ndarray; +/// +/// use ndarray::{arr2, concatenate, Axis}; +/// +/// # fn main() { +/// +/// let a = arr2(&[[2., 2.], +/// [3., 3.]]); +/// assert!( +/// concatenate![Axis(0), a, a] +/// == arr2(&[[2., 2.], +/// [3., 3.], +/// [2., 2.], +/// [3., 3.]]) +/// ); +/// # } +/// ``` +#[macro_export] +macro_rules! concatenate { + ($axis:expr, $( $array:expr ),+ ) => { + $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() + } +} + /// Stack arrays along the new axis. /// /// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each diff --git a/tests/stacking.rs b/tests/stacking.rs index c07c5bb24..f723adc57 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,4 +1,4 @@ -use ndarray::{arr2, arr3, aview1, stack, Array2, Axis, ErrorKind, Ix1}; +use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1}; #[test] fn concatenating() { @@ -23,6 +23,28 @@ fn concatenating() { let res: Result, _> = ndarray::stack(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); + + let a = arr2(&[[2., 2.], [3., 3.]]); + let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap(); + assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]])); + + let c = concatenate![Axis(0), a, b]; + assert_eq!( + c, + arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]]) + ); + + let d = concatenate![Axis(0), a.row(0), &[9., 9.]]; + assert_eq!(d, aview1(&[2., 2., 9., 9.])); + + let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); + + let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); + + let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } #[test] From ef74ac900b24f455ff0db8a614f58097d469e5bd Mon Sep 17 00:00:00 2001 From: "andrei.papou" Date: Wed, 18 Dec 2019 22:22:21 +0300 Subject: [PATCH 6/7] Updated deprecation version for stack function, suppressed deprecation warnings --- src/impl_methods.rs | 4 ++-- src/lib.rs | 2 ++ src/stacking.rs | 3 ++- tests/stacking.rs | 2 ++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index db502663f..9853a1c47 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -28,7 +28,7 @@ use crate::iter::{ IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; use crate::slice::MultiSlice; -use crate::stacking::stack; +use crate::stacking::concatenate; use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex}; /// # Methods For All Array Types @@ -840,7 +840,7 @@ where dim.set_axis(axis, 0); unsafe { Array::from_shape_vec_unchecked(dim, vec![]) } } else { - stack(axis, &subs).unwrap() + concatenate(axis, &subs).unwrap() } } diff --git a/src/lib.rs b/src/lib.rs index f9b90a0f8..ad7ffec7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,6 +131,8 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane pub use crate::arraytraits::AsArray; pub use crate::linalg_traits::{LinalgScalar, NdFloat}; + +#[allow(deprecated)] pub use crate::stacking::{concatenate, stack, stack_new_axis}; pub use crate::impl_views::IndexLonger; diff --git a/src/stacking.rs b/src/stacking.rs index d5d0fb94c..822b4b56d 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -30,7 +30,7 @@ use crate::imp_prelude::*; /// ); /// ``` #[deprecated( - since = "0.13.0", + since = "0.13.1", note = "Please use the `concatenate` function instead" )] pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> @@ -97,6 +97,7 @@ where /// [3., 3.]])) /// ); /// ``` +#[allow(deprecated)] pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> where A: Copy, diff --git a/tests/stacking.rs b/tests/stacking.rs index f723adc57..27d36d0f3 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,3 +1,5 @@ +#![allow(deprecated)] + use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1}; #[test] From 06e61458831fd41db61ee614df6102583b04d1cd Mon Sep 17 00:00:00 2001 From: andrei-papou Date: Fri, 25 Sep 2020 11:46:39 +0300 Subject: [PATCH 7/7] Updated `stack_new_axis`, deprecated `stack!` macro. --- src/stacking.rs | 34 +++++++++++++++++++++++++++++++--- tests/stacking.rs | 8 ++++---- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/stacking.rs b/src/stacking.rs index 822b4b56d..e31581679 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -30,7 +30,7 @@ use crate::imp_prelude::*; /// ); /// ``` #[deprecated( - since = "0.13.1", + since = "0.13.2", note = "Please use the `concatenate` function instead" )] pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> @@ -106,9 +106,33 @@ where stack(axis, arrays) } +/// Stack arrays along the new axis. +/// +/// ***Errors*** if the arrays have mismatching shapes. +/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds, +/// if the result is larger than is possible to represent. +/// +/// ``` +/// extern crate ndarray; +/// +/// use ndarray::{arr2, arr3, stack_new_axis, Axis}; +/// +/// # fn main() { +/// +/// let a = arr2(&[[2., 2.], +/// [3., 3.]]); +/// assert!( +/// stack_new_axis(Axis(0), &[a.view(), a.view()]) +/// == Ok(arr3(&[[[2., 2.], +/// [3., 3.]], +/// [[2., 2.], +/// [3., 3.]]])) +/// ); +/// # } +/// ``` pub fn stack_new_axis( axis: Axis, - arrays: Vec>, + arrays: &[ArrayView], ) -> Result, ShapeError> where A: Copy, @@ -176,6 +200,10 @@ where /// ); /// # } /// ``` +#[deprecated( + since = "0.13.2", + note = "Please use the `concatenate!` macro instead" +)] #[macro_export] macro_rules! stack { ($axis:expr, $( $array:expr ),+ ) => { @@ -247,6 +275,6 @@ macro_rules! concatenate { #[macro_export] macro_rules! stack_new_axis { ($axis:expr, $( $array:expr ),+ ) => { - $crate::stack_new_axis($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap() + $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() } } diff --git a/tests/stacking.rs b/tests/stacking.rs index 27d36d0f3..94077def2 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -52,16 +52,16 @@ fn concatenating() { #[test] fn stacking() { let a = arr2(&[[2., 2.], [3., 3.]]); - let b = ndarray::stack_new_axis(Axis(0), vec![a.view(), a.view()]).unwrap(); + let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]); - let res = ndarray::stack_new_axis(Axis(1), vec![a.view(), c.view()]); + let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); - let res = ndarray::stack_new_axis(Axis(3), vec![a.view(), a.view()]); + let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), vec![]); + let res: Result, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); }