Skip to content

Commit

Permalink
treat zero dimension like numpy does
Browse files Browse the repository at this point in the history
  • Loading branch information
SparrowLii committed Jan 27, 2021
1 parent 5eec58d commit e2d8008
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
8 changes: 3 additions & 5 deletions src/dimension/broadcast.rs
Expand Up @@ -23,15 +23,13 @@ where
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
//
// Zero dimension element is not in the original rules of broadcasting.
// We currently treat it as the same as 1. Especially, when one side is
// zero with one side is empty, or both sides are zero, the result will
// remain zero.
// We currently treat it like any other number greater than 1. As numpy does.
for i in 0..shape1.ndim() {
out_slice[i] = s1[i];
}
for i in 0..shape2.ndim() {
if out_slice[i + k] != s2[i] && s2[i] != 0 {
if out_slice[i + k] <= 1 {
if out_slice[i + k] != s2[i] {
if out_slice[i + k] == 1 {
out_slice[i + k] = s2[i]
} else if s2[i] != 1 {
return Err(from_kind(ErrorKind::IncompatibleShape));
Expand Down
4 changes: 2 additions & 2 deletions src/numeric/impl_numeric.rs
Expand Up @@ -243,7 +243,7 @@ where
/// **Panics** if `axis` is out of bounds.
pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
where
A: Copy + Zero + Add<Output = A>,
A: Clone + Zero + Add<Output = A>,
D: RemoveAxis,
{
let n = self.len_of(axis);
Expand Down Expand Up @@ -285,7 +285,7 @@ where
/// ```
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
where
A: Copy + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
D: RemoveAxis,
{
let axis_length = self.len_of(axis);
Expand Down
8 changes: 4 additions & 4 deletions tests/array.rs
Expand Up @@ -1589,19 +1589,19 @@ fn test_broadcast_shape() {
test_co(
&Dim([2, 1, 2]).into_dyn(),
&Dim(0),
Ok(Dim([2, 1, 2]).into_dyn()),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
test_co(
&Dim([2, 1, 1]),
&Dim([0, 0, 0, 3, 4]),
&Dim([0, 0, 1, 3, 4]),
Ok(Dim([0, 0, 2, 3, 4])),
);
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 1])));
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
test_co(
&Dim([1, 3, 0, 1, 1]),
&Dim([1, 2, 3, 1]),
Ok(Dim([1, 3, 2, 3, 1])),
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
);
}

Expand Down

0 comments on commit e2d8008

Please sign in to comment.