Skip to content

Commit

Permalink
FIX: Rename BroadcastShape to DimMax
Browse files Browse the repository at this point in the history
For consistency with other dimension traits (to come);
with <A as DimMax<B>>, Output is the maximum of A and B.
  • Loading branch information
bluss committed Mar 12, 2021
1 parent 16f382b commit b39593e
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 30 deletions.
10 changes: 5 additions & 5 deletions src/dimension/broadcast.rs
Expand Up @@ -34,7 +34,7 @@ fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, Shap
Ok(out)
}

pub trait BroadcastShape<Other: Dimension> {
pub trait DimMax<Other: Dimension> {
/// The resulting dimension type after broadcasting.
type Output: Dimension;

Expand All @@ -46,8 +46,8 @@ pub trait BroadcastShape<Other: Dimension> {

/// Dimensions of the same type remain unchanged when co_broadcast.
/// So you can directly use D as the resulting type.
/// (Instead of <D as BroadcastShape<D>>::BroadcastOutput)
impl<D: Dimension> BroadcastShape<D> for D {
/// (Instead of <D as DimMax<D>>::BroadcastOutput)
impl<D: Dimension> DimMax<D> for D {
type Output = D;

fn broadcast_shape(&self, other: &D) -> Result<Self::Output, ShapeError> {
Expand All @@ -57,15 +57,15 @@ impl<D: Dimension> BroadcastShape<D> for D {

macro_rules! impl_broadcast_distinct_fixed {
($smaller:ty, $larger:ty) => {
impl BroadcastShape<$larger> for $smaller {
impl DimMax<$larger> for $smaller {
type Output = $larger;

fn broadcast_shape(&self, other: &$larger) -> Result<Self::Output, ShapeError> {
co_broadcast::<Self, $larger, Self::Output>(self, other)
}
}

impl BroadcastShape<$smaller> for $larger {
impl DimMax<$smaller> for $larger {
type Output = $larger;

fn broadcast_shape(&self, other: &$smaller) -> Result<Self::Output, ShapeError> {
Expand Down
12 changes: 6 additions & 6 deletions src/dimension/dimension_trait.rs
Expand Up @@ -15,7 +15,7 @@ use super::axes_of;
use super::conversion::Convert;
use super::{stride_offset, stride_offset_checked};
use crate::itertools::{enumerate, zip};
use crate::{Axis, BroadcastShape};
use crate::{Axis, DimMax};
use crate::IntoDimension;
use crate::RemoveAxis;
use crate::{ArrayView1, ArrayViewMut1};
Expand Down Expand Up @@ -46,11 +46,11 @@ pub trait Dimension:
+ MulAssign
+ for<'x> MulAssign<&'x Self>
+ MulAssign<usize>
+ BroadcastShape<Ix0, Output=Self>
+ BroadcastShape<Self, Output=Self>
+ BroadcastShape<IxDyn, Output=IxDyn>
+ BroadcastShape<<Self as Dimension>::Smaller, Output=Self>
+ BroadcastShape<<Self as Dimension>::Larger, Output=<Self as Dimension>::Larger>
+ DimMax<Ix0, Output=Self>
+ DimMax<Self, Output=Self>
+ DimMax<IxDyn, Output=IxDyn>
+ DimMax<<Self as Dimension>::Smaller, Output=Self>
+ DimMax<<Self as Dimension>::Larger, Output=<Self as Dimension>::Larger>
{
/// For fixed-size dimension representations (e.g. `Ix2`), this should be
/// `Some(ndim)`, and for variable-size dimension representations (e.g.
Expand Down
2 changes: 1 addition & 1 deletion src/dimension/mod.rs
Expand Up @@ -12,7 +12,7 @@ use num_integer::div_floor;

pub use self::axes::{axes_of, Axes, AxisDescription};
pub use self::axis::Axis;
pub use self::broadcast::BroadcastShape;
pub use self::broadcast::DimMax;
pub use self::conversion::IntoDimension;
pub use self::dim::*;
pub use self::dimension_trait::Dimension;
Expand Down
6 changes: 3 additions & 3 deletions src/impl_methods.rs
Expand Up @@ -14,7 +14,7 @@ use rawpointer::PointerExt;

use crate::imp_prelude::*;

use crate::{arraytraits, BroadcastShape};
use crate::{arraytraits, DimMax};
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
Expand Down Expand Up @@ -1771,11 +1771,11 @@ where
///
/// Return `ShapeError` if their shapes can not be broadcast together.
pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
Result<(ArrayView<'a, A, <D as BroadcastShape<E>>::Output>, ArrayView<'b, B, <D as BroadcastShape<E>>::Output>), ShapeError>
Result<(ArrayView<'a, A, <D as DimMax<E>>::Output>, ArrayView<'b, B, <D as DimMax<E>>::Output>), ShapeError>
where
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
D: Dimension + DimMax<E>,
E: Dimension,
{
let shape = self.dim.broadcast_shape(&other.dim)?;
Expand Down
22 changes: 11 additions & 11 deletions src/impl_ops.rs
Expand Up @@ -6,7 +6,7 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::dimension::BroadcastShape;
use crate::dimension::DimMax;
use crate::Zip;
use num_complex::Complex;

Expand Down Expand Up @@ -68,10 +68,10 @@ where
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
D: Dimension + DimMax<E>,
E: Dimension,
{
type Output = ArrayBase<S, <D as BroadcastShape<E>>::Output>;
type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
{
self.$mth(&rhs)
Expand All @@ -95,14 +95,14 @@ where
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
D: Dimension + DimMax<E>,
E: Dimension,
{
type Output = ArrayBase<S, <D as BroadcastShape<E>>::Output>;
type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
{
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = self.into_dimensionality::<<D as BroadcastShape<E>>::Output>().unwrap();
let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
out
} else {
Expand Down Expand Up @@ -130,14 +130,14 @@ where
S: Data<Elem=A>,
S2: DataOwned<Elem=B> + DataMut,
D: Dimension,
E: Dimension + BroadcastShape<D>,
E: Dimension + DimMax<D>,
{
type Output = ArrayBase<S2, <E as BroadcastShape<D>>::Output>;
type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
where
{
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = rhs.into_dimensionality::<<E as BroadcastShape<D>>::Output>().unwrap();
let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
out
} else {
Expand All @@ -162,10 +162,10 @@ where
B: Clone,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
D: Dimension + DimMax<E>,
E: Dimension,
{
type Output = Array<A, <D as BroadcastShape<E>>::Output>;
type Output = Array<A, <D as DimMax<E>>::Output>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Expand Up @@ -134,7 +134,7 @@ use std::marker::PhantomData;
use alloc::sync::Arc;

pub use crate::dimension::dim::*;
pub use crate::dimension::BroadcastShape;
pub use crate::dimension::DimMax;
pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis};

pub use crate::dimension::IxDynImpl;
Expand Down
6 changes: 3 additions & 3 deletions tests/dimension.rs
Expand Up @@ -3,7 +3,7 @@
use defmac::defmac;

use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, Ix0, IxDyn, IxDynImpl, RemoveAxis,
ErrorKind, ShapeError, BroadcastShape};
ErrorKind, ShapeError, DimMax};

use std::hash::{Hash, Hasher};

Expand Down Expand Up @@ -347,9 +347,9 @@ fn test_broadcast_shape() {
fn test_co<D1, D2>(
d1: &D1,
d2: &D2,
r: Result<<D1 as BroadcastShape<D2>>::Output, ShapeError>,
r: Result<<D1 as DimMax<D2>>::Output, ShapeError>,
) where
D1: Dimension + BroadcastShape<D2>,
D1: Dimension + DimMax<D2>,
D2: Dimension,
{
let d = d1.broadcast_shape(d2);
Expand Down

0 comments on commit b39593e

Please sign in to comment.