diff --git a/Cargo.toml b/Cargo.toml index 2f4f8bb0a..58ede0171 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ serde = { version = "1.0", optional = true } defmac = "0.2" quickcheck = { version = "0.7.2", default-features = false } rawpointer = "0.1" +approx = "0.3" [features] # Enable blas usage diff --git a/examples/column_standardize.rs b/examples/column_standardize.rs index 032c520e3..80ca7332a 100644 --- a/examples/column_standardize.rs +++ b/examples/column_standardize.rs @@ -23,9 +23,9 @@ fn main() { [ 2., 2., 2.]]; println!("{:8.4}", data); - println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0))); + println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)).unwrap()); - data -= &data.mean_axis(Axis(0)); + data -= &data.mean_axis(Axis(0)).unwrap(); println!("{:8.4}", data); data /= &std(&data, Axis(0)); diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index f8c1f04a3..87e7be590 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -46,6 +46,33 @@ impl ArrayBase sum } + /// Returns the [arithmetic mean] x̅ of all elements in the array: + /// + /// ```text + /// 1 n + /// x̅ = ― ∑ xᵢ + /// n i=1 + /// ``` + /// + /// If the array is empty, `None` is returned. + /// + /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. + /// + /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean + pub fn mean(&self) -> Option + where + A: Clone + FromPrimitive + Add + Div + Zero + { + let n_elements = self.len(); + if n_elements == 0 { + None + } else { + let n_elements = A::from_usize(n_elements) + .expect("Converting number of elements to `A` must not fail."); + Some(self.sum() / n_elements) + } + } + /// Return the sum of all elements in the array. /// /// *This method has been renamed to `.sum()` and will be deprecated in the @@ -123,8 +150,9 @@ impl ArrayBase /// Return mean along `axis`. /// - /// **Panics** if `axis` is out of bounds, if the length of the axis is - /// zero and division by zero panics for type `A`, or if `A::from_usize()` + /// Return `None` if the length of the axis is zero. + /// + /// **Panics** if `axis` is out of bounds or if `A::from_usize()` /// fails for the axis length. /// /// ``` @@ -133,19 +161,25 @@ impl ArrayBase /// let a = arr2(&[[1., 2., 3.], /// [4., 5., 6.]]); /// assert!( - /// a.mean_axis(Axis(0)) == aview1(&[2.5, 3.5, 4.5]) && - /// a.mean_axis(Axis(1)) == aview1(&[2., 5.]) && + /// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) && + /// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) && /// - /// a.mean_axis(Axis(0)).mean_axis(Axis(0)) == aview0(&3.5) + /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5) /// ); /// ``` - pub fn mean_axis(&self, axis: Axis) -> Array + pub fn mean_axis(&self, axis: Axis) -> Option> where A: Clone + Zero + FromPrimitive + Add + Div, D: RemoveAxis, { - let n = A::from_usize(self.len_of(axis)).expect("Converting axis length to `A` must not fail."); - let sum = self.sum_axis(axis); - sum / &aview0(&n) + let axis_length = self.len_of(axis); + if axis_length == 0 { + None + } else { + let axis_length = A::from_usize(axis_length) + .expect("Converting axis length to `A` must not fail."); + let sum = self.sum_axis(axis); + Some(sum / &aview0(&axis_length)) + } } /// Return variance along `axis`. diff --git a/tests/array.rs b/tests/array.rs index 28e2e7fbc..bbe05bbda 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -925,175 +925,6 @@ fn assign() assert_eq!(a, arr2(&[[0, 0], [3, 4]])); } -#[test] -fn sum_mean() -{ - let a = arr2(&[[1., 2.], [3., 4.]]); - assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.])); - assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.])); - assert_eq!(a.mean_axis(Axis(0)), arr1(&[2., 3.])); - assert_eq!(a.mean_axis(Axis(1)), arr1(&[1.5, 3.5])); - assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.)); - assert_eq!(a.view().mean_axis(Axis(1)), aview1(&[1.5, 3.5])); - assert_eq!(a.sum(), 10.); -} - -#[test] -fn sum_mean_empty() { - assert_eq!(Array3::::ones((2, 0, 3)).sum(), 0.); - assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); - assert_eq!( - Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), - Array::zeros((2, 3)), - ); - let a = Array1::::ones(0).mean_axis(Axis(0)); - assert_eq!(a.shape(), &[]); - assert!(a[()].is_nan()); - let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1)); - assert_eq!(a.shape(), &[2, 3]); - a.mapv(|x| assert!(x.is_nan())); -} - -#[test] -fn var_axis() { - let a = array![ - [ - [-9.76, -0.38, 1.59, 6.23], - [-8.57, -9.27, 5.76, 6.01], - [-9.54, 5.09, 3.21, 6.56], - ], - [ - [ 8.23, -9.63, 3.76, -3.48], - [-5.46, 5.86, -2.81, 1.35], - [-1.08, 4.66, 8.34, -0.73], - ], - ]; - assert!(a.var_axis(Axis(0), 1.5).all_close( - &aview2(&[ - [3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01], - [9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01], - [7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01] - ]), - 1e-4, - )); - assert!(a.var_axis(Axis(1), 1.7).all_close( - &aview2(&[ - [0.61676923, 80.81092308, 6.79892308, 0.11789744], - [75.19912821, 114.25235897, 48.32405128, 9.03020513], - ]), - 1e-8, - )); - assert!(a.var_axis(Axis(2), 2.3).all_close( - &aview2(&[ - [ 79.64552941, 129.09663235, 95.98929412], - [109.64952941, 43.28758824, 36.27439706], - ]), - 1e-8, - )); - - let b = array![[1.1, 2.3, 4.7]]; - assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12)); - assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12)); - - let c = array![[], []]; - assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[])); - - let d = array![1.1, 2.7, 3.5, 4.9]; - assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12)); -} - -#[test] -fn std_axis() { - let a = array![ - [ - [ 0.22935481, 0.08030619, 0.60827517, 0.73684379], - [ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ], - [ 0.44485313, 0.63316367, 0.11005111, 0.08656246] - ], - [ - [ 0.28924665, 0.44082454, 0.59837736, 0.41014531], - [ 0.08382316, 0.43259439, 0.1428889 , 0.44830176], - [ 0.51529756, 0.70111616, 0.20799415, 0.91851457] - ], - ]; - assert!(a.std_axis(Axis(0), 1.5).all_close( - &aview2(&[ - [ 0.05989184, 0.36051836, 0.00989781, 0.32669847], - [ 0.81957535, 0.39599997, 0.49731472, 0.17084346], - [ 0.07044443, 0.06795249, 0.09794304, 0.83195211], - ]), - 1e-4, - )); - assert!(a.std_axis(Axis(1), 1.7).all_close( - &aview2(&[ - [ 0.42698655, 0.48139215, 0.36874991, 0.41458724], - [ 0.26769097, 0.18941435, 0.30555015, 0.35118674], - ]), - 1e-8, - )); - assert!(a.std_axis(Axis(2), 2.3).all_close( - &aview2(&[ - [ 0.41117907, 0.37130425, 0.35332388], - [ 0.16905862, 0.25304841, 0.39978276], - ]), - 1e-8, - )); - - let b = array![[100000., 1., 0.01]]; - assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12)); - assert!( - b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6), - ); - - let c = array![[], []]; - assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[])); -} - -#[test] -#[should_panic] -fn var_axis_negative_ddof() { - let a = array![1., 2., 3.]; - a.var_axis(Axis(0), -1.); -} - -#[test] -#[should_panic] -fn var_axis_too_large_ddof() { - let a = array![1., 2., 3.]; - a.var_axis(Axis(0), 4.); -} - -#[test] -fn var_axis_nan_ddof() { - let a = Array2::::zeros((2, 3)); - let v = a.var_axis(Axis(1), ::std::f64::NAN); - assert_eq!(v.shape(), &[2]); - v.mapv(|x| assert!(x.is_nan())); -} - -#[test] -fn var_axis_empty_axis() { - let a = Array2::::zeros((2, 0)); - let v = a.var_axis(Axis(1), 0.); - assert_eq!(v.shape(), &[2]); - v.mapv(|x| assert!(x.is_nan())); -} - -#[test] -#[should_panic] -fn std_axis_bad_dof() { - let a = array![1., 2., 3.]; - a.std_axis(Axis(0), 4.); -} - -#[test] -fn std_axis_empty_axis() { - let a = Array2::::zeros((2, 0)); - let v = a.std_axis(Axis(1), 0.); - assert_eq!(v.shape(), &[2]); - v.mapv(|x| assert!(x.is_nan())); -} - #[test] fn iter_size_hint() { diff --git a/tests/complex.rs b/tests/complex.rs index 8da721c4d..a7449e9f8 100644 --- a/tests/complex.rs +++ b/tests/complex.rs @@ -22,5 +22,5 @@ fn complex_mat_mul() let r = a.dot(&e); println!("{}", a); assert_eq!(r, a); - assert_eq!(a.mean_axis(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)])); + assert_eq!(a.mean_axis(Axis(0)).unwrap(), arr1(&[c(1.5, 1.), c(2.5, 0.)])); } diff --git a/tests/numeric.rs b/tests/numeric.rs new file mode 100644 index 000000000..e73da4904 --- /dev/null +++ b/tests/numeric.rs @@ -0,0 +1,203 @@ +extern crate approx; +use std::f64; +use ndarray::{array, Axis, aview1, aview2, aview0, arr0, arr1, arr2, Array, Array1, Array2, Array3}; +use approx::abs_diff_eq; + +#[test] +fn test_mean_with_nan_values() { + let a = array![f64::NAN, 1.]; + assert!(a.mean().unwrap().is_nan()); +} + +#[test] +fn test_mean_with_empty_array_of_floats() { + let a: Array1 = array![]; + assert!(a.mean().is_none()); +} + +#[test] +fn test_mean_with_array_of_floats() { + let a: Array1 = array![ + 0.99889651, 0.0150731 , 0.28492482, 0.83819218, 0.48413156, + 0.80710412, 0.41762936, 0.22879429, 0.43997224, 0.23831807, + 0.02416466, 0.6269962 , 0.47420614, 0.56275487, 0.78995021, + 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, + 0.34429457, 0.88072369, 0.17638164, 0.60819363, 0.250392 , + 0.69912532, 0.78855523, 0.79140914, 0.85084218, 0.31839879, + 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, + 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, + 0.63608897, 0.84959691, 0.43599069, 0.77867775, 0.88267754, + 0.83003623, 0.67016118, 0.67547638, 0.65220036, 0.68043427 + ]; + // Computed using NumPy + let expected_mean = 0.5475494059146699; + abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = f64::EPSILON); +} + +#[test] +fn sum_mean() +{ + let a = arr2(&[[1., 2.], [3., 4.]]); + assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.])); + assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.])); + assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.]))); + assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5]))); + assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.)); + assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5])); + assert_eq!(a.sum(), 10.); +} + +#[test] +fn sum_mean_empty() { + assert_eq!(Array3::::ones((2, 0, 3)).sum(), 0.); + assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); + assert_eq!( + Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), + Array::zeros((2, 3)), + ); + let a = Array1::::ones(0).mean_axis(Axis(0)); + assert_eq!(a, None); + let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1)); + assert_eq!(a, None); +} + +#[test] +fn var_axis() { + let a = array![ + [ + [-9.76, -0.38, 1.59, 6.23], + [-8.57, -9.27, 5.76, 6.01], + [-9.54, 5.09, 3.21, 6.56], + ], + [ + [ 8.23, -9.63, 3.76, -3.48], + [-5.46, 5.86, -2.81, 1.35], + [-1.08, 4.66, 8.34, -0.73], + ], + ]; + assert!(a.var_axis(Axis(0), 1.5).all_close( + &aview2(&[ + [3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01], + [9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01], + [7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01] + ]), + 1e-4, + )); + assert!(a.var_axis(Axis(1), 1.7).all_close( + &aview2(&[ + [0.61676923, 80.81092308, 6.79892308, 0.11789744], + [75.19912821, 114.25235897, 48.32405128, 9.03020513], + ]), + 1e-8, + )); + assert!(a.var_axis(Axis(2), 2.3).all_close( + &aview2(&[ + [ 79.64552941, 129.09663235, 95.98929412], + [109.64952941, 43.28758824, 36.27439706], + ]), + 1e-8, + )); + + let b = array![[1.1, 2.3, 4.7]]; + assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12)); + assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12)); + + let c = array![[], []]; + assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[])); + + let d = array![1.1, 2.7, 3.5, 4.9]; + assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12)); +} + +#[test] +fn std_axis() { + let a = array![ + [ + [ 0.22935481, 0.08030619, 0.60827517, 0.73684379], + [ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ], + [ 0.44485313, 0.63316367, 0.11005111, 0.08656246] + ], + [ + [ 0.28924665, 0.44082454, 0.59837736, 0.41014531], + [ 0.08382316, 0.43259439, 0.1428889 , 0.44830176], + [ 0.51529756, 0.70111616, 0.20799415, 0.91851457] + ], + ]; + assert!(a.std_axis(Axis(0), 1.5).all_close( + &aview2(&[ + [ 0.05989184, 0.36051836, 0.00989781, 0.32669847], + [ 0.81957535, 0.39599997, 0.49731472, 0.17084346], + [ 0.07044443, 0.06795249, 0.09794304, 0.83195211], + ]), + 1e-4, + )); + assert!(a.std_axis(Axis(1), 1.7).all_close( + &aview2(&[ + [ 0.42698655, 0.48139215, 0.36874991, 0.41458724], + [ 0.26769097, 0.18941435, 0.30555015, 0.35118674], + ]), + 1e-8, + )); + assert!(a.std_axis(Axis(2), 2.3).all_close( + &aview2(&[ + [ 0.41117907, 0.37130425, 0.35332388], + [ 0.16905862, 0.25304841, 0.39978276], + ]), + 1e-8, + )); + + let b = array![[100000., 1., 0.01]]; + assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12)); + assert!( + b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6), + ); + + let c = array![[], []]; + assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[])); +} + +#[test] +#[should_panic] +fn var_axis_negative_ddof() { + let a = array![1., 2., 3.]; + a.var_axis(Axis(0), -1.); +} + +#[test] +#[should_panic] +fn var_axis_too_large_ddof() { + let a = array![1., 2., 3.]; + a.var_axis(Axis(0), 4.); +} + +#[test] +fn var_axis_nan_ddof() { + let a = Array2::::zeros((2, 3)); + let v = a.var_axis(Axis(1), ::std::f64::NAN); + assert_eq!(v.shape(), &[2]); + v.mapv(|x| assert!(x.is_nan())); +} + +#[test] +fn var_axis_empty_axis() { + let a = Array2::::zeros((2, 0)); + let v = a.var_axis(Axis(1), 0.); + assert_eq!(v.shape(), &[2]); + v.mapv(|x| assert!(x.is_nan())); +} + +#[test] +#[should_panic] +fn std_axis_bad_dof() { + let a = array![1., 2., 3.]; + a.std_axis(Axis(0), 4.); +} + +#[test] +fn std_axis_empty_axis() { + let a = Array2::::zeros((2, 0)); + let v = a.std_axis(Axis(1), 0.); + assert_eq!(v.shape(), &[2]); + v.mapv(|x| assert!(x.is_nan())); +} +