Skip to content

Commit

Permalink
use view().into_dimensionality() instead
Browse files Browse the repository at this point in the history
  • Loading branch information
SparrowLii committed Apr 2, 2021
1 parent d810f1c commit cc16c86
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
20 changes: 19 additions & 1 deletion benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ extern crate test;

use std::mem::MaybeUninit;

use ndarray::ShapeBuilder;
use ndarray::{ShapeBuilder, Array3, Array4};
use ndarray::{arr0, arr1, arr2, azip, s};
use ndarray::{Array, Array1, Array2, Axis, Ix, Zip};
use ndarray::{Ix1, Ix2, Ix3, Ix5, IxDyn};
Expand Down Expand Up @@ -998,3 +998,21 @@ fn into_dyn_dyn(bench: &mut test::Bencher) {
let a = a.view();
bench.iter(|| a.clone().into_dyn());
}

#[bench]
fn broadcast_same_dim(bench: &mut test::Bencher) {
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let s = Array4::from_shape_vec((2, 2, 3, 2), s.to_vec()).unwrap();
let a = s.slice(s![.., ..;-1, ..;2, ..]);
let b = s.slice(s![.., .., ..;2, ..]);
bench.iter(|| &a + &b);
}

#[bench]
fn broadcast_one_side(bench: &mut test::Bencher) {
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let s2 = [1 ,2 ,3 ,4 ,5 ,6];
let a = Array4::from_shape_vec((4, 1, 3, 2), s.to_vec()).unwrap();
let b = Array3::from_shape_vec((1, 3, 2), s2.to_vec()).unwrap();
bench.iter(|| &a + &b);
}
33 changes: 9 additions & 24 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1805,35 +1805,20 @@ where
{
let shape = co_broadcast::<D, E, <D as DimMax<E>>::Output>(&self.dim, &other.dim)?;
let view1 = if shape.slice() == self.dim.slice() {
self.to_dimensionality::<<D as DimMax<E>>::Output>()
self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap()
} else if let Some(view1) = self.broadcast(shape.clone()) {
view1
} else {
self.broadcast(shape.clone())
return Err(from_kind(ErrorKind::IncompatibleShape))
};
let view2 = if shape.slice() == other.dim.slice() {
other.to_dimensionality::<<D as DimMax<E>>::Output>()
other.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap()
} else if let Some(view2) = other.broadcast(shape) {
view2
} else {
other.broadcast(shape)
return Err(from_kind(ErrorKind::IncompatibleShape))
};
if let Some(view1) = view1 {
if let Some(view2) = view2 {
return Ok((view1, view2));
}
}
Err(from_kind(ErrorKind::IncompatibleShape))
}

/// Creat an array view from an array with the same shape, but different dimensionality
/// type. Return None if the numbers of axes mismatch.
#[inline]
pub(crate) fn to_dimensionality<D2>(&self) -> Option<ArrayView<'_, A, D2>>
where
D2: Dimension,
S: Data,
{
let dim = <D2>::from_dimension(&self.dim)?;
let strides = <D2>::from_dimension(&self.strides)?;

unsafe { Some(ArrayView::new(self.ptr, dim, strides)) }
Ok((view1, view2))
}

/// Swap axes `ax` and `bx`.
Expand Down
4 changes: 2 additions & 2 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ where
type Output = Array<A, <D as DimMax<E>>::Output>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let lhs = self.to_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
let rhs = rhs.to_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
(lhs, rhs)
} else {
self.broadcast_with(rhs).unwrap()
Expand Down

0 comments on commit cc16c86

Please sign in to comment.