diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 6da4375d0..6b13de65d 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; /// /// Uses the [NumPy broadcasting rules] // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). -fn co_broadcast(shape1: &D1, shape2: &D2) -> Result - where - D1: Dimension, - D2: Dimension, - Output: Dimension, +pub(crate) fn co_broadcast(shape1: &D1, shape2: &D2) -> Result +where + D1: Dimension, + D2: Dimension, + Output: Dimension, { let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim()); // Swap the order if d2 is longer. @@ -37,11 +37,6 @@ fn co_broadcast(shape1: &D1, shape2: &D2) -> Result { /// The resulting dimension type after broadcasting. type Output: Dimension; - - /// Determines the shape after broadcasting the shapes together. - /// - /// If the shapes are not compatible, returns `Err`. - fn broadcast_shape(&self, other: &Other) -> Result; } /// Dimensions of the same type remain unchanged when co_broadcast. @@ -49,28 +44,16 @@ pub trait DimMax { /// (Instead of >::BroadcastOutput) impl DimMax for D { type Output = D; - - fn broadcast_shape(&self, other: &D) -> Result { - co_broadcast::(self, other) - } } macro_rules! impl_broadcast_distinct_fixed { ($smaller:ty, $larger:ty) => { impl DimMax<$larger> for $smaller { type Output = $larger; - - fn broadcast_shape(&self, other: &$larger) -> Result { - co_broadcast::(self, other) - } } impl DimMax<$smaller> for $larger { type Output = $larger; - - fn broadcast_shape(&self, other: &$smaller) -> Result { - co_broadcast::(self, other) - } } }; } @@ -103,3 +86,57 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn); impl_broadcast_distinct_fixed!(Ix4, IxDyn); impl_broadcast_distinct_fixed!(Ix5, IxDyn); impl_broadcast_distinct_fixed!(Ix6, IxDyn); + + +#[cfg(test)] +mod tests { + use super::co_broadcast; + use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind}; + + #[test] + fn test_broadcast_shape() { + fn test_co( + d1: &D1, + d2: &D2, + r: Result<>::Output, ShapeError>, + ) where + D1: Dimension + DimMax, + D2: Dimension, + { + let d = co_broadcast::>::Output>(&d1, d2); + assert_eq!(d, r); + } + test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); + test_co( + &Dim([1, 2, 2]), + &Dim([1, 3, 4]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); + let v = vec![1, 2, 3, 4, 5, 6, 7]; + test_co( + &Dim(vec![1, 1, 3, 1, 5, 1, 7]), + &Dim([2, 1, 4, 1, 6, 1]), + Ok(Dim(IxDynImpl::from(v.as_slice()))), + ); + let d = Dim([1, 2, 1, 3]); + test_co(&d, &d, Ok(d)); + test_co( + &Dim([2, 1, 2]).into_dyn(), + &Dim(0), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co( + &Dim([2, 1, 1]), + &Dim([0, 0, 1, 3, 4]), + Ok(Dim([0, 0, 2, 3, 4])), + ); + test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); + test_co( + &Dim([1, 3, 0, 1, 1]), + &Dim([1, 2, 3, 1]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + } +} diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 7f5eeeaf7..2505681b5 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -29,7 +29,7 @@ use std::mem; mod macros; mod axes; mod axis; -mod broadcast; +pub(crate) mod broadcast; mod conversion; pub mod dim; mod dimension_trait; diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 3079c9d2a..40c7fe1f2 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -21,6 +21,7 @@ use crate::dimension::{ abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; +use crate::dimension::broadcast::co_broadcast; use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; @@ -1778,7 +1779,7 @@ where D: Dimension + DimMax, E: Dimension, { - let shape = self.dim.broadcast_shape(&other.dim)?; + let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; if let Some(view1) = self.broadcast(shape.clone()) { if let Some(view2) = other.broadcast(shape) { return Ok((view1, view2)) diff --git a/tests/dimension.rs b/tests/dimension.rs index 2bbc50a68..939b4f0e3 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -2,8 +2,7 @@ use defmac::defmac; -use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, Ix0, IxDyn, IxDynImpl, RemoveAxis, - ErrorKind, ShapeError, DimMax}; +use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IxDyn, RemoveAxis}; use std::hash::{Hash, Hasher}; @@ -341,50 +340,3 @@ fn test_all_ndindex() { ndindex!(10, 4, 3, 2, 2); ndindex!(10, 4, 3, 2, 2, 2); } - -#[test] -fn test_broadcast_shape() { - fn test_co( - d1: &D1, - d2: &D2, - r: Result<>::Output, ShapeError>, - ) where - D1: Dimension + DimMax, - D2: Dimension, - { - let d = d1.broadcast_shape(d2); - assert_eq!(d, r); - } - test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); - test_co( - &Dim([1, 2, 2]), - &Dim([1, 3, 4]), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); - test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); - let v = vec![1, 2, 3, 4, 5, 6, 7]; - test_co( - &Dim(vec![1, 1, 3, 1, 5, 1, 7]), - &Dim([2, 1, 4, 1, 6, 1]), - Ok(Dim(IxDynImpl::from(v.as_slice()))), - ); - let d = Dim([1, 2, 1, 3]); - test_co(&d, &d, Ok(d)); - test_co( - &Dim([2, 1, 2]).into_dyn(), - &Dim(0), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); - test_co( - &Dim([2, 1, 1]), - &Dim([0, 0, 1, 3, 4]), - Ok(Dim([0, 0, 2, 3, 4])), - ); - test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); - test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); - test_co( - &Dim([1, 3, 0, 1, 1]), - &Dim([1, 2, 3, 1]), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); -}