Skip to content

Commit

Permalink
Add function broadcast_with
Browse files Browse the repository at this point in the history
  • Loading branch information
SparrowLii authored and bluss committed Mar 12, 2021
1 parent e3b73cc commit b623239
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 11 deletions.
34 changes: 32 additions & 2 deletions src/impl_methods.rs
Expand Up @@ -14,14 +14,14 @@ use rawpointer::PointerExt;

use crate::imp_prelude::*;

use crate::arraytraits;
use crate::{arraytraits, BroadcastShape};
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
};
use crate::error::{self, ErrorKind, ShapeError};
use crate::error::{self, ErrorKind, ShapeError, from_kind};
use crate::math_cell::MathCell;
use crate::itertools::zip;
use crate::zip::Zip;
Expand Down Expand Up @@ -1766,6 +1766,36 @@ where
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
}

/// Calculate the views of two ArrayBases after broadcasting each other, if possible.
///
/// Return `ShapeError` if their shapes can not be broadcast together.
///
/// ```
/// use ndarray::{arr1, arr2};
///
/// let a = arr2(&[[2], [3], [4]]);
/// let b = arr1(&[5, 6, 7]);
/// let (a1, b1) = a.broadcast_with(&b).unwrap();
/// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
/// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
/// ```
pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
Result<(ArrayView<'a, A, <D as BroadcastShape<E>>::Output>, ArrayView<'b, B, <D as BroadcastShape<E>>::Output>), ShapeError>
where
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
let shape = self.dim.broadcast_shape(&other.dim)?;
if let Some(view1) = self.broadcast(shape.clone()) {
if let Some(view2) = other.broadcast(shape) {
return Ok((view1, view2))
}
}
return Err(from_kind(ErrorKind::IncompatibleShape));
}

/// Swap axes `ax` and `bx`.
///
/// This does not move any data, it just adjusts the array’s dimensions
Expand Down
12 changes: 3 additions & 9 deletions src/impl_ops.rs
Expand Up @@ -106,9 +106,7 @@ where
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
out
} else {
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape).unwrap();
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
}
}
Expand Down Expand Up @@ -143,9 +141,7 @@ where
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
out
} else {
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape).unwrap();
let (rhs, lhs) = rhs.broadcast_with(self).unwrap();
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
}
}
Expand All @@ -171,9 +167,7 @@ where
{
type Output = Array<A, <D as BroadcastShape<E>>::Output>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape).unwrap();
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
}
}
Expand Down
31 changes: 31 additions & 0 deletions tests/broadcast.rs
@@ -1,4 +1,5 @@
use ndarray::prelude::*;
use ndarray::{ShapeError, ErrorKind, arr3};

#[test]
#[cfg(feature = "std")]
Expand Down Expand Up @@ -81,3 +82,33 @@ fn test_broadcast_1d() {
println!("b2=\n{:?}", b2);
assert_eq!(b0, b2);
}

#[test]
fn test_broadcast_with() {
let a = arr2(&[[1., 2.], [3., 4.]]);
let b = aview0(&1.);
let (a1, b1) = a.broadcast_with(&b).unwrap();
assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]]));
assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]]));

let a = arr2(&[[2], [3], [4]]);
let b = arr1(&[5, 6, 7]);
let (a1, b1) = a.broadcast_with(&b).unwrap();
assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));

// Negative strides and non-contiguous memory
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap();
let a = s.slice(s![..;-1,..;2,..]);
let b = s.slice(s![..2, -1, ..]);
let (a1, b1) = a.broadcast_with(&b).unwrap();
assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]]));
assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]]));

// ShapeError
let a = arr2(&[[2, 2], [3, 3], [4, 4]]);
let b = arr1(&[5, 6, 7]);
let e = a.broadcast_with(&b);
assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
}

0 comments on commit b623239

Please sign in to comment.