Skip to content

Commit

Permalink
Modify the docs and visibility of broadcast_with
Browse files Browse the repository at this point in the history
  • Loading branch information
SparrowLii committed Feb 19, 2021
1 parent 6c40d61 commit 0dbeaf3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 42 deletions.
15 changes: 3 additions & 12 deletions src/impl_methods.rs
Expand Up @@ -1707,20 +1707,11 @@ where
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
}

/// Calculate the views of two ArrayBases after broadcasting each other, if possible.
/// For two arrays or views, find their common shape if possible and
/// broadcast them as array views into that shape.
///
/// Return `ShapeError` if their shapes can not be broadcast together.
///
/// ```
/// use ndarray::{arr1, arr2};
///
/// let a = arr2(&[[2], [3], [4]]);
/// let b = arr1(&[5, 6, 7]);
/// let (a1, b1) = a.broadcast_with(&b).unwrap();
/// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
/// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
/// ```
pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
Result<(ArrayView<'a, A, <D as BroadcastShape<E>>::Output>, ArrayView<'b, B, <D as BroadcastShape<E>>::Output>), ShapeError>
where
S: Data<Elem=A>,
Expand Down
30 changes: 0 additions & 30 deletions tests/broadcast.rs
Expand Up @@ -82,33 +82,3 @@ fn test_broadcast_1d() {
println!("b2=\n{:?}", b2);
assert_eq!(b0, b2);
}

#[test]
fn test_broadcast_with() {
let a = arr2(&[[1., 2.], [3., 4.]]);
let b = aview0(&1.);
let (a1, b1) = a.broadcast_with(&b).unwrap();
assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]]));
assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]]));

let a = arr2(&[[2], [3], [4]]);
let b = arr1(&[5, 6, 7]);
let (a1, b1) = a.broadcast_with(&b).unwrap();
assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));

// Negative strides and non-contiguous memory
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap();
let a = s.slice(s![..;-1,..;2,..]);
let b = s.slice(s![..2, -1, ..]);
let (a1, b1) = a.broadcast_with(&b).unwrap();
assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]]));
assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]]));

// ShapeError
let a = arr2(&[[2, 2], [3, 3], [4, 4]]);
let b = arr1(&[5, 6, 7]);
let e = a.broadcast_with(&b);
assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
}

0 comments on commit 0dbeaf3

Please sign in to comment.