Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue/655: Standard deviation and variance #753

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
108 changes: 108 additions & 0 deletions benches/varstd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#![feature(test)]

extern crate test;
use test::Bencher;

use ndarray::arr3;
use ndarray::prelude::*;
use std::iter::FromIterator;

#[rustfmt::skip]
fn big_array() -> Array2<f64> {
arr2(&[
[ 92., 53., 51., 94., 27., 69., 11., 13., 62., 42., 73., 83., 2. , 53., 77. ],
[ 65., 56., 11., 32., 95., 66., 88., 10., 37., 8. , 12., 2. , 59., 78., 48. ],
[ 20., 86., 71., 99., 1. , 76., 29., 53., 87., 88., 61., 84., 2. , 87., 90. ],
[ 19., 22., 44., 38., 85., 12., 8. , 38., 53., 46., 80., 70., 62., 14., 8. ],
[ 51., 70., 71., 21., 14., 48., 34., 4. , 27., 55., 60., 95., 1. , 79., 1. ],
[ 13., 23., 78., 97., 57., 16., 81., 31., 88., 15., 78., 95., 93., 9. , 6. ],
[ 68., 58., 4. , 11., 91., 56., 61., 15., 60., 92., 29., 27., 22., 30., 2. ],
[ 53., 70., 89., 42., 59., 79., 63., 61., 86., 48., 40., 50., 23., 18., 55. ],
[ 14., 96., 68., 16., 52., 16., 70., 12., 16., 60., 28., 52., 56., 12., 37. ],
[ 68., 73., 6. , 51., 54., 51., 97., 88., 36., 32., 83., 52., 53., 86., 4. ],
[ 88., 11., 86., 91., 83., 71., 18., 60., 95., 59., 85., 92., 34., 76., 93. ],
[ 81., 18., 47., 26., 53., 64., 53., 12., 55., 92., 76., 22., 81., 80., 21. ],
[ 86., 48., 42., 19., 94., 86., 16., 37., 74., 85., 11., 9. , 80., 2. , 80. ],
[ 51., 43., 55., 56., 49., 77., 78., 94., 80., 23., 72., 67., 58., 95., 95. ],
[ 92., 24., 45., 41., 33., 64., 89., 8. , 75., 42., 32., 61., 19., 11., 61. ],
[ 81., 35., 75., 67., 73., 30., 95., 17., 24., 48., 72., 2. , 46., 14., 50. ],
[ 99., 87., 41., 87., 68., 22., 94., 73., 82., 87., 86., 46., 36., 26., 57. ],
[ 96., 69., 28., 44., 32., 70., 94., 13., 85., 5. , 13., 44., 60., 79., 76. ],
[ 81., 92., 42., 93., 99., 41., 13., 8. , 68., 92., 89., 83., 16., 82., 92. ],
[ 29., 18., 10., 71., 4. , 20., 99., 10., 91., 51., 90., 78., 20., 25., 44. ],
[ 57., 56., 96., 81., 87., 57., 32., 22., 29., 63., 76., 39., 52., 77., 96. ],
[ 88., 2. , 56., 75., 72., 53., 0. , 57., 42., 83., 77., 85., 14., 15., 19. ],
])
}

#[bench]
fn var_into_shape_use_var_axis(bench: &mut Bencher) {
let a = arr3(&[
[[1., 2.], [1., 3.]],
[[3., 5.], [1., 3.]],
[[5., 7.], [1., 3.]],
]);

let len = a.len();
bench.iter(|| {
let flattened = a
.view()
.into_shape(len)
.expect("into_shape to a.len() can not fail.");
flattened.var_axis(Axis(0), 1.)
});
}

#[bench]
fn var_into_shape_use_var_axis_big(bench: &mut Bencher) {
let a = big_array();

let len = a.len();
bench.iter(|| {
let flattened = a
.view()
.into_shape(len)
.expect("into_shape to a.len() can not fail.");
flattened.var_axis(Axis(0), 1.)
});
}

#[bench]
fn var_flatten_user_var_axis(bench: &mut Bencher) {
let a = arr3(&[
[[1., 2.], [1., 3.]],
[[3., 5.], [1., 3.]],
[[5., 7.], [1., 3.]],
]);

bench.iter(|| {
let flattened = Array::from_iter(a.iter().map(|&x| x));
flattened.var_axis(Axis(0), 1.)
})
}

#[bench]
fn var_flatten_user_var_axis_big(bench: &mut Bencher) {
let a = big_array();
bench.iter(|| {
let flattened = Array::from_iter(a.iter().map(|&x| x));
flattened.var_axis(Axis(0), 1.)
})
}

#[bench]
fn var_new(bench: &mut Bencher) {
let a = arr3(&[
[[1., 2.], [1., 3.]],
[[3., 5.], [1., 3.]],
[[5., 7.], [1., 3.]],
]);

bench.iter(|| a.var(1.))
}

#[bench]
fn var_new_big(bench: &mut Bencher) {
let a = big_array();
bench.iter(|| a.var(1.))
}
85 changes: 67 additions & 18 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,25 +235,13 @@ where
A: Float + FromPrimitive,
D: RemoveAxis,
{
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
assert!(
!(ddof < zero || ddof > n),
"`ddof` must not be less than zero or greater than the length of \
the axis",
);
let dof = n - ddof;
let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
for (i, subview) in self.axis_iter(axis).enumerate() {
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
let delta = x - *mean;
*mean = *mean + delta / count;
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
let mut output = Array::zeros(self.dim.remove_axis(axis));
Zip::from(output.view_mut())
.and(self.lanes(axis))
.apply(|o, l| {
*o = l.var(ddof);
});
}
sum_sq.mapv_into(|s| s / dof)
output
}

/// Return standard deviation along `axis`.
Expand Down Expand Up @@ -306,6 +294,67 @@ where
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
}

/// Return variance for the flattened array.
///
/// This uses the same method as var_axis.
///
/// # Example
///
/// ```
/// use ndarray::{arr2, Axis};
///
/// let a = arr2(&[[1., 2.],
/// [3., 4.],
/// [5., 6.]]);
///
/// let a_flat = a.view().into_shape(6).expect("This must not fail.");
/// assert_eq!(a.var(1.), a_flat.var_axis(Axis(0), 1.).into_scalar());
/// ```
pub fn var(&self, ddof: A) -> A
where
A: Float + FromPrimitive,
{
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
assert!(
!(ddof < zero || ddof > n),
"`ddof` must not be less than zero or greater than the length of \
the axis",
);
let dof = n - ddof;
let mut mean = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
let mut sum_sq = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
for (count, x) in self.iter().enumerate() {
Copy link
Member

@bluss bluss Apr 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note, iter is not an efficient way to iterate a "flattened" array, so this should potentially use unary Zip instead, for example, or ArrayBase::fold. In the high level, this is documented in the module documentation, about array traversals.

can of course be fixed as a later change.

let delta = *x - mean;
mean = mean + delta / A::from_usize(count + 1).unwrap();
sum_sq = (*x - mean).mul_add(delta, sum_sq);
}
sum_sq / dof
}

/// Return standard deviation for the flattened array.
///
/// The standard deviation is computed from the variance.
///
/// # Example
///
/// ```
/// use ndarray::{arr2, Axis};
///
/// let a = arr2(&[[1., 2.],
/// [3., 4.],
/// [5., 6.]]);
///
/// let a_flat = a.view().into_shape(6).expect("This must not fail.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"This must not fail" does not add information, just use .unwrap(). I'm just afraid that the user will conclude some kinds of into_shape() (to 1D) always succeed, and they don't.

/// assert_eq!(a.std(1.), a_flat.std_axis(Axis(0), 1.).into_scalar());
/// ```
pub fn std(&self, ddof: A) -> A
where
A: Float + FromPrimitive,
{
self.var(ddof).sqrt()
}

/// Return `true` if the arrays' elementwise differences are all within
/// the given absolute tolerance, `false` otherwise.
///
Expand Down
34 changes: 33 additions & 1 deletion tests/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)]

use approx::assert_abs_diff_eq;
use ndarray::{arr0, arr1, arr2, array, aview1, Array, Array1, Array2, Array3, Axis};
use ndarray::{arr0, arr1, arr2, arr3, array, aview1, Array, Array1, Array2, Array3, Axis};
use std::f64;

#[test]
Expand Down Expand Up @@ -225,3 +225,35 @@ fn std_axis_empty_axis() {
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn var_var_axis() {
let a = arr3(&[
[[1., 2.], [1., 3.]],
[[3., 5.], [1., 3.]],
[[5., 7.], [1., 3.]],
]);

let a_flat = a
.view()
.into_shape(a.len().clone())
.expect("into_shape to a.len() must not fail.");

assert_eq!(a.var(1.), a_flat.var_axis(Axis(0), 1.).into_scalar());
}

#[test]
fn std_std_axis() {
let a = arr3(&[
[[1., 2.], [1., 3.]],
[[3., 5.], [1., 3.]],
[[5., 7.], [1., 3.]],
]);

let a_flat = a
.view()
.into_shape(a.len().clone())
.expect("into_shape to a.len() must not fail.");

assert_eq!(a.std(1.), a_flat.std_axis(Axis(0), 1.).into_scalar());
}