From e2d8008bb0762778c2130492a7171da9d671444b Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Wed, 27 Jan 2021 11:16:14 +0800 Subject: [PATCH] treat zero dimension like numpy does --- src/dimension/broadcast.rs | 8 +++----- src/numeric/impl_numeric.rs | 4 ++-- tests/array.rs | 8 ++++---- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index d13e2ac3a..7c158f5b2 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -23,15 +23,13 @@ where // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). // // Zero dimension element is not in the original rules of broadcasting. - // We currently treat it as the same as 1. Especially, when one side is - // zero with one side is empty, or both sides are zero, the result will - // remain zero. + // We currently treat it like any other number greater than 1. As numpy does. for i in 0..shape1.ndim() { out_slice[i] = s1[i]; } for i in 0..shape2.ndim() { - if out_slice[i + k] != s2[i] && s2[i] != 0 { - if out_slice[i + k] <= 1 { + if out_slice[i + k] != s2[i] { + if out_slice[i + k] == 1 { out_slice[i + k] = s2[i] } else if s2[i] != 1 { return Err(from_kind(ErrorKind::IncompatibleShape)); diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 39ec24a74..9d1cef7e1 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -243,7 +243,7 @@ where /// **Panics** if `axis` is out of bounds. pub fn sum_axis(&self, axis: Axis) -> Array where - A: Copy + Zero + Add, + A: Clone + Zero + Add, D: RemoveAxis, { let n = self.len_of(axis); @@ -285,7 +285,7 @@ where /// ``` pub fn mean_axis(&self, axis: Axis) -> Option> where - A: Copy + Zero + FromPrimitive + Add + Div, + A: Clone + Zero + FromPrimitive + Add + Div, D: RemoveAxis, { let axis_length = self.len_of(axis); diff --git a/tests/array.rs b/tests/array.rs index bd9b10ac4..2be897d31 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1589,19 +1589,19 @@ fn test_broadcast_shape() { test_co( &Dim([2, 1, 2]).into_dyn(), &Dim(0), - Ok(Dim([2, 1, 2]).into_dyn()), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), ); test_co( &Dim([2, 1, 1]), - &Dim([0, 0, 0, 3, 4]), + &Dim([0, 0, 1, 3, 4]), Ok(Dim([0, 0, 2, 3, 4])), ); test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); - test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 1]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); test_co( &Dim([1, 3, 0, 1, 1]), &Dim([1, 2, 3, 1]), - Ok(Dim([1, 3, 2, 3, 1])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), ); }