Skip to content

Commit

Permalink
Merge pull request #898 from SparrowLii/co_broadcast
Browse files Browse the repository at this point in the history
Implement co-broadcasting in operator overloading
  • Loading branch information
bluss committed Mar 12, 2021
2 parents 5bd5891 + 03cfdfc commit b5687f8
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 29 deletions.
7 changes: 3 additions & 4 deletions src/data_traits.rs
Expand Up @@ -16,9 +16,7 @@ use std::ptr::NonNull;
use alloc::sync::Arc;
use alloc::vec::Vec;

use crate::{
ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr,
};
use crate::{ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr};

/// Array representation trait.
///
Expand Down Expand Up @@ -414,7 +412,6 @@ pub unsafe trait DataOwned: Data {
/// Corresponding owned data with MaybeUninit elements
type MaybeUninit: DataOwned<Elem = MaybeUninit<Self::Elem>>
+ RawDataSubst<Self::Elem, Output=Self>;

#[doc(hidden)]
fn new(elements: Vec<Self::Elem>) -> Self;

Expand All @@ -440,6 +437,7 @@ unsafe impl<A> DataOwned for OwnedRepr<A> {
fn new(elements: Vec<A>) -> Self {
OwnedRepr::from(elements)
}

fn into_shared(self) -> OwnedArcRepr<A> {
OwnedArcRepr(Arc::new(self))
}
Expand Down Expand Up @@ -622,3 +620,4 @@ impl<'a, A: 'a, B: 'a> RawDataSubst<B> for ViewRepr<&'a mut A> {
ViewRepr::new()
}
}

143 changes: 143 additions & 0 deletions src/dimension/broadcast.rs
@@ -0,0 +1,143 @@
use crate::error::*;
use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};

/// Calculate the common shape for a pair of array shapes, that they can be broadcasted
/// to. Return an error if the shapes are not compatible.
///
/// Uses the [NumPy broadcasting rules]
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
pub(crate) fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
where
D1: Dimension,
D2: Dimension,
Output: Dimension,
{
let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim());
// Swap the order if d2 is longer.
if overflow {
return co_broadcast::<D2, D1, Output>(shape2, shape1);
}
// The output should be the same length as shape1.
let mut out = Output::zeros(shape1.ndim());
for (out, s) in izip!(out.slice_mut(), shape1.slice()) {
*out = *s;
}
for (out, s2) in izip!(&mut out.slice_mut()[k..], shape2.slice()) {
if *out != *s2 {
if *out == 1 {
*out = *s2
} else if *s2 != 1 {
return Err(from_kind(ErrorKind::IncompatibleShape));
}
}
}
Ok(out)
}

pub trait DimMax<Other: Dimension> {
/// The resulting dimension type after broadcasting.
type Output: Dimension;
}

/// Dimensions of the same type remain unchanged when co_broadcast.
/// So you can directly use D as the resulting type.
/// (Instead of <D as DimMax<D>>::BroadcastOutput)
impl<D: Dimension> DimMax<D> for D {
type Output = D;
}

macro_rules! impl_broadcast_distinct_fixed {
($smaller:ty, $larger:ty) => {
impl DimMax<$larger> for $smaller {
type Output = $larger;
}

impl DimMax<$smaller> for $larger {
type Output = $larger;
}
};
}

impl_broadcast_distinct_fixed!(Ix0, Ix1);
impl_broadcast_distinct_fixed!(Ix0, Ix2);
impl_broadcast_distinct_fixed!(Ix0, Ix3);
impl_broadcast_distinct_fixed!(Ix0, Ix4);
impl_broadcast_distinct_fixed!(Ix0, Ix5);
impl_broadcast_distinct_fixed!(Ix0, Ix6);
impl_broadcast_distinct_fixed!(Ix1, Ix2);
impl_broadcast_distinct_fixed!(Ix1, Ix3);
impl_broadcast_distinct_fixed!(Ix1, Ix4);
impl_broadcast_distinct_fixed!(Ix1, Ix5);
impl_broadcast_distinct_fixed!(Ix1, Ix6);
impl_broadcast_distinct_fixed!(Ix2, Ix3);
impl_broadcast_distinct_fixed!(Ix2, Ix4);
impl_broadcast_distinct_fixed!(Ix2, Ix5);
impl_broadcast_distinct_fixed!(Ix2, Ix6);
impl_broadcast_distinct_fixed!(Ix3, Ix4);
impl_broadcast_distinct_fixed!(Ix3, Ix5);
impl_broadcast_distinct_fixed!(Ix3, Ix6);
impl_broadcast_distinct_fixed!(Ix4, Ix5);
impl_broadcast_distinct_fixed!(Ix4, Ix6);
impl_broadcast_distinct_fixed!(Ix5, Ix6);
impl_broadcast_distinct_fixed!(Ix0, IxDyn);
impl_broadcast_distinct_fixed!(Ix1, IxDyn);
impl_broadcast_distinct_fixed!(Ix2, IxDyn);
impl_broadcast_distinct_fixed!(Ix3, IxDyn);
impl_broadcast_distinct_fixed!(Ix4, IxDyn);
impl_broadcast_distinct_fixed!(Ix5, IxDyn);
impl_broadcast_distinct_fixed!(Ix6, IxDyn);


#[cfg(test)]
#[cfg(feature = "std")]
mod tests {
use super::co_broadcast;
use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind};

#[test]
fn test_broadcast_shape() {
fn test_co<D1, D2>(
d1: &D1,
d2: &D2,
r: Result<<D1 as DimMax<D2>>::Output, ShapeError>,
) where
D1: Dimension + DimMax<D2>,
D2: Dimension,
{
let d = co_broadcast::<D1, D2, <D1 as DimMax<D2>>::Output>(&d1, d2);
assert_eq!(d, r);
}
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
test_co(
&Dim([1, 2, 2]),
&Dim([1, 3, 4]),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
let v = vec![1, 2, 3, 4, 5, 6, 7];
test_co(
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
&Dim([2, 1, 4, 1, 6, 1]),
Ok(Dim(IxDynImpl::from(v.as_slice()))),
);
let d = Dim([1, 2, 1, 3]);
test_co(&d, &d, Ok(d));
test_co(
&Dim([2, 1, 2]).into_dyn(),
&Dim(0),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
test_co(
&Dim([2, 1, 1]),
&Dim([0, 0, 1, 3, 4]),
Ok(Dim([0, 0, 2, 3, 4])),
);
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
test_co(
&Dim([1, 3, 0, 1, 1]),
&Dim([1, 2, 3, 1]),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
}
}
7 changes: 6 additions & 1 deletion src/dimension/dimension_trait.rs
Expand Up @@ -15,7 +15,7 @@ use super::axes_of;
use super::conversion::Convert;
use super::{stride_offset, stride_offset_checked};
use crate::itertools::{enumerate, zip};
use crate::Axis;
use crate::{Axis, DimMax};
use crate::IntoDimension;
use crate::RemoveAxis;
use crate::{ArrayView1, ArrayViewMut1};
Expand Down Expand Up @@ -46,6 +46,11 @@ pub trait Dimension:
+ MulAssign
+ for<'x> MulAssign<&'x Self>
+ MulAssign<usize>
+ DimMax<Ix0, Output=Self>
+ DimMax<Self, Output=Self>
+ DimMax<IxDyn, Output=IxDyn>
+ DimMax<<Self as Dimension>::Smaller, Output=Self>
+ DimMax<<Self as Dimension>::Larger, Output=<Self as Dimension>::Larger>
{
/// For fixed-size dimension representations (e.g. `Ix2`), this should be
/// `Some(ndim)`, and for variable-size dimension representations (e.g.
Expand Down
2 changes: 2 additions & 0 deletions src/dimension/mod.rs
Expand Up @@ -12,6 +12,7 @@ use num_integer::div_floor;

pub use self::axes::{axes_of, Axes, AxisDescription};
pub use self::axis::Axis;
pub use self::broadcast::DimMax;
pub use self::conversion::IntoDimension;
pub use self::dim::*;
pub use self::dimension_trait::Dimension;
Expand All @@ -28,6 +29,7 @@ use std::mem;
mod macros;
mod axes;
mod axis;
pub(crate) mod broadcast;
mod conversion;
pub mod dim;
mod dimension_trait;
Expand Down
31 changes: 28 additions & 3 deletions src/impl_methods.rs
Expand Up @@ -14,14 +14,15 @@ use rawpointer::PointerExt;

use crate::imp_prelude::*;

use crate::arraytraits;
use crate::{arraytraits, DimMax};
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
};
use crate::error::{self, ErrorKind, ShapeError};
use crate::dimension::broadcast::co_broadcast;
use crate::error::{self, ErrorKind, ShapeError, from_kind};
use crate::math_cell::MathCell;
use crate::itertools::zip;
use crate::zip::Zip;
Expand Down Expand Up @@ -1766,6 +1767,28 @@ where
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
}

/// For two arrays or views, find their common shape if possible and
/// broadcast them as array views into that shape.
///
/// Return `ShapeError` if their shapes can not be broadcast together.
#[allow(clippy::type_complexity)]
pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
Result<(ArrayView<'a, A, DimMaxOf<D, E>>, ArrayView<'b, B, DimMaxOf<D, E>>), ShapeError>
where
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension + DimMax<E>,
E: Dimension,
{
let shape = co_broadcast::<D, E, <D as DimMax<E>>::Output>(&self.dim, &other.dim)?;
if let Some(view1) = self.broadcast(shape.clone()) {
if let Some(view2) = other.broadcast(shape) {
return Ok((view1, view2));
}
}
Err(from_kind(ErrorKind::IncompatibleShape))
}

/// Swap axes `ax` and `bx`.
///
/// This does not move any data, it just adjusts the array’s dimensions
Expand Down Expand Up @@ -2013,7 +2036,7 @@ where
self.map_inplace(move |elt| *elt = x.clone());
}

fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
pub(crate) fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
where
S: DataMut,
S2: Data<Elem = B>,
Expand Down Expand Up @@ -2443,3 +2466,5 @@ unsafe fn unlimited_transmute<A, B>(data: A) -> B {
let old_data = ManuallyDrop::new(data);
(&*old_data as *const A as *const B).read()
}

type DimMaxOf<A, B> = <A as DimMax<B>>::Output;

0 comments on commit b5687f8

Please sign in to comment.