From b6232398675b00af31ae8870ce045785fc332cc2 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Thu, 18 Feb 2021 21:03:57 +0800 Subject: [PATCH] Add function broadcast_with --- src/impl_methods.rs | 34 ++++++++++++++++++++++++++++++++-- src/impl_ops.rs | 12 +++--------- tests/broadcast.rs | 31 +++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9a1e0fac8..4202ec257 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -14,14 +14,14 @@ use rawpointer::PointerExt; use crate::imp_prelude::*; -use crate::arraytraits; +use crate::{arraytraits, BroadcastShape}; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; -use crate::error::{self, ErrorKind, ShapeError}; +use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::Zip; @@ -1766,6 +1766,36 @@ where unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } + /// Calculate the views of two ArrayBases after broadcasting each other, if possible. + /// + /// Return `ShapeError` if their shapes can not be broadcast together. + /// + /// ``` + /// use ndarray::{arr1, arr2}; + /// + /// let a = arr2(&[[2], [3], [4]]); + /// let b = arr1(&[5, 6, 7]); + /// let (a1, b1) = a.broadcast_with(&b).unwrap(); + /// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); + /// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); + /// ``` + pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> + Result<(ArrayView<'a, A, >::Output>, ArrayView<'b, B, >::Output>), ShapeError> + where + S: Data, + S2: Data, + D: Dimension + BroadcastShape, + E: Dimension, + { + let shape = self.dim.broadcast_shape(&other.dim)?; + if let Some(view1) = self.broadcast(shape.clone()) { + if let Some(view2) = other.broadcast(shape) { + return Ok((view1, view2)) + } + } + return Err(from_kind(ErrorKind::IncompatibleShape)); + } + /// Swap axes `ax` and `bx`. /// /// This does not move any data, it just adjusts the array’s dimensions diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 69c6f698e..80aa8e11a 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -106,9 +106,7 @@ where out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); out } else { - let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); + let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } @@ -143,9 +141,7 @@ where out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); out } else { - let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); - let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); + let (rhs, lhs) = rhs.broadcast_with(self).unwrap(); Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } @@ -171,9 +167,7 @@ where { type Output = Array>::Output>; fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { - let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); + let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth)) } } diff --git a/tests/broadcast.rs b/tests/broadcast.rs index 5416e9017..26111c780 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -1,4 +1,5 @@ use ndarray::prelude::*; +use ndarray::{ShapeError, ErrorKind, arr3}; #[test] #[cfg(feature = "std")] @@ -81,3 +82,33 @@ fn test_broadcast_1d() { println!("b2=\n{:?}", b2); assert_eq!(b0, b2); } + +#[test] +fn test_broadcast_with() { + let a = arr2(&[[1., 2.], [3., 4.]]); + let b = aview0(&1.); + let (a1, b1) = a.broadcast_with(&b).unwrap(); + assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]])); + assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]])); + + let a = arr2(&[[2], [3], [4]]); + let b = arr1(&[5, 6, 7]); + let (a1, b1) = a.broadcast_with(&b).unwrap(); + assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); + assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); + + // Negative strides and non-contiguous memory + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); + let a = s.slice(s![..;-1,..;2,..]); + let b = s.slice(s![..2, -1, ..]); + let (a1, b1) = a.broadcast_with(&b).unwrap(); + assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]])); + assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]])); + + // ShapeError + let a = arr2(&[[2, 2], [3, 3], [4, 4]]); + let b = arr1(&[5, 6, 7]); + let e = a.broadcast_with(&b); + assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); +}