From f7a94916e7e496c99c00e8535077d8ce18bd0e72 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Wed, 20 Jan 2021 11:29:12 +0800 Subject: [PATCH 01/16] Fix co_broadcast in operator overloading --- src/dimension/broadcast.rs | 105 ++++++++++++++++++++++++++++++++++++ src/dimension/mod.rs | 2 + src/impl_ops.rs | 85 +++++++++++++++++++++-------- src/lib.rs | 1 + src/numeric/impl_numeric.rs | 5 +- tests/array.rs | 94 +++++++++++++++++++++++++++++++- 6 files changed, 266 insertions(+), 26 deletions(-) create mode 100644 src/dimension/broadcast.rs diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs new file mode 100644 index 000000000..f29936cd1 --- /dev/null +++ b/src/dimension/broadcast.rs @@ -0,0 +1,105 @@ +use crate::error::*; +use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; + +/// Calculate the co_broadcast shape of two dimensions. Return error if shapes are +/// not compatible. +fn broadcast_shape(shape1: &D1, shape2: &D2) -> Result +where + D1: Dimension, + D2: Dimension, + Output: Dimension, +{ + let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim()); + // Swap the order if d2 is longer. + if overflow { + return broadcast_shape::(shape2, shape1); + } + // The output should be the same length as shape1. + let mut out = Output::zeros(shape1.ndim()); + let out_slice = out.slice_mut(); + let s1 = shape1.slice(); + let s2 = shape2.slice(); + // Uses the [NumPy broadcasting rules] + // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). + // + // Zero dimension element is not in the original rules of broadcasting. + // We currently treat it as the same as 1. Especially, when one side is + // zero with one side is empty, or both sides are zero, the result will + // remain zero. + for i in 0..shape1.ndim() { + out_slice[i] = s1[i]; + } + for i in 0..shape2.ndim() { + if out_slice[i + k] != s2[i] && s2[i] != 0 { + if out_slice[i + k] <= 1 { + out_slice[i + k] = s2[i] + } else if s2[i] != 1 { + return Err(from_kind(ErrorKind::IncompatibleShape)); + } + } + } + Ok(out) +} + +pub trait BroadcastShape: Dimension { + /// The resulting dimension type after broadcasting. + type BroadcastOutput: Dimension; + + /// Determines the shape after broadcasting the dimensions together. + /// + /// If the dimensions are not compatible, returns `Err`. + /// + /// Uses the [NumPy broadcasting rules] + /// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). + fn broadcast_shape(&self, other: &Other) -> Result { + broadcast_shape::(self, other) + } +} + +/// Dimensions of the same type remain unchanged when co_broadcast. +/// So you can directly use D as the resulting type. +/// (Instead of >::BroadcastOutput) +impl BroadcastShape for D { + type BroadcastOutput = D; +} + +macro_rules! impl_broadcast_distinct_fixed { + ($smaller:ty, $larger:ty) => { + impl BroadcastShape<$larger> for $smaller { + type BroadcastOutput = $larger; + } + + impl BroadcastShape<$smaller> for $larger { + type BroadcastOutput = $larger; + } + }; +} + +impl_broadcast_distinct_fixed!(Ix0, Ix1); +impl_broadcast_distinct_fixed!(Ix0, Ix2); +impl_broadcast_distinct_fixed!(Ix0, Ix3); +impl_broadcast_distinct_fixed!(Ix0, Ix4); +impl_broadcast_distinct_fixed!(Ix0, Ix5); +impl_broadcast_distinct_fixed!(Ix0, Ix6); +impl_broadcast_distinct_fixed!(Ix1, Ix2); +impl_broadcast_distinct_fixed!(Ix1, Ix3); +impl_broadcast_distinct_fixed!(Ix1, Ix4); +impl_broadcast_distinct_fixed!(Ix1, Ix5); +impl_broadcast_distinct_fixed!(Ix1, Ix6); +impl_broadcast_distinct_fixed!(Ix2, Ix3); +impl_broadcast_distinct_fixed!(Ix2, Ix4); +impl_broadcast_distinct_fixed!(Ix2, Ix5); +impl_broadcast_distinct_fixed!(Ix2, Ix6); +impl_broadcast_distinct_fixed!(Ix3, Ix4); +impl_broadcast_distinct_fixed!(Ix3, Ix5); +impl_broadcast_distinct_fixed!(Ix3, Ix6); +impl_broadcast_distinct_fixed!(Ix4, Ix5); +impl_broadcast_distinct_fixed!(Ix4, Ix6); +impl_broadcast_distinct_fixed!(Ix5, Ix6); +impl_broadcast_distinct_fixed!(Ix0, IxDyn); +impl_broadcast_distinct_fixed!(Ix1, IxDyn); +impl_broadcast_distinct_fixed!(Ix2, IxDyn); +impl_broadcast_distinct_fixed!(Ix3, IxDyn); +impl_broadcast_distinct_fixed!(Ix4, IxDyn); +impl_broadcast_distinct_fixed!(Ix5, IxDyn); +impl_broadcast_distinct_fixed!(Ix6, IxDyn); diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 1359b8f39..98572ac59 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -12,6 +12,7 @@ use num_integer::div_floor; pub use self::axes::{axes_of, Axes, AxisDescription}; pub use self::axis::Axis; +pub use self::broadcast::BroadcastShape; pub use self::conversion::IntoDimension; pub use self::dim::*; pub use self::dimension_trait::Dimension; @@ -28,6 +29,7 @@ use std::mem; mod macros; mod axes; mod axis; +mod broadcast; mod conversion; pub mod dim; mod dimension_trait; diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 256bee3e5..d7645b8cb 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -6,6 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use crate::dimension::BroadcastShape; use num_complex::Complex; /// Elements that can be used as direct operands in arithmetic with arrays. @@ -53,24 +54,48 @@ macro_rules! impl_binary_op( /// Perform elementwise #[doc=$doc] /// between `self` and `rhs`, -/// and return the result (based on `self`). -/// -/// `self` must be an `Array` or `ArcArray`. +/// and return the result. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl $trt> for ArrayBase where A: Clone + $trt, B: Clone, - S: DataOwned + DataMut, + S: Data, S2: Data, - D: Dimension, + D: Dimension + BroadcastShape, E: Dimension, { - type Output = ArrayBase; - fn $mth(self, rhs: ArrayBase) -> ArrayBase + type Output = Array>::BroadcastOutput>; + fn $mth(self, rhs: ArrayBase) -> Self::Output + { + self.$mth(&rhs) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between reference `self` and `rhs`, +/// and return the result as a new `Array`. +/// +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase +where + A: Clone + $trt, + B: Clone, + S: Data, + S2: Data, + D: Dimension + BroadcastShape, + E: Dimension, +{ + type Output = Array>::BroadcastOutput>; + fn $mth(self, rhs: ArrayBase) -> Self::Output { self.$mth(&rhs) } @@ -79,27 +104,34 @@ where /// Perform elementwise #[doc=$doc] /// between `self` and reference `rhs`, -/// and return the result (based on `self`). +/// and return the result. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase where A: Clone + $trt, B: Clone, - S: DataOwned + DataMut, + S: Data, S2: Data, - D: Dimension, + D: Dimension + BroadcastShape, E: Dimension, { - type Output = ArrayBase; - fn $mth(mut self, rhs: &ArrayBase) -> ArrayBase + type Output = Array>::BroadcastOutput>; + fn $mth(self, rhs: &ArrayBase) -> Self::Output { - self.zip_mut_with(rhs, |x, y| { + let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); + let mut self_ = if shape.slice() == self.dim.slice() { + self.into_owned().into_dimensionality::<>::BroadcastOutput>().unwrap() + } else { + self.broadcast(shape).unwrap().to_owned() + }; + self_.zip_mut_with(rhs, |x, y| { *x = x.clone() $operator y.clone(); }); - self + self_ } } @@ -108,7 +140,8 @@ where /// between references `self` and `rhs`, /// and return the result as a new `Array`. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase @@ -117,13 +150,21 @@ where B: Clone, S: Data, S2: Data, - D: Dimension, + D: Dimension + BroadcastShape, E: Dimension, { - type Output = Array; - fn $mth(self, rhs: &'a ArrayBase) -> Array { - // FIXME: Can we co-broadcast arrays here? And how? - self.to_owned().$mth(rhs) + type Output = Array>::BroadcastOutput>; + fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { + let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); + let mut self_ = if shape.slice() == self.dim.slice() { + self.to_owned().into_dimensionality::<>::BroadcastOutput>().unwrap() + } else { + self.broadcast(shape).unwrap().to_owned() + }; + self_.zip_mut_with(rhs, |x, y| { + *x = x.clone() $operator y.clone(); + }); + self_ } } diff --git a/src/lib.rs b/src/lib.rs index 9cd7dc3f3..1a760bf31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,6 +134,7 @@ use std::marker::PhantomData; use alloc::sync::Arc; pub use crate::dimension::dim::*; +pub use crate::dimension::BroadcastShape; pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis}; pub use crate::dimension::IxDynImpl; diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index fbec80418..5485a774e 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -13,7 +13,7 @@ use std::ops::{Add, Div, Mul}; use crate::imp_prelude::*; use crate::itertools::enumerate; -use crate::numeric_util; +use crate::{numeric_util, BroadcastShape}; /// # Numerical Methods for Arrays impl ArrayBase @@ -283,10 +283,11 @@ where /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5) /// ); /// ``` - pub fn mean_axis(&self, axis: Axis) -> Option> + pub fn mean_axis(&self, axis: Axis) -> Option>::BroadcastOutput>> where A: Clone + Zero + FromPrimitive + Add + Div, D: RemoveAxis, + D::Smaller: BroadcastShape, { let axis_length = self.len_of(axis); if axis_length == 0 { diff --git a/tests/array.rs b/tests/array.rs index 6b72bb5c4..599504d3d 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -9,9 +9,9 @@ use defmac::defmac; use itertools::{enumerate, zip, Itertools}; -use ndarray::indices; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; +use ndarray::{indices, BroadcastShape, ErrorKind, IxDynImpl, ShapeError}; use ndarray::{Slice, SliceInfo, SliceOrIndex}; macro_rules! assert_panics { @@ -484,7 +484,7 @@ fn test_add() { } let B = A.clone(); - A = A + &B; + let A = A + &B; assert_eq!(A[[0, 0]], 0); assert_eq!(A[[0, 1]], 2); assert_eq!(A[[1, 0]], 4); @@ -1557,6 +1557,53 @@ fn insert_axis_view() { ); } +#[test] +fn test_broadcast_shape() { + fn test_co( + d1: &D1, + d2: &D2, + r: Result<>::BroadcastOutput, ShapeError>, + ) where + D1: Dimension + BroadcastShape, + D2: Dimension, + { + let d = d1.broadcast_shape(&d2); + assert_eq!(d, r); + } + test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); + test_co( + &Dim([1, 2, 2]), + &Dim([1, 3, 4]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); + let v = vec![1, 2, 3, 4, 5, 6, 7]; + test_co( + &Dim(vec![1, 1, 3, 1, 5, 1, 7]), + &Dim([2, 1, 4, 1, 6, 1]), + Ok(Dim(IxDynImpl::from(v.as_slice()))), + ); + let d = Dim([1, 2, 1, 3]); + test_co(&d, &d, Ok(d)); + test_co( + &Dim([2, 1, 2]).into_dyn(), + &Dim(0), + Ok(Dim([2, 1, 2]).into_dyn()), + ); + test_co( + &Dim([2, 1, 1]), + &Dim([0, 0, 0, 3, 4]), + Ok(Dim([0, 0, 2, 3, 4])), + ); + test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 1]))); + test_co( + &Dim([1, 3, 0, 1, 1]), + &Dim([1, 2, 3, 1]), + Ok(Dim([1, 3, 2, 3, 1])), + ); +} + #[test] fn arithmetic_broadcast() { let mut a = arr2(&[[1., 2.], [3., 4.]]); @@ -1565,6 +1612,49 @@ fn arithmetic_broadcast() { a.swap_axes(0, 1); let b = a.clone() / aview0(&1.); assert_eq!(a, b); + + // reference + let a = arr2(&[[2], [3], [4]]); + let b = arr1(&[5, 6, 7]); + assert_eq!(&a + &b, arr2(&[[7, 8, 9], [8, 9, 10], [9, 10, 11]])); + assert_eq!( + a.clone() - &b, + arr2(&[[-3, -4, -5], [-2, -3, -4], [-1, -2, -3]]) + ); + assert_eq!( + a.clone() * b.clone(), + arr2(&[[10, 12, 14], [15, 18, 21], [20, 24, 28]]) + ); + assert_eq!(&b / a, arr2(&[[2, 3, 3], [1, 2, 2], [1, 1, 1]])); + + // Negative strides and non-contiguous memory + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); + let a = s.slice(s![..;-1,..;2,..]); + let b = s.slice(s![..2, -1, ..]); + let mut c = s.clone(); + c.collapse_axis(Axis(2), 1); + let c = c.slice(s![1,..;2,..]); + assert_eq!( + &a.to_owned() + &b, + arr3(&[[[11, 15], [20, 24]], [[10, 14], [19, 23]]]) + ); + assert_eq!( + &a + b + c.into_owned(), + arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) + ); + + // shared array + let sa = a.to_shared(); + let sa2 = sa.to_shared(); + let sb = b.to_shared(); + let sb2 = sb.to_shared(); + let sc = c.to_shared(); + let sc2 = sc.into_shared(); + assert_eq!( + sa2 + sb2 + sc2.into_owned(), + arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) + ); } #[test] From 5e77eed8701b27c5edab6d9791af0f359de390c7 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Sat, 23 Jan 2021 19:19:04 +0800 Subject: [PATCH 02/16] Add BroadCastShape to the Dimension definition --- src/dimension/broadcast.rs | 18 ++++++++++++++---- src/dimension/dimension_trait.rs | 5 ++++- src/numeric/impl_numeric.rs | 5 ++--- tests/array.rs | 2 +- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index f29936cd1..15d656f7f 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -41,7 +41,7 @@ where Ok(out) } -pub trait BroadcastShape: Dimension { +pub trait BroadcastShape { /// The resulting dimension type after broadcasting. type BroadcastOutput: Dimension; @@ -51,9 +51,7 @@ pub trait BroadcastShape: Dimension { /// /// Uses the [NumPy broadcasting rules] /// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). - fn broadcast_shape(&self, other: &Other) -> Result { - broadcast_shape::(self, other) - } + fn broadcast_shape(&self, other: &Other) -> Result; } /// Dimensions of the same type remain unchanged when co_broadcast. @@ -61,16 +59,28 @@ pub trait BroadcastShape: Dimension { /// (Instead of >::BroadcastOutput) impl BroadcastShape for D { type BroadcastOutput = D; + + fn broadcast_shape(&self, other: &D) -> Result { + broadcast_shape::(self, other) + } } macro_rules! impl_broadcast_distinct_fixed { ($smaller:ty, $larger:ty) => { impl BroadcastShape<$larger> for $smaller { type BroadcastOutput = $larger; + + fn broadcast_shape(&self, other: &$larger) -> Result { + broadcast_shape::(self, other) + } } impl BroadcastShape<$smaller> for $larger { type BroadcastOutput = $larger; + + fn broadcast_shape(&self, other: &$smaller) -> Result { + broadcast_shape::(self, other) + } } }; } diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 92f241189..6c2893e4d 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -15,7 +15,7 @@ use super::axes_of; use super::conversion::Convert; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; -use crate::Axis; +use crate::{Axis, BroadcastShape}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; @@ -46,6 +46,9 @@ pub trait Dimension: + MulAssign + for<'x> MulAssign<&'x Self> + MulAssign + + BroadcastShape + + BroadcastShape + + BroadcastShape { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 5485a774e..fbec80418 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -13,7 +13,7 @@ use std::ops::{Add, Div, Mul}; use crate::imp_prelude::*; use crate::itertools::enumerate; -use crate::{numeric_util, BroadcastShape}; +use crate::numeric_util; /// # Numerical Methods for Arrays impl ArrayBase @@ -283,11 +283,10 @@ where /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5) /// ); /// ``` - pub fn mean_axis(&self, axis: Axis) -> Option>::BroadcastOutput>> + pub fn mean_axis(&self, axis: Axis) -> Option> where A: Clone + Zero + FromPrimitive + Add + Div, D: RemoveAxis, - D::Smaller: BroadcastShape, { let axis_length = self.len_of(axis); if axis_length == 0 { diff --git a/tests/array.rs b/tests/array.rs index 599504d3d..a49ae65b3 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1567,7 +1567,7 @@ fn test_broadcast_shape() { D1: Dimension + BroadcastShape, D2: Dimension, { - let d = d1.broadcast_shape(&d2); + let d = d1.broadcast_shape(d2); assert_eq!(d, r); } test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); From e8b29e1c3d7299bea394021cb6f2bf4ab9bb36a7 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Sat, 23 Jan 2021 23:15:48 +0800 Subject: [PATCH 03/16] add BroadCastShape<::Smaller> --- src/dimension/dimension_trait.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 6c2893e4d..89d106521 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -49,6 +49,8 @@ pub trait Dimension: + BroadcastShape + BroadcastShape + BroadcastShape + + BroadcastShape<::Smaller, BroadcastOutput=Self> + + BroadcastShape<::Larger, BroadcastOutput=::Larger> { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. From f7c9da16b483d6b4909f43eb055acf1a896e96ab Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Mon, 25 Jan 2021 16:47:42 +0800 Subject: [PATCH 04/16] use uninitialized to avoid traversing or cloning --- src/impl_methods.rs | 40 +++++++++++++++ src/impl_ops.rs | 99 +++++++++++++++++++++++++------------ src/numeric/impl_numeric.rs | 4 +- tests/array.rs | 6 +-- 4 files changed, 113 insertions(+), 36 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9943e7c8e..e7d5e7328 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2037,6 +2037,46 @@ where self.zip_mut_with_by_rows(rhs, f); } + /// Traverse two arrays in unspecified order, in lock step, + /// calling the closure `f` on each element pair, and put + /// the result into the corresponding element of self. + pub fn zip_mut_from_pair(&mut self, lhs: &ArrayBase, rhs: &ArrayBase, f: F, ) + where + S: DataMut, + S1: Data, + S2: Data, + F: Fn(&B, &C) -> A, + { + debug_assert_eq!(self.shape(), lhs.shape()); + debug_assert_eq!(self.shape(), rhs.shape()); + + if self.dim.strides_equivalent(&self.strides, &lhs.strides) + && self.dim.strides_equivalent(&self.strides, &rhs.strides) + { + if let Some(self_s) = self.as_slice_memory_order_mut() { + if let Some(lhs_s) = lhs.as_slice_memory_order() { + if let Some(rhs_s) = rhs.as_slice_memory_order() { + for (s, (l, r)) in + self_s.iter_mut().zip(lhs_s.iter().zip(rhs_s)) { + *s = f(&l, &r); + } + return; + } + } + } + } + + // Otherwise, fall back to the outer iter + let n = self.ndim(); + let dim = self.raw_dim(); + Zip::from(LanesMut::new(self.view_mut(), Axis(n - 1))) + .and(Lanes::new(lhs.broadcast_assume(dim.clone()), Axis(n - 1))) + .and(Lanes::new(rhs.broadcast_assume(dim), Axis(n - 1))) + .for_each(move |s_row, l_row, r_row| { + Zip::from(s_row).and(l_row).and(r_row).for_each(|s, a, b| *s = f(a, b)) + }); + } + // zip two arrays where they have different layout or strides #[inline(always)] fn zip_mut_with_by_rows(&mut self, rhs: &ArrayBase, mut f: F) diff --git a/src/impl_ops.rs b/src/impl_ops.rs index d7645b8cb..7e28d48cf 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -56,20 +56,22 @@ macro_rules! impl_binary_op( /// between `self` and `rhs`, /// and return the result. /// +/// `self` must be an `Array` or `ArcArray`. +/// /// If their shapes disagree, `self` is broadcast to their broadcast shape, /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl $trt> for ArrayBase where - A: Clone + $trt, + A: Copy + $trt, B: Clone, - S: Data, + S: DataOwned + DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, { - type Output = Array>::BroadcastOutput>; + type Output = ArrayBase>::BroadcastOutput>; fn $mth(self, rhs: ArrayBase) -> Self::Output { self.$mth(&rhs) @@ -79,7 +81,9 @@ where /// Perform elementwise #[doc=$doc] /// between reference `self` and `rhs`, -/// and return the result as a new `Array`. +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. /// /// If their shapes disagree, `self` is broadcast to their broadcast shape, /// cloning the data if needed. @@ -87,17 +91,36 @@ where /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase where - A: Clone + $trt, - B: Clone, + A: Clone + $trt, + B: Copy, S: Data, - S2: Data, - D: Dimension + BroadcastShape, - E: Dimension, + S2: DataOwned + DataMut, + D: Dimension, + E: Dimension + BroadcastShape, { - type Output = Array>::BroadcastOutput>; + type Output = ArrayBase>::BroadcastOutput>; fn $mth(self, rhs: ArrayBase) -> Self::Output { - self.$mth(&rhs) + let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); + if shape.slice() == rhs.dim.slice() { + let mut out = rhs.into_dimensionality::<>::BroadcastOutput>().unwrap(); + out.zip_mut_with(self, |x, y| { + *x = y.clone() $operator x.clone(); + }); + out + } else { + // SAFETY: Overwrite all the elements in the array after + // it is created via `zip_mut_from_pair`. + let mut out = unsafe { + Self::Output::uninitialized(shape.clone().into_pattern()) + }; + let lhs = self.broadcast(shape.clone()).unwrap(); + let rhs = rhs.broadcast(shape).unwrap(); + out.zip_mut_from_pair(&lhs, &rhs, |x, y| { + x.clone() $operator y.clone() + }); + out + } } } @@ -106,32 +129,44 @@ where /// between `self` and reference `rhs`, /// and return the result. /// +/// `rhs` must be an `Array` or `ArcArray`. +/// /// If their shapes disagree, `self` is broadcast to their broadcast shape, /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase where - A: Clone + $trt, + A: Copy + $trt, B: Clone, - S: Data, + S: DataOwned + DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, { - type Output = Array>::BroadcastOutput>; + type Output = ArrayBase>::BroadcastOutput>; fn $mth(self, rhs: &ArrayBase) -> Self::Output { let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - let mut self_ = if shape.slice() == self.dim.slice() { - self.into_owned().into_dimensionality::<>::BroadcastOutput>().unwrap() + if shape.slice() == self.dim.slice() { + let mut out = self.into_dimensionality::<>::BroadcastOutput>().unwrap(); + out.zip_mut_with(rhs, |x, y| { + *x = x.clone() $operator y.clone(); + }); + out } else { - self.broadcast(shape).unwrap().to_owned() - }; - self_.zip_mut_with(rhs, |x, y| { - *x = x.clone() $operator y.clone(); - }); - self_ + // SAFETY: Overwrite all the elements in the array after + // it is created via `zip_mut_from_pair`. + let mut out = unsafe { + Self::Output::uninitialized(shape.clone().into_pattern()) + }; + let lhs = self.broadcast(shape.clone()).unwrap(); + let rhs = rhs.broadcast(shape).unwrap(); + out.zip_mut_from_pair(&lhs, &rhs, |x, y| { + x.clone() $operator y.clone() + }); + out + } } } @@ -140,13 +175,13 @@ where /// between references `self` and `rhs`, /// and return the result as a new `Array`. /// -/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape, /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase where - A: Clone + $trt, + A: Copy + $trt, B: Clone, S: Data, S2: Data, @@ -156,15 +191,17 @@ where type Output = Array>::BroadcastOutput>; fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - let mut self_ = if shape.slice() == self.dim.slice() { - self.to_owned().into_dimensionality::<>::BroadcastOutput>().unwrap() - } else { - self.broadcast(shape).unwrap().to_owned() + // SAFETY: Overwrite all the elements in the array after + // it is created via `zip_mut_from_pair`. + let mut out = unsafe { + Self::Output::uninitialized(shape.clone().into_pattern()) }; - self_.zip_mut_with(rhs, |x, y| { - *x = x.clone() $operator y.clone(); + let lhs = self.broadcast(shape.clone()).unwrap(); + let rhs = rhs.broadcast(shape).unwrap(); + out.zip_mut_from_pair(&lhs, &rhs, |x, y| { + x.clone() $operator y.clone() }); - self_ + out } } diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index fbec80418..0c5b4820d 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -243,7 +243,7 @@ where /// **Panics** if `axis` is out of bounds. pub fn sum_axis(&self, axis: Axis) -> Array where - A: Clone + Zero + Add, + A: Copy + Zero + Add, D: RemoveAxis, { let n = self.len_of(axis); @@ -285,7 +285,7 @@ where /// ``` pub fn mean_axis(&self, axis: Axis) -> Option> where - A: Clone + Zero + FromPrimitive + Add + Div, + A: Copy + Zero + FromPrimitive + Add + Div, D: RemoveAxis, { let axis_length = self.len_of(axis); diff --git a/tests/array.rs b/tests/array.rs index a49ae65b3..370600bd3 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -484,7 +484,7 @@ fn test_add() { } let B = A.clone(); - let A = A + &B; + A = A + &B; assert_eq!(A[[0, 0]], 0); assert_eq!(A[[0, 1]], 2); assert_eq!(A[[1, 0]], 4); @@ -1640,7 +1640,7 @@ fn arithmetic_broadcast() { arr3(&[[[11, 15], [20, 24]], [[10, 14], [19, 23]]]) ); assert_eq!( - &a + b + c.into_owned(), + &a + b.into_owned() + c, arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) ); @@ -1652,7 +1652,7 @@ fn arithmetic_broadcast() { let sc = c.to_shared(); let sc2 = sc.into_shared(); assert_eq!( - sa2 + sb2 + sc2.into_owned(), + sa2 + &sb2 + sc2.into_owned(), arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) ); } From 0561c1325c88086ad7239a5fe596cf66870c6c91 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Tue, 26 Jan 2021 16:35:25 +0800 Subject: [PATCH 05/16] Use MaybeUninitSubst and Zip to avoid uninitialized(), rename BroadCastOutput to Output, remove zip_mut_from_pair() --- src/data_traits.rs | 28 +++++-- src/dimension/broadcast.rs | 22 +++--- src/dimension/dimension_trait.rs | 10 +-- src/impl_methods.rs | 40 ---------- src/impl_ops.rs | 129 ++++++++++++++++--------------- src/lib.rs | 2 +- tests/array.rs | 2 +- 7 files changed, 108 insertions(+), 125 deletions(-) diff --git a/src/data_traits.rs b/src/data_traits.rs index 1e191c468..eabe1795b 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -10,15 +10,12 @@ use rawpointer::PointerExt; -use std::mem::{self, size_of}; -use std::mem::MaybeUninit; +use std::mem::{self, size_of};use std::mem::MaybeUninit; use std::ptr::NonNull; use alloc::sync::Arc; use alloc::vec::Vec; -use crate::{ - ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, -}; +use crate::{ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; /// Array representation trait. /// @@ -414,7 +411,6 @@ pub unsafe trait DataOwned: Data { /// Corresponding owned data with MaybeUninit elements type MaybeUninit: DataOwned> + RawDataSubst; - #[doc(hidden)] fn new(elements: Vec) -> Self; @@ -440,6 +436,7 @@ unsafe impl DataOwned for OwnedRepr { fn new(elements: Vec) -> Self { OwnedRepr::from(elements) } + fn into_shared(self) -> OwnedArcRepr { OwnedArcRepr(Arc::new(self)) } @@ -622,3 +619,22 @@ impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a mut A> { ViewRepr::new() } } + +/// Array representation trait. +/// +/// The MaybeUninitSubst trait maps the MaybeUninit type of element, while +/// mapping the MaybeUninit type back to origin element type. +/// +/// For example, `MaybeUninitSubst` can map the type `OwnedRepr` to `OwnedRepr>`, +/// and use `Output as RawDataSubst` to map `OwnedRepr>` back to `OwnedRepr`. +pub trait MaybeUninitSubst: DataOwned { + type Output: DataOwned> + RawDataSubst>; +} + +impl MaybeUninitSubst for OwnedRepr { + type Output = OwnedRepr>; +} + +impl MaybeUninitSubst for OwnedArcRepr { + type Output = OwnedArcRepr>; +} diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 15d656f7f..d13e2ac3a 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -43,7 +43,7 @@ where pub trait BroadcastShape { /// The resulting dimension type after broadcasting. - type BroadcastOutput: Dimension; + type Output: Dimension; /// Determines the shape after broadcasting the dimensions together. /// @@ -51,35 +51,35 @@ pub trait BroadcastShape { /// /// Uses the [NumPy broadcasting rules] /// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). - fn broadcast_shape(&self, other: &Other) -> Result; + fn broadcast_shape(&self, other: &Other) -> Result; } /// Dimensions of the same type remain unchanged when co_broadcast. /// So you can directly use D as the resulting type. /// (Instead of >::BroadcastOutput) impl BroadcastShape for D { - type BroadcastOutput = D; + type Output = D; - fn broadcast_shape(&self, other: &D) -> Result { - broadcast_shape::(self, other) + fn broadcast_shape(&self, other: &D) -> Result { + broadcast_shape::(self, other) } } macro_rules! impl_broadcast_distinct_fixed { ($smaller:ty, $larger:ty) => { impl BroadcastShape<$larger> for $smaller { - type BroadcastOutput = $larger; + type Output = $larger; - fn broadcast_shape(&self, other: &$larger) -> Result { - broadcast_shape::(self, other) + fn broadcast_shape(&self, other: &$larger) -> Result { + broadcast_shape::(self, other) } } impl BroadcastShape<$smaller> for $larger { - type BroadcastOutput = $larger; + type Output = $larger; - fn broadcast_shape(&self, other: &$smaller) -> Result { - broadcast_shape::(self, other) + fn broadcast_shape(&self, other: &$smaller) -> Result { + broadcast_shape::(self, other) } } }; diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 89d106521..e3a1ebd09 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -46,11 +46,11 @@ pub trait Dimension: + MulAssign + for<'x> MulAssign<&'x Self> + MulAssign - + BroadcastShape - + BroadcastShape - + BroadcastShape - + BroadcastShape<::Smaller, BroadcastOutput=Self> - + BroadcastShape<::Larger, BroadcastOutput=::Larger> + + BroadcastShape + + BroadcastShape + + BroadcastShape + + BroadcastShape<::Smaller, Output=Self> + + BroadcastShape<::Larger, Output=::Larger> { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. diff --git a/src/impl_methods.rs b/src/impl_methods.rs index e7d5e7328..9943e7c8e 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2037,46 +2037,6 @@ where self.zip_mut_with_by_rows(rhs, f); } - /// Traverse two arrays in unspecified order, in lock step, - /// calling the closure `f` on each element pair, and put - /// the result into the corresponding element of self. - pub fn zip_mut_from_pair(&mut self, lhs: &ArrayBase, rhs: &ArrayBase, f: F, ) - where - S: DataMut, - S1: Data, - S2: Data, - F: Fn(&B, &C) -> A, - { - debug_assert_eq!(self.shape(), lhs.shape()); - debug_assert_eq!(self.shape(), rhs.shape()); - - if self.dim.strides_equivalent(&self.strides, &lhs.strides) - && self.dim.strides_equivalent(&self.strides, &rhs.strides) - { - if let Some(self_s) = self.as_slice_memory_order_mut() { - if let Some(lhs_s) = lhs.as_slice_memory_order() { - if let Some(rhs_s) = rhs.as_slice_memory_order() { - for (s, (l, r)) in - self_s.iter_mut().zip(lhs_s.iter().zip(rhs_s)) { - *s = f(&l, &r); - } - return; - } - } - } - } - - // Otherwise, fall back to the outer iter - let n = self.ndim(); - let dim = self.raw_dim(); - Zip::from(LanesMut::new(self.view_mut(), Axis(n - 1))) - .and(Lanes::new(lhs.broadcast_assume(dim.clone()), Axis(n - 1))) - .and(Lanes::new(rhs.broadcast_assume(dim), Axis(n - 1))) - .for_each(move |s_row, l_row, r_row| { - Zip::from(s_row).and(l_row).and(r_row).for_each(|s, a, b| *s = f(a, b)) - }); - } - // zip two arrays where they have different layout or strides #[inline(always)] fn zip_mut_with_by_rows(&mut self, rhs: &ArrayBase, mut f: F) diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 7e28d48cf..065edbad8 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -7,6 +7,8 @@ // except according to those terms. use crate::dimension::BroadcastShape; +use crate::data_traits::MaybeUninitSubst; +use crate::Zip; use num_complex::Complex; /// Elements that can be used as direct operands in arithmetic with arrays. @@ -64,14 +66,15 @@ macro_rules! impl_binary_op( /// **Panics** if broadcasting isn’t possible. impl $trt> for ArrayBase where - A: Copy + $trt, + A: Clone + $trt, B: Clone, - S: DataOwned + DataMut, + S: DataOwned + DataMut + MaybeUninitSubst, + >::Output: DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, { - type Output = ArrayBase>::BroadcastOutput>; + type Output = ArrayBase>::Output>; fn $mth(self, rhs: ArrayBase) -> Self::Output { self.$mth(&rhs) @@ -80,7 +83,7 @@ where /// Perform elementwise #[doc=$doc] -/// between reference `self` and `rhs`, +/// between `self` and reference `rhs`, /// and return the result. /// /// `rhs` must be an `Array` or `ArcArray`. @@ -89,44 +92,49 @@ where /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. -impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase +impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase where - A: Clone + $trt, - B: Copy, - S: Data, - S2: DataOwned + DataMut, - D: Dimension, - E: Dimension + BroadcastShape, + A: Clone + $trt, + B: Clone, + S: DataOwned + DataMut + MaybeUninitSubst, + >::Output: DataMut, + S2: Data, + D: Dimension + BroadcastShape, + E: Dimension, { - type Output = ArrayBase>::BroadcastOutput>; - fn $mth(self, rhs: ArrayBase) -> Self::Output + type Output = ArrayBase>::Output>; + fn $mth(self, rhs: &ArrayBase) -> Self::Output { - let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); - if shape.slice() == rhs.dim.slice() { - let mut out = rhs.into_dimensionality::<>::BroadcastOutput>().unwrap(); - out.zip_mut_with(self, |x, y| { - *x = y.clone() $operator x.clone(); + let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); + if shape.slice() == self.dim.slice() { + let mut out = self.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with(rhs, |x, y| { + *x = x.clone() $operator y.clone(); }); out } else { - // SAFETY: Overwrite all the elements in the array after - // it is created via `zip_mut_from_pair`. - let mut out = unsafe { - Self::Output::uninitialized(shape.clone().into_pattern()) - }; let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); - out.zip_mut_from_pair(&lhs, &rhs, |x, y| { - x.clone() $operator y.clone() - }); - out + let rhs = rhs.broadcast(shape.clone()).unwrap(); + // SAFETY: Overwrite all the elements in the array after + // it is created via `raw_view_mut`. + unsafe { + let mut out =ArrayBase::<>::Output, >::Output>::maybe_uninit(shape.into_pattern()); + let output_view = out.raw_view_mut().cast::(); + Zip::from(&lhs).and(&rhs) + .and(output_view) + .collect_with_partial(|x, y| { + x.clone() $operator y.clone() + }) + .release_ownership(); + out.assume_init() + } } } } /// Perform elementwise #[doc=$doc] -/// between `self` and reference `rhs`, +/// between reference `self` and `rhs`, /// and return the result. /// /// `rhs` must be an `Array` or `ArcArray`. @@ -135,37 +143,43 @@ where /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. -impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase +impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase where - A: Copy + $trt, + A: Clone + $trt, B: Clone, - S: DataOwned + DataMut, - S2: Data, - D: Dimension + BroadcastShape, - E: Dimension, + S: Data, + S2: DataOwned + DataMut + MaybeUninitSubst, + >::Output: DataMut, + D: Dimension, + E: Dimension + BroadcastShape, { - type Output = ArrayBase>::BroadcastOutput>; - fn $mth(self, rhs: &ArrayBase) -> Self::Output + type Output = ArrayBase>::Output>; + fn $mth(self, rhs: ArrayBase) -> Self::Output + where { - let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - if shape.slice() == self.dim.slice() { - let mut out = self.into_dimensionality::<>::BroadcastOutput>().unwrap(); - out.zip_mut_with(rhs, |x, y| { - *x = x.clone() $operator y.clone(); + let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); + if shape.slice() == rhs.dim.slice() { + let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with(self, |x, y| { + *x = y.clone() $operator x.clone(); }); out } else { - // SAFETY: Overwrite all the elements in the array after - // it is created via `zip_mut_from_pair`. - let mut out = unsafe { - Self::Output::uninitialized(shape.clone().into_pattern()) - }; let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); - out.zip_mut_from_pair(&lhs, &rhs, |x, y| { - x.clone() $operator y.clone() - }); - out + let rhs = rhs.broadcast(shape.clone()).unwrap(); + // SAFETY: Overwrite all the elements in the array after + // it is created via `raw_view_mut`. + unsafe { + let mut out =ArrayBase::<>::Output, >::Output>::maybe_uninit(shape.into_pattern()); + let output_view = out.raw_view_mut().cast::(); + Zip::from(&lhs).and(&rhs) + .and(output_view) + .collect_with_partial(|x, y| { + x.clone() $operator y.clone() + }) + .release_ownership(); + out.assume_init() + } } } } @@ -188,19 +202,12 @@ where D: Dimension + BroadcastShape, E: Dimension, { - type Output = Array>::BroadcastOutput>; + type Output = Array>::Output>; fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - // SAFETY: Overwrite all the elements in the array after - // it is created via `zip_mut_from_pair`. - let mut out = unsafe { - Self::Output::uninitialized(shape.clone().into_pattern()) - }; let lhs = self.broadcast(shape.clone()).unwrap(); let rhs = rhs.broadcast(shape).unwrap(); - out.zip_mut_from_pair(&lhs, &rhs, |x, y| { - x.clone() $operator y.clone() - }); + let out = Zip::from(&lhs).and(&rhs).map_collect(|x, y| x.clone() $operator y.clone()); out } } diff --git a/src/lib.rs b/src/lib.rs index 1a760bf31..844226ed9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -179,7 +179,7 @@ pub use crate::aliases::*; pub use crate::data_traits::{ Data, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut, - RawDataSubst, + RawDataSubst, MaybeUninitSubst, }; mod free_functions; diff --git a/tests/array.rs b/tests/array.rs index 370600bd3..6f07baff8 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1562,7 +1562,7 @@ fn test_broadcast_shape() { fn test_co( d1: &D1, d2: &D2, - r: Result<>::BroadcastOutput, ShapeError>, + r: Result<>::Output, ShapeError>, ) where D1: Dimension + BroadcastShape, D2: Dimension, From fc969bdadc830f3e7c4f624a3245b4051e5c9e7d Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Wed, 27 Jan 2021 11:16:14 +0800 Subject: [PATCH 06/16] treat zero dimension like numpy does --- src/dimension/broadcast.rs | 8 +++----- src/numeric/impl_numeric.rs | 4 ++-- tests/array.rs | 8 ++++---- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index d13e2ac3a..7c158f5b2 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -23,15 +23,13 @@ where // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). // // Zero dimension element is not in the original rules of broadcasting. - // We currently treat it as the same as 1. Especially, when one side is - // zero with one side is empty, or both sides are zero, the result will - // remain zero. + // We currently treat it like any other number greater than 1. As numpy does. for i in 0..shape1.ndim() { out_slice[i] = s1[i]; } for i in 0..shape2.ndim() { - if out_slice[i + k] != s2[i] && s2[i] != 0 { - if out_slice[i + k] <= 1 { + if out_slice[i + k] != s2[i] { + if out_slice[i + k] == 1 { out_slice[i + k] = s2[i] } else if s2[i] != 1 { return Err(from_kind(ErrorKind::IncompatibleShape)); diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 0c5b4820d..fbec80418 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -243,7 +243,7 @@ where /// **Panics** if `axis` is out of bounds. pub fn sum_axis(&self, axis: Axis) -> Array where - A: Copy + Zero + Add, + A: Clone + Zero + Add, D: RemoveAxis, { let n = self.len_of(axis); @@ -285,7 +285,7 @@ where /// ``` pub fn mean_axis(&self, axis: Axis) -> Option> where - A: Copy + Zero + FromPrimitive + Add + Div, + A: Clone + Zero + FromPrimitive + Add + Div, D: RemoveAxis, { let axis_length = self.len_of(axis); diff --git a/tests/array.rs b/tests/array.rs index 6f07baff8..a380423d7 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1588,19 +1588,19 @@ fn test_broadcast_shape() { test_co( &Dim([2, 1, 2]).into_dyn(), &Dim(0), - Ok(Dim([2, 1, 2]).into_dyn()), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), ); test_co( &Dim([2, 1, 1]), - &Dim([0, 0, 0, 3, 4]), + &Dim([0, 0, 1, 3, 4]), Ok(Dim([0, 0, 2, 3, 4])), ); test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); - test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 1]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); test_co( &Dim([1, 3, 0, 1, 1]), &Dim([1, 2, 3, 1]), - Ok(Dim([1, 3, 2, 3, 1])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), ); } From a19149b73b93d227f5eaec14fe7e4da398869c2f Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Mon, 1 Feb 2021 12:20:05 +0800 Subject: [PATCH 07/16] rebase and use map_collect_owned in impl_ops.rs --- src/data_traits.rs | 21 ++----------- src/impl_methods.rs | 2 +- src/impl_ops.rs | 76 +++++++++++++++++---------------------------- src/lib.rs | 2 +- 4 files changed, 33 insertions(+), 68 deletions(-) diff --git a/src/data_traits.rs b/src/data_traits.rs index eabe1795b..7ac63d54e 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -10,7 +10,8 @@ use rawpointer::PointerExt; -use std::mem::{self, size_of};use std::mem::MaybeUninit; +use std::mem::{self, size_of}; +use std::mem::MaybeUninit; use std::ptr::NonNull; use alloc::sync::Arc; use alloc::vec::Vec; @@ -620,21 +621,3 @@ impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a mut A> { } } -/// Array representation trait. -/// -/// The MaybeUninitSubst trait maps the MaybeUninit type of element, while -/// mapping the MaybeUninit type back to origin element type. -/// -/// For example, `MaybeUninitSubst` can map the type `OwnedRepr` to `OwnedRepr>`, -/// and use `Output as RawDataSubst` to map `OwnedRepr>` back to `OwnedRepr`. -pub trait MaybeUninitSubst: DataOwned { - type Output: DataOwned> + RawDataSubst>; -} - -impl MaybeUninitSubst for OwnedRepr { - type Output = OwnedRepr>; -} - -impl MaybeUninitSubst for OwnedArcRepr { - type Output = OwnedArcRepr>; -} diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9943e7c8e..9a1e0fac8 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2013,7 +2013,7 @@ where self.map_inplace(move |elt| *elt = x.clone()); } - fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) + pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) where S: DataMut, S2: Data, diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 065edbad8..6e45ea629 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -7,7 +7,6 @@ // except according to those terms. use crate::dimension::BroadcastShape; -use crate::data_traits::MaybeUninitSubst; use crate::Zip; use num_complex::Complex; @@ -68,8 +67,8 @@ impl $trt> for ArrayBase where A: Clone + $trt, B: Clone, - S: DataOwned + DataMut + MaybeUninitSubst, - >::Output: DataMut, + S: DataOwned + DataMut, + S::MaybeUninit: DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, @@ -96,8 +95,8 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase where A: Clone + $trt, B: Clone, - S: DataOwned + DataMut + MaybeUninitSubst, - >::Output: DataMut, + S: DataOwned + DataMut, + S::MaybeUninit: DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, @@ -105,29 +104,15 @@ where type Output = ArrayBase>::Output>; fn $mth(self, rhs: &ArrayBase) -> Self::Output { - let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - if shape.slice() == self.dim.slice() { + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = self.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with(rhs, |x, y| { - *x = x.clone() $operator y.clone(); - }); + out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); out } else { + let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape.clone()).unwrap(); - // SAFETY: Overwrite all the elements in the array after - // it is created via `raw_view_mut`. - unsafe { - let mut out =ArrayBase::<>::Output, >::Output>::maybe_uninit(shape.into_pattern()); - let output_view = out.raw_view_mut().cast::(); - Zip::from(&lhs).and(&rhs) - .and(output_view) - .collect_with_partial(|x, y| { - x.clone() $operator y.clone() - }) - .release_ownership(); - out.assume_init() - } + let rhs = rhs.broadcast(shape).unwrap(); + Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } } @@ -148,8 +133,8 @@ where A: Clone + $trt, B: Clone, S: Data, - S2: DataOwned + DataMut + MaybeUninitSubst, - >::Output: DataMut, + S2: DataOwned + DataMut, + S2::MaybeUninit: DataMut, D: Dimension, E: Dimension + BroadcastShape, { @@ -157,29 +142,15 @@ where fn $mth(self, rhs: ArrayBase) -> Self::Output where { - let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); - if shape.slice() == rhs.dim.slice() { + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with(self, |x, y| { - *x = y.clone() $operator x.clone(); - }); + out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); out } else { + let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape.clone()).unwrap(); - // SAFETY: Overwrite all the elements in the array after - // it is created via `raw_view_mut`. - unsafe { - let mut out =ArrayBase::<>::Output, >::Output>::maybe_uninit(shape.into_pattern()); - let output_view = out.raw_view_mut().cast::(); - Zip::from(&lhs).and(&rhs) - .and(output_view) - .collect_with_partial(|x, y| { - x.clone() $operator y.clone() - }) - .release_ownership(); - out.assume_init() - } + let rhs = rhs.broadcast(shape).unwrap(); + Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } } @@ -207,8 +178,7 @@ where let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); let lhs = self.broadcast(shape.clone()).unwrap(); let rhs = rhs.broadcast(shape).unwrap(); - let out = Zip::from(&lhs).and(&rhs).map_collect(|x, y| x.clone() $operator y.clone()); - out + Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth)) } } @@ -313,6 +283,18 @@ mod arithmetic_ops { use num_complex::Complex; use std::ops::*; + fn clone_opf(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C { + move |x, y| f(x.clone(), y.clone()) + } + + fn clone_iopf(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) { + move |x, y| *x = f(x.clone(), y.clone()) + } + + fn clone_iopf_rev(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) { + move |x, y| *x = f(y.clone(), x.clone()) + } + impl_binary_op!(Add, +, add, +=, "addition"); impl_binary_op!(Sub, -, sub, -=, "subtraction"); impl_binary_op!(Mul, *, mul, *=, "multiplication"); diff --git a/src/lib.rs b/src/lib.rs index 844226ed9..1a760bf31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -179,7 +179,7 @@ pub use crate::aliases::*; pub use crate::data_traits::{ Data, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut, - RawDataSubst, MaybeUninitSubst, + RawDataSubst, }; mod free_functions; From b1f4f9526e264609e54b8d3e2ab2053cdef8915c Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Wed, 3 Feb 2021 21:49:22 +0800 Subject: [PATCH 08/16] use izip in index loop in broadcast.rs --- src/dimension/broadcast.rs | 17 +++++++---------- src/impl_ops.rs | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 7c158f5b2..3f4b8bdbe 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -16,22 +16,19 @@ where } // The output should be the same length as shape1. let mut out = Output::zeros(shape1.ndim()); - let out_slice = out.slice_mut(); - let s1 = shape1.slice(); - let s2 = shape2.slice(); // Uses the [NumPy broadcasting rules] // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). // // Zero dimension element is not in the original rules of broadcasting. // We currently treat it like any other number greater than 1. As numpy does. - for i in 0..shape1.ndim() { - out_slice[i] = s1[i]; + for (out, s) in izip!(out.slice_mut(), shape1.slice()) { + *out = *s; } - for i in 0..shape2.ndim() { - if out_slice[i + k] != s2[i] { - if out_slice[i + k] == 1 { - out_slice[i + k] = s2[i] - } else if s2[i] != 1 { + for (out, s2) in izip!(&mut out.slice_mut()[k..], shape2.slice()) { + if *out != *s2 { + if *out == 1 { + *out = *s2 + } else if *s2 != 1 { return Err(from_kind(ErrorKind::IncompatibleShape)); } } diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 6e45ea629..585b3ed4c 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -166,7 +166,7 @@ where /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase where - A: Copy + $trt, + A: Clone + $trt, B: Clone, S: Data, S2: Data, From c9eb88e5a22c91be0ca2b90a4b50e5bcba1a7688 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Fri, 5 Feb 2021 20:40:33 +0800 Subject: [PATCH 09/16] Update documentation and function names --- src/dimension/broadcast.rs | 35 ++++++++++++-------------- src/impl_ops.rs | 3 --- tests/array.rs | 49 +------------------------------------ tests/dimension.rs | 50 +++++++++++++++++++++++++++++++++++++- 4 files changed, 65 insertions(+), 72 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 3f4b8bdbe..74ac9e634 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -1,26 +1,24 @@ use crate::error::*; use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; -/// Calculate the co_broadcast shape of two dimensions. Return error if shapes are -/// not compatible. -fn broadcast_shape(shape1: &D1, shape2: &D2) -> Result -where - D1: Dimension, - D2: Dimension, - Output: Dimension, +/// Calculate the common shape for a pair of array shapes, which can be broadcasted +/// to each other. Return an error if shapes are not compatible. +/// +/// Uses the [NumPy broadcasting rules] +// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). +fn co_broadcasting(shape1: &D1, shape2: &D2) -> Result + where + D1: Dimension, + D2: Dimension, + Output: Dimension, { let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim()); // Swap the order if d2 is longer. if overflow { - return broadcast_shape::(shape2, shape1); + return co_broadcasting::(shape2, shape1); } // The output should be the same length as shape1. let mut out = Output::zeros(shape1.ndim()); - // Uses the [NumPy broadcasting rules] - // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). - // - // Zero dimension element is not in the original rules of broadcasting. - // We currently treat it like any other number greater than 1. As numpy does. for (out, s) in izip!(out.slice_mut(), shape1.slice()) { *out = *s; } @@ -42,10 +40,7 @@ pub trait BroadcastShape { /// Determines the shape after broadcasting the dimensions together. /// - /// If the dimensions are not compatible, returns `Err`. - /// - /// Uses the [NumPy broadcasting rules] - /// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). + /// If the shapes are not compatible, returns `Err`. fn broadcast_shape(&self, other: &Other) -> Result; } @@ -56,7 +51,7 @@ impl BroadcastShape for D { type Output = D; fn broadcast_shape(&self, other: &D) -> Result { - broadcast_shape::(self, other) + co_broadcasting::(self, other) } } @@ -66,7 +61,7 @@ macro_rules! impl_broadcast_distinct_fixed { type Output = $larger; fn broadcast_shape(&self, other: &$larger) -> Result { - broadcast_shape::(self, other) + co_broadcasting::(self, other) } } @@ -74,7 +69,7 @@ macro_rules! impl_broadcast_distinct_fixed { type Output = $larger; fn broadcast_shape(&self, other: &$smaller) -> Result { - broadcast_shape::(self, other) + co_broadcasting::(self, other) } } }; diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 585b3ed4c..5bd518a00 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -68,7 +68,6 @@ where A: Clone + $trt, B: Clone, S: DataOwned + DataMut, - S::MaybeUninit: DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, @@ -96,7 +95,6 @@ where A: Clone + $trt, B: Clone, S: DataOwned + DataMut, - S::MaybeUninit: DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, @@ -134,7 +132,6 @@ where B: Clone, S: Data, S2: DataOwned + DataMut, - S2::MaybeUninit: DataMut, D: Dimension, E: Dimension + BroadcastShape, { diff --git a/tests/array.rs b/tests/array.rs index a380423d7..b0a28ca41 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -11,7 +11,7 @@ use defmac::defmac; use itertools::{enumerate, zip, Itertools}; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; -use ndarray::{indices, BroadcastShape, ErrorKind, IxDynImpl, ShapeError}; +use ndarray::indices; use ndarray::{Slice, SliceInfo, SliceOrIndex}; macro_rules! assert_panics { @@ -1557,53 +1557,6 @@ fn insert_axis_view() { ); } -#[test] -fn test_broadcast_shape() { - fn test_co( - d1: &D1, - d2: &D2, - r: Result<>::Output, ShapeError>, - ) where - D1: Dimension + BroadcastShape, - D2: Dimension, - { - let d = d1.broadcast_shape(d2); - assert_eq!(d, r); - } - test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); - test_co( - &Dim([1, 2, 2]), - &Dim([1, 3, 4]), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); - test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); - let v = vec![1, 2, 3, 4, 5, 6, 7]; - test_co( - &Dim(vec![1, 1, 3, 1, 5, 1, 7]), - &Dim([2, 1, 4, 1, 6, 1]), - Ok(Dim(IxDynImpl::from(v.as_slice()))), - ); - let d = Dim([1, 2, 1, 3]); - test_co(&d, &d, Ok(d)); - test_co( - &Dim([2, 1, 2]).into_dyn(), - &Dim(0), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); - test_co( - &Dim([2, 1, 1]), - &Dim([0, 0, 1, 3, 4]), - Ok(Dim([0, 0, 2, 3, 4])), - ); - test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); - test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); - test_co( - &Dim([1, 3, 0, 1, 1]), - &Dim([1, 2, 3, 1]), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); -} - #[test] fn arithmetic_broadcast() { let mut a = arr2(&[[1., 2.], [3., 4.]]); diff --git a/tests/dimension.rs b/tests/dimension.rs index 939b4f0e3..ede0dc7d2 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -2,7 +2,8 @@ use defmac::defmac; -use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IxDyn, RemoveAxis}; +use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, Ix0, IxDyn, IxDynImpl, RemoveAxis, + ErrorKind, ShapeError, BroadcastShape}; use std::hash::{Hash, Hasher}; @@ -340,3 +341,50 @@ fn test_all_ndindex() { ndindex!(10, 4, 3, 2, 2); ndindex!(10, 4, 3, 2, 2, 2); } + +#[test] +fn test_broadcast_shape() { + fn test_co( + d1: &D1, + d2: &D2, + r: Result<>::Output, ShapeError>, + ) where + D1: Dimension + BroadcastShape, + D2: Dimension, + { + let d = d1.broadcast_shape(d2); + assert_eq!(d, r); + } + test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); + test_co( + &Dim([1, 2, 2]), + &Dim([1, 3, 4]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); + let v = vec![1, 2, 3, 4, 5, 6, 7]; + test_co( + &Dim(vec![1, 1, 3, 1, 5, 1, 7]), + &Dim([2, 1, 4, 1, 6, 1]), + Ok(Dim(IxDynImpl::from(v.as_slice()))), + ); + let d = Dim([1, 2, 1, 3]); + test_co(&d, &d, Ok(d)); + test_co( + &Dim([2, 1, 2]).into_dyn(), + &Dim(0), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co( + &Dim([2, 1, 1]), + &Dim([0, 0, 1, 3, 4]), + Ok(Dim([0, 0, 2, 3, 4])), + ); + test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); + test_co( + &Dim([1, 3, 0, 1, 1]), + &Dim([1, 2, 3, 1]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); +} From e3b73cc50330743d735872e3ecfb0a880dcbe252 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Thu, 18 Feb 2021 19:40:43 +0800 Subject: [PATCH 10/16] Update documentation and function names --- src/dimension/broadcast.rs | 16 ++++++++-------- src/impl_ops.rs | 3 +-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 74ac9e634..377c8509e 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -1,12 +1,12 @@ use crate::error::*; use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; -/// Calculate the common shape for a pair of array shapes, which can be broadcasted -/// to each other. Return an error if shapes are not compatible. +/// Calculate the common shape for a pair of array shapes, that they can be broadcasted +/// to. Return an error if the shapes are not compatible. /// /// Uses the [NumPy broadcasting rules] // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). -fn co_broadcasting(shape1: &D1, shape2: &D2) -> Result +fn co_broadcast(shape1: &D1, shape2: &D2) -> Result where D1: Dimension, D2: Dimension, @@ -15,7 +15,7 @@ fn co_broadcasting(shape1: &D1, shape2: &D2) -> Result(shape2, shape1); + return co_broadcast::(shape2, shape1); } // The output should be the same length as shape1. let mut out = Output::zeros(shape1.ndim()); @@ -38,7 +38,7 @@ pub trait BroadcastShape { /// The resulting dimension type after broadcasting. type Output: Dimension; - /// Determines the shape after broadcasting the dimensions together. + /// Determines the shape after broadcasting the shapes together. /// /// If the shapes are not compatible, returns `Err`. fn broadcast_shape(&self, other: &Other) -> Result; @@ -51,7 +51,7 @@ impl BroadcastShape for D { type Output = D; fn broadcast_shape(&self, other: &D) -> Result { - co_broadcasting::(self, other) + co_broadcast::(self, other) } } @@ -61,7 +61,7 @@ macro_rules! impl_broadcast_distinct_fixed { type Output = $larger; fn broadcast_shape(&self, other: &$larger) -> Result { - co_broadcasting::(self, other) + co_broadcast::(self, other) } } @@ -69,7 +69,7 @@ macro_rules! impl_broadcast_distinct_fixed { type Output = $larger; fn broadcast_shape(&self, other: &$smaller) -> Result { - co_broadcasting::(self, other) + co_broadcast::(self, other) } } }; diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 5bd518a00..69c6f698e 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -59,8 +59,7 @@ macro_rules! impl_binary_op( /// /// `self` must be an `Array` or `ArcArray`. /// -/// If their shapes disagree, `self` is broadcast to their broadcast shape, -/// cloning the data if needed. +/// If their shapes disagree, `self` is broadcast to their broadcast shape. /// /// **Panics** if broadcasting isn’t possible. impl $trt> for ArrayBase From b6232398675b00af31ae8870ce045785fc332cc2 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Thu, 18 Feb 2021 21:03:57 +0800 Subject: [PATCH 11/16] Add function broadcast_with --- src/impl_methods.rs | 34 ++++++++++++++++++++++++++++++++-- src/impl_ops.rs | 12 +++--------- tests/broadcast.rs | 31 +++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9a1e0fac8..4202ec257 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -14,14 +14,14 @@ use rawpointer::PointerExt; use crate::imp_prelude::*; -use crate::arraytraits; +use crate::{arraytraits, BroadcastShape}; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; -use crate::error::{self, ErrorKind, ShapeError}; +use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::Zip; @@ -1766,6 +1766,36 @@ where unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } + /// Calculate the views of two ArrayBases after broadcasting each other, if possible. + /// + /// Return `ShapeError` if their shapes can not be broadcast together. + /// + /// ``` + /// use ndarray::{arr1, arr2}; + /// + /// let a = arr2(&[[2], [3], [4]]); + /// let b = arr1(&[5, 6, 7]); + /// let (a1, b1) = a.broadcast_with(&b).unwrap(); + /// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); + /// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); + /// ``` + pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> + Result<(ArrayView<'a, A, >::Output>, ArrayView<'b, B, >::Output>), ShapeError> + where + S: Data, + S2: Data, + D: Dimension + BroadcastShape, + E: Dimension, + { + let shape = self.dim.broadcast_shape(&other.dim)?; + if let Some(view1) = self.broadcast(shape.clone()) { + if let Some(view2) = other.broadcast(shape) { + return Ok((view1, view2)) + } + } + return Err(from_kind(ErrorKind::IncompatibleShape)); + } + /// Swap axes `ax` and `bx`. /// /// This does not move any data, it just adjusts the array’s dimensions diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 69c6f698e..80aa8e11a 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -106,9 +106,7 @@ where out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); out } else { - let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); + let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } @@ -143,9 +141,7 @@ where out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); out } else { - let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); - let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); + let (rhs, lhs) = rhs.broadcast_with(self).unwrap(); Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } @@ -171,9 +167,7 @@ where { type Output = Array>::Output>; fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { - let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape).unwrap(); + let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth)) } } diff --git a/tests/broadcast.rs b/tests/broadcast.rs index 5416e9017..26111c780 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -1,4 +1,5 @@ use ndarray::prelude::*; +use ndarray::{ShapeError, ErrorKind, arr3}; #[test] #[cfg(feature = "std")] @@ -81,3 +82,33 @@ fn test_broadcast_1d() { println!("b2=\n{:?}", b2); assert_eq!(b0, b2); } + +#[test] +fn test_broadcast_with() { + let a = arr2(&[[1., 2.], [3., 4.]]); + let b = aview0(&1.); + let (a1, b1) = a.broadcast_with(&b).unwrap(); + assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]])); + assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]])); + + let a = arr2(&[[2], [3], [4]]); + let b = arr1(&[5, 6, 7]); + let (a1, b1) = a.broadcast_with(&b).unwrap(); + assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); + assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); + + // Negative strides and non-contiguous memory + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); + let a = s.slice(s![..;-1,..;2,..]); + let b = s.slice(s![..2, -1, ..]); + let (a1, b1) = a.broadcast_with(&b).unwrap(); + assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]])); + assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]])); + + // ShapeError + let a = arr2(&[[2, 2], [3, 3], [4, 4]]); + let b = arr1(&[5, 6, 7]); + let e = a.broadcast_with(&b); + assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); +} From c24e602be41aaabe1e3c457ed0ce23d9754cbca5 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Fri, 19 Feb 2021 09:08:16 +0800 Subject: [PATCH 12/16] Modify the docs and visibility of broadcast_with --- src/impl_methods.rs | 15 +++------------ tests/broadcast.rs | 30 ------------------------------ 2 files changed, 3 insertions(+), 42 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4202ec257..c9b2d6f49 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1766,20 +1766,11 @@ where unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } - /// Calculate the views of two ArrayBases after broadcasting each other, if possible. + /// For two arrays or views, find their common shape if possible and + /// broadcast them as array views into that shape. /// /// Return `ShapeError` if their shapes can not be broadcast together. - /// - /// ``` - /// use ndarray::{arr1, arr2}; - /// - /// let a = arr2(&[[2], [3], [4]]); - /// let b = arr1(&[5, 6, 7]); - /// let (a1, b1) = a.broadcast_with(&b).unwrap(); - /// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); - /// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); - /// ``` - pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> + pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> Result<(ArrayView<'a, A, >::Output>, ArrayView<'b, B, >::Output>), ShapeError> where S: Data, diff --git a/tests/broadcast.rs b/tests/broadcast.rs index 26111c780..e3d377139 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -82,33 +82,3 @@ fn test_broadcast_1d() { println!("b2=\n{:?}", b2); assert_eq!(b0, b2); } - -#[test] -fn test_broadcast_with() { - let a = arr2(&[[1., 2.], [3., 4.]]); - let b = aview0(&1.); - let (a1, b1) = a.broadcast_with(&b).unwrap(); - assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]])); - assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]])); - - let a = arr2(&[[2], [3], [4]]); - let b = arr1(&[5, 6, 7]); - let (a1, b1) = a.broadcast_with(&b).unwrap(); - assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); - assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); - - // Negative strides and non-contiguous memory - let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); - let a = s.slice(s![..;-1,..;2,..]); - let b = s.slice(s![..2, -1, ..]); - let (a1, b1) = a.broadcast_with(&b).unwrap(); - assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]])); - assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]])); - - // ShapeError - let a = arr2(&[[2, 2], [3, 3], [4, 4]]); - let b = arr1(&[5, 6, 7]); - let e = a.broadcast_with(&b); - assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); -} From 16f382b1eeacf1ea7072bd5ea410bc8f44e6ae2e Mon Sep 17 00:00:00 2001 From: bluss Date: Fri, 12 Mar 2021 10:37:31 +0100 Subject: [PATCH 13/16] TEST: Remove unused imports in tests/broadcast.rs --- tests/broadcast.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/broadcast.rs b/tests/broadcast.rs index e3d377139..5416e9017 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -1,5 +1,4 @@ use ndarray::prelude::*; -use ndarray::{ShapeError, ErrorKind, arr3}; #[test] #[cfg(feature = "std")] From b39593ebc2c2a22c66f58e82fd84f2857c458508 Mon Sep 17 00:00:00 2001 From: bluss Date: Fri, 12 Mar 2021 10:37:49 +0100 Subject: [PATCH 14/16] FIX: Rename BroadcastShape to DimMax For consistency with other dimension traits (to come); with >, Output is the maximum of A and B. --- src/dimension/broadcast.rs | 10 +++++----- src/dimension/dimension_trait.rs | 12 ++++++------ src/dimension/mod.rs | 2 +- src/impl_methods.rs | 6 +++--- src/impl_ops.rs | 22 +++++++++++----------- src/lib.rs | 2 +- tests/dimension.rs | 6 +++--- 7 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 377c8509e..6da4375d0 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -34,7 +34,7 @@ fn co_broadcast(shape1: &D1, shape2: &D2) -> Result { +pub trait DimMax { /// The resulting dimension type after broadcasting. type Output: Dimension; @@ -46,8 +46,8 @@ pub trait BroadcastShape { /// Dimensions of the same type remain unchanged when co_broadcast. /// So you can directly use D as the resulting type. -/// (Instead of >::BroadcastOutput) -impl BroadcastShape for D { +/// (Instead of >::BroadcastOutput) +impl DimMax for D { type Output = D; fn broadcast_shape(&self, other: &D) -> Result { @@ -57,7 +57,7 @@ impl BroadcastShape for D { macro_rules! impl_broadcast_distinct_fixed { ($smaller:ty, $larger:ty) => { - impl BroadcastShape<$larger> for $smaller { + impl DimMax<$larger> for $smaller { type Output = $larger; fn broadcast_shape(&self, other: &$larger) -> Result { @@ -65,7 +65,7 @@ macro_rules! impl_broadcast_distinct_fixed { } } - impl BroadcastShape<$smaller> for $larger { + impl DimMax<$smaller> for $larger { type Output = $larger; fn broadcast_shape(&self, other: &$smaller) -> Result { diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index e3a1ebd09..6007f93ab 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -15,7 +15,7 @@ use super::axes_of; use super::conversion::Convert; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; -use crate::{Axis, BroadcastShape}; +use crate::{Axis, DimMax}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; @@ -46,11 +46,11 @@ pub trait Dimension: + MulAssign + for<'x> MulAssign<&'x Self> + MulAssign - + BroadcastShape - + BroadcastShape - + BroadcastShape - + BroadcastShape<::Smaller, Output=Self> - + BroadcastShape<::Larger, Output=::Larger> + + DimMax + + DimMax + + DimMax + + DimMax<::Smaller, Output=Self> + + DimMax<::Larger, Output=::Larger> { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 98572ac59..7f5eeeaf7 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -12,7 +12,7 @@ use num_integer::div_floor; pub use self::axes::{axes_of, Axes, AxisDescription}; pub use self::axis::Axis; -pub use self::broadcast::BroadcastShape; +pub use self::broadcast::DimMax; pub use self::conversion::IntoDimension; pub use self::dim::*; pub use self::dimension_trait::Dimension; diff --git a/src/impl_methods.rs b/src/impl_methods.rs index c9b2d6f49..3079c9d2a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -14,7 +14,7 @@ use rawpointer::PointerExt; use crate::imp_prelude::*; -use crate::{arraytraits, BroadcastShape}; +use crate::{arraytraits, DimMax}; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ @@ -1771,11 +1771,11 @@ where /// /// Return `ShapeError` if their shapes can not be broadcast together. pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> - Result<(ArrayView<'a, A, >::Output>, ArrayView<'b, B, >::Output>), ShapeError> + Result<(ArrayView<'a, A, >::Output>, ArrayView<'b, B, >::Output>), ShapeError> where S: Data, S2: Data, - D: Dimension + BroadcastShape, + D: Dimension + DimMax, E: Dimension, { let shape = self.dim.broadcast_shape(&other.dim)?; diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 80aa8e11a..d38cb566a 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::dimension::BroadcastShape; +use crate::dimension::DimMax; use crate::Zip; use num_complex::Complex; @@ -68,10 +68,10 @@ where B: Clone, S: DataOwned + DataMut, S2: Data, - D: Dimension + BroadcastShape, + D: Dimension + DimMax, E: Dimension, { - type Output = ArrayBase>::Output>; + type Output = ArrayBase>::Output>; fn $mth(self, rhs: ArrayBase) -> Self::Output { self.$mth(&rhs) @@ -95,14 +95,14 @@ where B: Clone, S: DataOwned + DataMut, S2: Data, - D: Dimension + BroadcastShape, + D: Dimension + DimMax, E: Dimension, { - type Output = ArrayBase>::Output>; + type Output = ArrayBase>::Output>; fn $mth(self, rhs: &ArrayBase) -> Self::Output { if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { - let mut out = self.into_dimensionality::<>::Output>().unwrap(); + let mut out = self.into_dimensionality::<>::Output>().unwrap(); out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); out } else { @@ -130,14 +130,14 @@ where S: Data, S2: DataOwned + DataMut, D: Dimension, - E: Dimension + BroadcastShape, + E: Dimension + DimMax, { - type Output = ArrayBase>::Output>; + type Output = ArrayBase>::Output>; fn $mth(self, rhs: ArrayBase) -> Self::Output where { if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { - let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); + let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); out } else { @@ -162,10 +162,10 @@ where B: Clone, S: Data, S2: Data, - D: Dimension + BroadcastShape, + D: Dimension + DimMax, E: Dimension, { - type Output = Array>::Output>; + type Output = Array>::Output>; fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth)) diff --git a/src/lib.rs b/src/lib.rs index 1a760bf31..c079b4817 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ use std::marker::PhantomData; use alloc::sync::Arc; pub use crate::dimension::dim::*; -pub use crate::dimension::BroadcastShape; +pub use crate::dimension::DimMax; pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis}; pub use crate::dimension::IxDynImpl; diff --git a/tests/dimension.rs b/tests/dimension.rs index ede0dc7d2..2bbc50a68 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -3,7 +3,7 @@ use defmac::defmac; use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, Ix0, IxDyn, IxDynImpl, RemoveAxis, - ErrorKind, ShapeError, BroadcastShape}; + ErrorKind, ShapeError, DimMax}; use std::hash::{Hash, Hasher}; @@ -347,9 +347,9 @@ fn test_broadcast_shape() { fn test_co( d1: &D1, d2: &D2, - r: Result<>::Output, ShapeError>, + r: Result<>::Output, ShapeError>, ) where - D1: Dimension + BroadcastShape, + D1: Dimension + DimMax, D2: Dimension, { let d = d1.broadcast_shape(d2); From 38f7341c692f00297146b3d7a28e32ba3e097494 Mon Sep 17 00:00:00 2001 From: bluss Date: Fri, 12 Mar 2021 10:47:47 +0100 Subject: [PATCH 15/16] FIX: Remove broadcast_shape from the DimMax trait While calling co_broadcast directly is less convenient, for now they are two different functions. --- src/dimension/broadcast.rs | 82 ++++++++++++++++++++++++++++---------- src/dimension/mod.rs | 2 +- src/impl_methods.rs | 3 +- tests/dimension.rs | 50 +---------------------- 4 files changed, 64 insertions(+), 73 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 6da4375d0..dc1513f04 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; /// /// Uses the [NumPy broadcasting rules] // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). -fn co_broadcast(shape1: &D1, shape2: &D2) -> Result - where - D1: Dimension, - D2: Dimension, - Output: Dimension, +pub(crate) fn co_broadcast(shape1: &D1, shape2: &D2) -> Result +where + D1: Dimension, + D2: Dimension, + Output: Dimension, { let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim()); // Swap the order if d2 is longer. @@ -37,11 +37,6 @@ fn co_broadcast(shape1: &D1, shape2: &D2) -> Result { /// The resulting dimension type after broadcasting. type Output: Dimension; - - /// Determines the shape after broadcasting the shapes together. - /// - /// If the shapes are not compatible, returns `Err`. - fn broadcast_shape(&self, other: &Other) -> Result; } /// Dimensions of the same type remain unchanged when co_broadcast. @@ -49,28 +44,16 @@ pub trait DimMax { /// (Instead of >::BroadcastOutput) impl DimMax for D { type Output = D; - - fn broadcast_shape(&self, other: &D) -> Result { - co_broadcast::(self, other) - } } macro_rules! impl_broadcast_distinct_fixed { ($smaller:ty, $larger:ty) => { impl DimMax<$larger> for $smaller { type Output = $larger; - - fn broadcast_shape(&self, other: &$larger) -> Result { - co_broadcast::(self, other) - } } impl DimMax<$smaller> for $larger { type Output = $larger; - - fn broadcast_shape(&self, other: &$smaller) -> Result { - co_broadcast::(self, other) - } } }; } @@ -103,3 +86,58 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn); impl_broadcast_distinct_fixed!(Ix4, IxDyn); impl_broadcast_distinct_fixed!(Ix5, IxDyn); impl_broadcast_distinct_fixed!(Ix6, IxDyn); + + +#[cfg(test)] +#[cfg(feature = "std")] +mod tests { + use super::co_broadcast; + use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind}; + + #[test] + fn test_broadcast_shape() { + fn test_co( + d1: &D1, + d2: &D2, + r: Result<>::Output, ShapeError>, + ) where + D1: Dimension + DimMax, + D2: Dimension, + { + let d = co_broadcast::>::Output>(&d1, d2); + assert_eq!(d, r); + } + test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); + test_co( + &Dim([1, 2, 2]), + &Dim([1, 3, 4]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); + let v = vec![1, 2, 3, 4, 5, 6, 7]; + test_co( + &Dim(vec![1, 1, 3, 1, 5, 1, 7]), + &Dim([2, 1, 4, 1, 6, 1]), + Ok(Dim(IxDynImpl::from(v.as_slice()))), + ); + let d = Dim([1, 2, 1, 3]); + test_co(&d, &d, Ok(d)); + test_co( + &Dim([2, 1, 2]).into_dyn(), + &Dim(0), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co( + &Dim([2, 1, 1]), + &Dim([0, 0, 1, 3, 4]), + Ok(Dim([0, 0, 2, 3, 4])), + ); + test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); + test_co( + &Dim([1, 3, 0, 1, 1]), + &Dim([1, 2, 3, 1]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + } +} diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 7f5eeeaf7..2505681b5 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -29,7 +29,7 @@ use std::mem; mod macros; mod axes; mod axis; -mod broadcast; +pub(crate) mod broadcast; mod conversion; pub mod dim; mod dimension_trait; diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 3079c9d2a..40c7fe1f2 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -21,6 +21,7 @@ use crate::dimension::{ abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; +use crate::dimension::broadcast::co_broadcast; use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; @@ -1778,7 +1779,7 @@ where D: Dimension + DimMax, E: Dimension, { - let shape = self.dim.broadcast_shape(&other.dim)?; + let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; if let Some(view1) = self.broadcast(shape.clone()) { if let Some(view2) = other.broadcast(shape) { return Ok((view1, view2)) diff --git a/tests/dimension.rs b/tests/dimension.rs index 2bbc50a68..939b4f0e3 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -2,8 +2,7 @@ use defmac::defmac; -use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, Ix0, IxDyn, IxDynImpl, RemoveAxis, - ErrorKind, ShapeError, DimMax}; +use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IxDyn, RemoveAxis}; use std::hash::{Hash, Hasher}; @@ -341,50 +340,3 @@ fn test_all_ndindex() { ndindex!(10, 4, 3, 2, 2); ndindex!(10, 4, 3, 2, 2, 2); } - -#[test] -fn test_broadcast_shape() { - fn test_co( - d1: &D1, - d2: &D2, - r: Result<>::Output, ShapeError>, - ) where - D1: Dimension + DimMax, - D2: Dimension, - { - let d = d1.broadcast_shape(d2); - assert_eq!(d, r); - } - test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); - test_co( - &Dim([1, 2, 2]), - &Dim([1, 3, 4]), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); - test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); - let v = vec![1, 2, 3, 4, 5, 6, 7]; - test_co( - &Dim(vec![1, 1, 3, 1, 5, 1, 7]), - &Dim([2, 1, 4, 1, 6, 1]), - Ok(Dim(IxDynImpl::from(v.as_slice()))), - ); - let d = Dim([1, 2, 1, 3]); - test_co(&d, &d, Ok(d)); - test_co( - &Dim([2, 1, 2]).into_dyn(), - &Dim(0), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); - test_co( - &Dim([2, 1, 1]), - &Dim([0, 0, 1, 3, 4]), - Ok(Dim([0, 0, 2, 3, 4])), - ); - test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); - test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); - test_co( - &Dim([1, 3, 0, 1, 1]), - &Dim([1, 2, 3, 1]), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), - ); -} From 03cfdfc445e5672ae47cca34012a3cfbda219db3 Mon Sep 17 00:00:00 2001 From: bluss Date: Fri, 12 Mar 2021 11:50:11 +0100 Subject: [PATCH 16/16] MAINT: Fix clippy warnings for broadcast_with --- src/impl_methods.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 40c7fe1f2..958fc3f1c 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1771,8 +1771,9 @@ where /// broadcast them as array views into that shape. /// /// Return `ShapeError` if their shapes can not be broadcast together. + #[allow(clippy::type_complexity)] pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> - Result<(ArrayView<'a, A, >::Output>, ArrayView<'b, B, >::Output>), ShapeError> + Result<(ArrayView<'a, A, DimMaxOf>, ArrayView<'b, B, DimMaxOf>), ShapeError> where S: Data, S2: Data, @@ -1782,10 +1783,10 @@ where let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; if let Some(view1) = self.broadcast(shape.clone()) { if let Some(view2) = other.broadcast(shape) { - return Ok((view1, view2)) + return Ok((view1, view2)); } } - return Err(from_kind(ErrorKind::IncompatibleShape)); + Err(from_kind(ErrorKind::IncompatibleShape)) } /// Swap axes `ax` and `bx`. @@ -2465,3 +2466,5 @@ unsafe fn unlimited_transmute(data: A) -> B { let old_data = ManuallyDrop::new(data); (&*old_data as *const A as *const B).read() } + +type DimMaxOf = >::Output;