Skip to content

Commit

Permalink
FIX: Remove broadcast_shape from the DimMax trait
Browse files Browse the repository at this point in the history
While calling co_broadcast directly is less convenient, for now they are
two different functions.
  • Loading branch information
bluss committed Mar 12, 2021
1 parent b39593e commit 38f7341
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 73 deletions.
82 changes: 60 additions & 22 deletions src/dimension/broadcast.rs
Expand Up @@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
///
/// Uses the [NumPy broadcasting rules]
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
where
D1: Dimension,
D2: Dimension,
Output: Dimension,
pub(crate) fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
where
D1: Dimension,
D2: Dimension,
Output: Dimension,
{
let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim());
// Swap the order if d2 is longer.
Expand All @@ -37,40 +37,23 @@ fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, Shap
pub trait DimMax<Other: Dimension> {
/// The resulting dimension type after broadcasting.
type Output: Dimension;

/// Determines the shape after broadcasting the shapes together.
///
/// If the shapes are not compatible, returns `Err`.
fn broadcast_shape(&self, other: &Other) -> Result<Self::Output, ShapeError>;
}

/// Dimensions of the same type remain unchanged when co_broadcast.
/// So you can directly use D as the resulting type.
/// (Instead of <D as DimMax<D>>::BroadcastOutput)
impl<D: Dimension> DimMax<D> for D {
type Output = D;

fn broadcast_shape(&self, other: &D) -> Result<Self::Output, ShapeError> {
co_broadcast::<D, D, Self::Output>(self, other)
}
}

macro_rules! impl_broadcast_distinct_fixed {
($smaller:ty, $larger:ty) => {
impl DimMax<$larger> for $smaller {
type Output = $larger;

fn broadcast_shape(&self, other: &$larger) -> Result<Self::Output, ShapeError> {
co_broadcast::<Self, $larger, Self::Output>(self, other)
}
}

impl DimMax<$smaller> for $larger {
type Output = $larger;

fn broadcast_shape(&self, other: &$smaller) -> Result<Self::Output, ShapeError> {
co_broadcast::<Self, $smaller, Self::Output>(self, other)
}
}
};
}
Expand Down Expand Up @@ -103,3 +86,58 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn);
impl_broadcast_distinct_fixed!(Ix4, IxDyn);
impl_broadcast_distinct_fixed!(Ix5, IxDyn);
impl_broadcast_distinct_fixed!(Ix6, IxDyn);


#[cfg(test)]
#[cfg(feature = "std")]
mod tests {
use super::co_broadcast;
use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind};

#[test]
fn test_broadcast_shape() {
fn test_co<D1, D2>(
d1: &D1,
d2: &D2,
r: Result<<D1 as DimMax<D2>>::Output, ShapeError>,
) where
D1: Dimension + DimMax<D2>,
D2: Dimension,
{
let d = co_broadcast::<D1, D2, <D1 as DimMax<D2>>::Output>(&d1, d2);
assert_eq!(d, r);
}
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
test_co(
&Dim([1, 2, 2]),
&Dim([1, 3, 4]),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
let v = vec![1, 2, 3, 4, 5, 6, 7];
test_co(
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
&Dim([2, 1, 4, 1, 6, 1]),
Ok(Dim(IxDynImpl::from(v.as_slice()))),
);
let d = Dim([1, 2, 1, 3]);
test_co(&d, &d, Ok(d));
test_co(
&Dim([2, 1, 2]).into_dyn(),
&Dim(0),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
test_co(
&Dim([2, 1, 1]),
&Dim([0, 0, 1, 3, 4]),
Ok(Dim([0, 0, 2, 3, 4])),
);
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
test_co(
&Dim([1, 3, 0, 1, 1]),
&Dim([1, 2, 3, 1]),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
}
}
2 changes: 1 addition & 1 deletion src/dimension/mod.rs
Expand Up @@ -29,7 +29,7 @@ use std::mem;
mod macros;
mod axes;
mod axis;
mod broadcast;
pub(crate) mod broadcast;
mod conversion;
pub mod dim;
mod dimension_trait;
Expand Down
3 changes: 2 additions & 1 deletion src/impl_methods.rs
Expand Up @@ -21,6 +21,7 @@ use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
};
use crate::dimension::broadcast::co_broadcast;
use crate::error::{self, ErrorKind, ShapeError, from_kind};
use crate::math_cell::MathCell;
use crate::itertools::zip;
Expand Down Expand Up @@ -1778,7 +1779,7 @@ where
D: Dimension + DimMax<E>,
E: Dimension,
{
let shape = self.dim.broadcast_shape(&other.dim)?;
let shape = co_broadcast::<D, E, <D as DimMax<E>>::Output>(&self.dim, &other.dim)?;
if let Some(view1) = self.broadcast(shape.clone()) {
if let Some(view2) = other.broadcast(shape) {
return Ok((view1, view2))
Expand Down
50 changes: 1 addition & 49 deletions tests/dimension.rs
Expand Up @@ -2,8 +2,7 @@

use defmac::defmac;

use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, Ix0, IxDyn, IxDynImpl, RemoveAxis,
ErrorKind, ShapeError, DimMax};
use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IxDyn, RemoveAxis};

use std::hash::{Hash, Hasher};

Expand Down Expand Up @@ -341,50 +340,3 @@ fn test_all_ndindex() {
ndindex!(10, 4, 3, 2, 2);
ndindex!(10, 4, 3, 2, 2, 2);
}

#[test]
fn test_broadcast_shape() {
fn test_co<D1, D2>(
d1: &D1,
d2: &D2,
r: Result<<D1 as DimMax<D2>>::Output, ShapeError>,
) where
D1: Dimension + DimMax<D2>,
D2: Dimension,
{
let d = d1.broadcast_shape(d2);
assert_eq!(d, r);
}
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
test_co(
&Dim([1, 2, 2]),
&Dim([1, 3, 4]),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
let v = vec![1, 2, 3, 4, 5, 6, 7];
test_co(
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
&Dim([2, 1, 4, 1, 6, 1]),
Ok(Dim(IxDynImpl::from(v.as_slice()))),
);
let d = Dim([1, 2, 1, 3]);
test_co(&d, &d, Ok(d));
test_co(
&Dim([2, 1, 2]).into_dyn(),
&Dim(0),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
test_co(
&Dim([2, 1, 1]),
&Dim([0, 0, 1, 3, 4]),
Ok(Dim([0, 0, 2, 3, 4])),
);
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
test_co(
&Dim([1, 3, 0, 1, 1]),
&Dim([1, 2, 3, 1]),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
}

0 comments on commit 38f7341

Please sign in to comment.