diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4202ec257..c9b2d6f49 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1766,20 +1766,11 @@ where unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } - /// Calculate the views of two ArrayBases after broadcasting each other, if possible. + /// For two arrays or views, find their common shape if possible and + /// broadcast them as array views into that shape. /// /// Return `ShapeError` if their shapes can not be broadcast together. - /// - /// ``` - /// use ndarray::{arr1, arr2}; - /// - /// let a = arr2(&[[2], [3], [4]]); - /// let b = arr1(&[5, 6, 7]); - /// let (a1, b1) = a.broadcast_with(&b).unwrap(); - /// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); - /// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); - /// ``` - pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> + pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> Result<(ArrayView<'a, A, >::Output>, ArrayView<'b, B, >::Output>), ShapeError> where S: Data, diff --git a/tests/broadcast.rs b/tests/broadcast.rs index 26111c780..e3d377139 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -82,33 +82,3 @@ fn test_broadcast_1d() { println!("b2=\n{:?}", b2); assert_eq!(b0, b2); } - -#[test] -fn test_broadcast_with() { - let a = arr2(&[[1., 2.], [3., 4.]]); - let b = aview0(&1.); - let (a1, b1) = a.broadcast_with(&b).unwrap(); - assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]])); - assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]])); - - let a = arr2(&[[2], [3], [4]]); - let b = arr1(&[5, 6, 7]); - let (a1, b1) = a.broadcast_with(&b).unwrap(); - assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); - assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); - - // Negative strides and non-contiguous memory - let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); - let a = s.slice(s![..;-1,..;2,..]); - let b = s.slice(s![..2, -1, ..]); - let (a1, b1) = a.broadcast_with(&b).unwrap(); - assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]])); - assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]])); - - // ShapeError - let a = arr2(&[[2, 2], [3, 3], [4, 4]]); - let b = arr1(&[5, 6, 7]); - let e = a.broadcast_with(&b); - assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); -}