diff --git a/src/data_traits.rs b/src/data_traits.rs index 1e191c468..7ac63d54e 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -16,9 +16,7 @@ use std::ptr::NonNull; use alloc::sync::Arc; use alloc::vec::Vec; -use crate::{ - ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, -}; +use crate::{ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; /// Array representation trait. /// @@ -414,7 +412,6 @@ pub unsafe trait DataOwned: Data { /// Corresponding owned data with MaybeUninit elements type MaybeUninit: DataOwned> + RawDataSubst; - #[doc(hidden)] fn new(elements: Vec) -> Self; @@ -440,6 +437,7 @@ unsafe impl DataOwned for OwnedRepr { fn new(elements: Vec) -> Self { OwnedRepr::from(elements) } + fn into_shared(self) -> OwnedArcRepr { OwnedArcRepr(Arc::new(self)) } @@ -622,3 +620,4 @@ impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a mut A> { ViewRepr::new() } } + diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs new file mode 100644 index 000000000..dc1513f04 --- /dev/null +++ b/src/dimension/broadcast.rs @@ -0,0 +1,143 @@ +use crate::error::*; +use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; + +/// Calculate the common shape for a pair of array shapes, that they can be broadcasted +/// to. Return an error if the shapes are not compatible. +/// +/// Uses the [NumPy broadcasting rules] +// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). +pub(crate) fn co_broadcast(shape1: &D1, shape2: &D2) -> Result +where + D1: Dimension, + D2: Dimension, + Output: Dimension, +{ + let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim()); + // Swap the order if d2 is longer. + if overflow { + return co_broadcast::(shape2, shape1); + } + // The output should be the same length as shape1. + let mut out = Output::zeros(shape1.ndim()); + for (out, s) in izip!(out.slice_mut(), shape1.slice()) { + *out = *s; + } + for (out, s2) in izip!(&mut out.slice_mut()[k..], shape2.slice()) { + if *out != *s2 { + if *out == 1 { + *out = *s2 + } else if *s2 != 1 { + return Err(from_kind(ErrorKind::IncompatibleShape)); + } + } + } + Ok(out) +} + +pub trait DimMax { + /// The resulting dimension type after broadcasting. + type Output: Dimension; +} + +/// Dimensions of the same type remain unchanged when co_broadcast. +/// So you can directly use D as the resulting type. +/// (Instead of >::BroadcastOutput) +impl DimMax for D { + type Output = D; +} + +macro_rules! impl_broadcast_distinct_fixed { + ($smaller:ty, $larger:ty) => { + impl DimMax<$larger> for $smaller { + type Output = $larger; + } + + impl DimMax<$smaller> for $larger { + type Output = $larger; + } + }; +} + +impl_broadcast_distinct_fixed!(Ix0, Ix1); +impl_broadcast_distinct_fixed!(Ix0, Ix2); +impl_broadcast_distinct_fixed!(Ix0, Ix3); +impl_broadcast_distinct_fixed!(Ix0, Ix4); +impl_broadcast_distinct_fixed!(Ix0, Ix5); +impl_broadcast_distinct_fixed!(Ix0, Ix6); +impl_broadcast_distinct_fixed!(Ix1, Ix2); +impl_broadcast_distinct_fixed!(Ix1, Ix3); +impl_broadcast_distinct_fixed!(Ix1, Ix4); +impl_broadcast_distinct_fixed!(Ix1, Ix5); +impl_broadcast_distinct_fixed!(Ix1, Ix6); +impl_broadcast_distinct_fixed!(Ix2, Ix3); +impl_broadcast_distinct_fixed!(Ix2, Ix4); +impl_broadcast_distinct_fixed!(Ix2, Ix5); +impl_broadcast_distinct_fixed!(Ix2, Ix6); +impl_broadcast_distinct_fixed!(Ix3, Ix4); +impl_broadcast_distinct_fixed!(Ix3, Ix5); +impl_broadcast_distinct_fixed!(Ix3, Ix6); +impl_broadcast_distinct_fixed!(Ix4, Ix5); +impl_broadcast_distinct_fixed!(Ix4, Ix6); +impl_broadcast_distinct_fixed!(Ix5, Ix6); +impl_broadcast_distinct_fixed!(Ix0, IxDyn); +impl_broadcast_distinct_fixed!(Ix1, IxDyn); +impl_broadcast_distinct_fixed!(Ix2, IxDyn); +impl_broadcast_distinct_fixed!(Ix3, IxDyn); +impl_broadcast_distinct_fixed!(Ix4, IxDyn); +impl_broadcast_distinct_fixed!(Ix5, IxDyn); +impl_broadcast_distinct_fixed!(Ix6, IxDyn); + + +#[cfg(test)] +#[cfg(feature = "std")] +mod tests { + use super::co_broadcast; + use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind}; + + #[test] + fn test_broadcast_shape() { + fn test_co( + d1: &D1, + d2: &D2, + r: Result<>::Output, ShapeError>, + ) where + D1: Dimension + DimMax, + D2: Dimension, + { + let d = co_broadcast::>::Output>(&d1, d2); + assert_eq!(d, r); + } + test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); + test_co( + &Dim([1, 2, 2]), + &Dim([1, 3, 4]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); + let v = vec![1, 2, 3, 4, 5, 6, 7]; + test_co( + &Dim(vec![1, 1, 3, 1, 5, 1, 7]), + &Dim([2, 1, 4, 1, 6, 1]), + Ok(Dim(IxDynImpl::from(v.as_slice()))), + ); + let d = Dim([1, 2, 1, 3]); + test_co(&d, &d, Ok(d)); + test_co( + &Dim([2, 1, 2]).into_dyn(), + &Dim(0), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + test_co( + &Dim([2, 1, 1]), + &Dim([0, 0, 1, 3, 4]), + Ok(Dim([0, 0, 2, 3, 4])), + ); + test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); + test_co( + &Dim([1, 3, 0, 1, 1]), + &Dim([1, 2, 3, 1]), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), + ); + } +} diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 92f241189..6007f93ab 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -15,7 +15,7 @@ use super::axes_of; use super::conversion::Convert; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; -use crate::Axis; +use crate::{Axis, DimMax}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; @@ -46,6 +46,11 @@ pub trait Dimension: + MulAssign + for<'x> MulAssign<&'x Self> + MulAssign + + DimMax + + DimMax + + DimMax + + DimMax<::Smaller, Output=Self> + + DimMax<::Larger, Output=::Larger> { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 1359b8f39..2505681b5 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -12,6 +12,7 @@ use num_integer::div_floor; pub use self::axes::{axes_of, Axes, AxisDescription}; pub use self::axis::Axis; +pub use self::broadcast::DimMax; pub use self::conversion::IntoDimension; pub use self::dim::*; pub use self::dimension_trait::Dimension; @@ -28,6 +29,7 @@ use std::mem; mod macros; mod axes; mod axis; +pub(crate) mod broadcast; mod conversion; pub mod dim; mod dimension_trait; diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9943e7c8e..958fc3f1c 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -14,14 +14,15 @@ use rawpointer::PointerExt; use crate::imp_prelude::*; -use crate::arraytraits; +use crate::{arraytraits, DimMax}; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; -use crate::error::{self, ErrorKind, ShapeError}; +use crate::dimension::broadcast::co_broadcast; +use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::Zip; @@ -1766,6 +1767,28 @@ where unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } + /// For two arrays or views, find their common shape if possible and + /// broadcast them as array views into that shape. + /// + /// Return `ShapeError` if their shapes can not be broadcast together. + #[allow(clippy::type_complexity)] + pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase) -> + Result<(ArrayView<'a, A, DimMaxOf>, ArrayView<'b, B, DimMaxOf>), ShapeError> + where + S: Data, + S2: Data, + D: Dimension + DimMax, + E: Dimension, + { + let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; + if let Some(view1) = self.broadcast(shape.clone()) { + if let Some(view2) = other.broadcast(shape) { + return Ok((view1, view2)); + } + } + Err(from_kind(ErrorKind::IncompatibleShape)) + } + /// Swap axes `ax` and `bx`. /// /// This does not move any data, it just adjusts the array’s dimensions @@ -2013,7 +2036,7 @@ where self.map_inplace(move |elt| *elt = x.clone()); } - fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) + pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) where S: DataMut, S2: Data, @@ -2443,3 +2466,5 @@ unsafe fn unlimited_transmute(data: A) -> B { let old_data = ManuallyDrop::new(data); (&*old_data as *const A as *const B).read() } + +type DimMaxOf = >::Output; diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 256bee3e5..d38cb566a 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -6,6 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use crate::dimension::DimMax; +use crate::Zip; use num_complex::Complex; /// Elements that can be used as direct operands in arithmetic with arrays. @@ -53,11 +55,11 @@ macro_rules! impl_binary_op( /// Perform elementwise #[doc=$doc] /// between `self` and `rhs`, -/// and return the result (based on `self`). +/// and return the result. /// /// `self` must be an `Array` or `ArcArray`. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` is broadcast to their broadcast shape. /// /// **Panics** if broadcasting isn’t possible. impl $trt> for ArrayBase @@ -66,11 +68,11 @@ where B: Clone, S: DataOwned + DataMut, S2: Data, - D: Dimension, + D: Dimension + DimMax, E: Dimension, { - type Output = ArrayBase; - fn $mth(self, rhs: ArrayBase) -> ArrayBase + type Output = ArrayBase>::Output>; + fn $mth(self, rhs: ArrayBase) -> Self::Output { self.$mth(&rhs) } @@ -79,9 +81,12 @@ where /// Perform elementwise #[doc=$doc] /// between `self` and reference `rhs`, -/// and return the result (based on `self`). +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase @@ -90,16 +95,55 @@ where B: Clone, S: DataOwned + DataMut, S2: Data, - D: Dimension, + D: Dimension + DimMax, E: Dimension, { - type Output = ArrayBase; - fn $mth(mut self, rhs: &ArrayBase) -> ArrayBase + type Output = ArrayBase>::Output>; + fn $mth(self, rhs: &ArrayBase) -> Self::Output { - self.zip_mut_with(rhs, |x, y| { - *x = x.clone() $operator y.clone(); - }); - self + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { + let mut out = self.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); + out + } else { + let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); + Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) + } + } +} + +/// Perform elementwise +#[doc=$doc] +/// between reference `self` and `rhs`, +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. +/// +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase +where + A: Clone + $trt, + B: Clone, + S: Data, + S2: DataOwned + DataMut, + D: Dimension, + E: Dimension + DimMax, +{ + type Output = ArrayBase>::Output>; + fn $mth(self, rhs: ArrayBase) -> Self::Output + where + { + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { + let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); + out + } else { + let (rhs, lhs) = rhs.broadcast_with(self).unwrap(); + Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) + } } } @@ -108,7 +152,8 @@ where /// between references `self` and `rhs`, /// and return the result as a new `Array`. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape, +/// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase @@ -117,13 +162,13 @@ where B: Clone, S: Data, S2: Data, - D: Dimension, + D: Dimension + DimMax, E: Dimension, { - type Output = Array; - fn $mth(self, rhs: &'a ArrayBase) -> Array { - // FIXME: Can we co-broadcast arrays here? And how? - self.to_owned().$mth(rhs) + type Output = Array>::Output>; + fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { + let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); + Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth)) } } @@ -228,6 +273,18 @@ mod arithmetic_ops { use num_complex::Complex; use std::ops::*; + fn clone_opf(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C { + move |x, y| f(x.clone(), y.clone()) + } + + fn clone_iopf(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) { + move |x, y| *x = f(x.clone(), y.clone()) + } + + fn clone_iopf_rev(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) { + move |x, y| *x = f(y.clone(), x.clone()) + } + impl_binary_op!(Add, +, add, +=, "addition"); impl_binary_op!(Sub, -, sub, -=, "subtraction"); impl_binary_op!(Mul, *, mul, *=, "multiplication"); diff --git a/src/lib.rs b/src/lib.rs index 9cd7dc3f3..c079b4817 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,6 +134,7 @@ use std::marker::PhantomData; use alloc::sync::Arc; pub use crate::dimension::dim::*; +pub use crate::dimension::DimMax; pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis}; pub use crate::dimension::IxDynImpl; diff --git a/tests/array.rs b/tests/array.rs index 6b72bb5c4..b0a28ca41 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -9,9 +9,9 @@ use defmac::defmac; use itertools::{enumerate, zip, Itertools}; -use ndarray::indices; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; +use ndarray::indices; use ndarray::{Slice, SliceInfo, SliceOrIndex}; macro_rules! assert_panics { @@ -1565,6 +1565,49 @@ fn arithmetic_broadcast() { a.swap_axes(0, 1); let b = a.clone() / aview0(&1.); assert_eq!(a, b); + + // reference + let a = arr2(&[[2], [3], [4]]); + let b = arr1(&[5, 6, 7]); + assert_eq!(&a + &b, arr2(&[[7, 8, 9], [8, 9, 10], [9, 10, 11]])); + assert_eq!( + a.clone() - &b, + arr2(&[[-3, -4, -5], [-2, -3, -4], [-1, -2, -3]]) + ); + assert_eq!( + a.clone() * b.clone(), + arr2(&[[10, 12, 14], [15, 18, 21], [20, 24, 28]]) + ); + assert_eq!(&b / a, arr2(&[[2, 3, 3], [1, 2, 2], [1, 1, 1]])); + + // Negative strides and non-contiguous memory + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); + let a = s.slice(s![..;-1,..;2,..]); + let b = s.slice(s![..2, -1, ..]); + let mut c = s.clone(); + c.collapse_axis(Axis(2), 1); + let c = c.slice(s![1,..;2,..]); + assert_eq!( + &a.to_owned() + &b, + arr3(&[[[11, 15], [20, 24]], [[10, 14], [19, 23]]]) + ); + assert_eq!( + &a + b.into_owned() + c, + arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) + ); + + // shared array + let sa = a.to_shared(); + let sa2 = sa.to_shared(); + let sb = b.to_shared(); + let sb2 = sb.to_shared(); + let sc = c.to_shared(); + let sc2 = sc.into_shared(); + assert_eq!( + sa2 + &sb2 + sc2.into_owned(), + arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) + ); } #[test]