-
Notifications
You must be signed in to change notification settings - Fork 291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue/655: Standard deviation and variance #753
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#![feature(test)] | ||
|
||
extern crate test; | ||
use test::Bencher; | ||
|
||
use ndarray::arr3; | ||
use ndarray::prelude::*; | ||
use std::iter::FromIterator; | ||
|
||
#[rustfmt::skip] | ||
fn big_array() -> Array2<f64> { | ||
arr2(&[ | ||
[ 92., 53., 51., 94., 27., 69., 11., 13., 62., 42., 73., 83., 2. , 53., 77. ], | ||
[ 65., 56., 11., 32., 95., 66., 88., 10., 37., 8. , 12., 2. , 59., 78., 48. ], | ||
[ 20., 86., 71., 99., 1. , 76., 29., 53., 87., 88., 61., 84., 2. , 87., 90. ], | ||
[ 19., 22., 44., 38., 85., 12., 8. , 38., 53., 46., 80., 70., 62., 14., 8. ], | ||
[ 51., 70., 71., 21., 14., 48., 34., 4. , 27., 55., 60., 95., 1. , 79., 1. ], | ||
[ 13., 23., 78., 97., 57., 16., 81., 31., 88., 15., 78., 95., 93., 9. , 6. ], | ||
[ 68., 58., 4. , 11., 91., 56., 61., 15., 60., 92., 29., 27., 22., 30., 2. ], | ||
[ 53., 70., 89., 42., 59., 79., 63., 61., 86., 48., 40., 50., 23., 18., 55. ], | ||
[ 14., 96., 68., 16., 52., 16., 70., 12., 16., 60., 28., 52., 56., 12., 37. ], | ||
[ 68., 73., 6. , 51., 54., 51., 97., 88., 36., 32., 83., 52., 53., 86., 4. ], | ||
[ 88., 11., 86., 91., 83., 71., 18., 60., 95., 59., 85., 92., 34., 76., 93. ], | ||
[ 81., 18., 47., 26., 53., 64., 53., 12., 55., 92., 76., 22., 81., 80., 21. ], | ||
[ 86., 48., 42., 19., 94., 86., 16., 37., 74., 85., 11., 9. , 80., 2. , 80. ], | ||
[ 51., 43., 55., 56., 49., 77., 78., 94., 80., 23., 72., 67., 58., 95., 95. ], | ||
[ 92., 24., 45., 41., 33., 64., 89., 8. , 75., 42., 32., 61., 19., 11., 61. ], | ||
[ 81., 35., 75., 67., 73., 30., 95., 17., 24., 48., 72., 2. , 46., 14., 50. ], | ||
[ 99., 87., 41., 87., 68., 22., 94., 73., 82., 87., 86., 46., 36., 26., 57. ], | ||
[ 96., 69., 28., 44., 32., 70., 94., 13., 85., 5. , 13., 44., 60., 79., 76. ], | ||
[ 81., 92., 42., 93., 99., 41., 13., 8. , 68., 92., 89., 83., 16., 82., 92. ], | ||
[ 29., 18., 10., 71., 4. , 20., 99., 10., 91., 51., 90., 78., 20., 25., 44. ], | ||
[ 57., 56., 96., 81., 87., 57., 32., 22., 29., 63., 76., 39., 52., 77., 96. ], | ||
[ 88., 2. , 56., 75., 72., 53., 0. , 57., 42., 83., 77., 85., 14., 15., 19. ], | ||
]) | ||
} | ||
|
||
#[bench] | ||
fn var_into_shape_use_var_axis(bench: &mut Bencher) { | ||
let a = arr3(&[ | ||
[[1., 2.], [1., 3.]], | ||
[[3., 5.], [1., 3.]], | ||
[[5., 7.], [1., 3.]], | ||
]); | ||
|
||
let len = a.len(); | ||
bench.iter(|| { | ||
let flattened = a | ||
.view() | ||
.into_shape(len) | ||
.expect("into_shape to a.len() can not fail."); | ||
flattened.var_axis(Axis(0), 1.) | ||
}); | ||
} | ||
|
||
#[bench] | ||
fn var_into_shape_use_var_axis_big(bench: &mut Bencher) { | ||
let a = big_array(); | ||
|
||
let len = a.len(); | ||
bench.iter(|| { | ||
let flattened = a | ||
.view() | ||
.into_shape(len) | ||
.expect("into_shape to a.len() can not fail."); | ||
flattened.var_axis(Axis(0), 1.) | ||
}); | ||
} | ||
|
||
#[bench] | ||
fn var_flatten_user_var_axis(bench: &mut Bencher) { | ||
let a = arr3(&[ | ||
[[1., 2.], [1., 3.]], | ||
[[3., 5.], [1., 3.]], | ||
[[5., 7.], [1., 3.]], | ||
]); | ||
|
||
bench.iter(|| { | ||
let flattened = Array::from_iter(a.iter().map(|&x| x)); | ||
flattened.var_axis(Axis(0), 1.) | ||
}) | ||
} | ||
|
||
#[bench] | ||
fn var_flatten_user_var_axis_big(bench: &mut Bencher) { | ||
let a = big_array(); | ||
bench.iter(|| { | ||
let flattened = Array::from_iter(a.iter().map(|&x| x)); | ||
flattened.var_axis(Axis(0), 1.) | ||
}) | ||
} | ||
|
||
#[bench] | ||
fn var_new(bench: &mut Bencher) { | ||
let a = arr3(&[ | ||
[[1., 2.], [1., 3.]], | ||
[[3., 5.], [1., 3.]], | ||
[[5., 7.], [1., 3.]], | ||
]); | ||
|
||
bench.iter(|| a.var(1.)) | ||
} | ||
|
||
#[bench] | ||
fn var_new_big(bench: &mut Bencher) { | ||
let a = big_array(); | ||
bench.iter(|| a.var(1.)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -235,25 +235,13 @@ where | |
A: Float + FromPrimitive, | ||
D: RemoveAxis, | ||
{ | ||
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); | ||
let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail."); | ||
assert!( | ||
!(ddof < zero || ddof > n), | ||
"`ddof` must not be less than zero or greater than the length of \ | ||
the axis", | ||
); | ||
let dof = n - ddof; | ||
let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis)); | ||
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis)); | ||
for (i, subview) in self.axis_iter(axis).enumerate() { | ||
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail."); | ||
azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) { | ||
let delta = x - *mean; | ||
*mean = *mean + delta / count; | ||
*sum_sq = (x - *mean).mul_add(delta, *sum_sq); | ||
let mut output = Array::zeros(self.dim.remove_axis(axis)); | ||
Zip::from(output.view_mut()) | ||
.and(self.lanes(axis)) | ||
.apply(|o, l| { | ||
*o = l.var(ddof); | ||
}); | ||
} | ||
sum_sq.mapv_into(|s| s / dof) | ||
output | ||
} | ||
|
||
/// Return standard deviation along `axis`. | ||
|
@@ -306,6 +294,67 @@ where | |
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt()) | ||
} | ||
|
||
/// Return variance for the flattened array. | ||
/// | ||
/// This uses the same method as var_axis. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use ndarray::{arr2, Axis}; | ||
/// | ||
/// let a = arr2(&[[1., 2.], | ||
/// [3., 4.], | ||
/// [5., 6.]]); | ||
/// | ||
/// let a_flat = a.view().into_shape(6).expect("This must not fail."); | ||
/// assert_eq!(a.var(1.), a_flat.var_axis(Axis(0), 1.).into_scalar()); | ||
/// ``` | ||
pub fn var(&self, ddof: A) -> A | ||
where | ||
A: Float + FromPrimitive, | ||
{ | ||
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); | ||
let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail."); | ||
assert!( | ||
!(ddof < zero || ddof > n), | ||
"`ddof` must not be less than zero or greater than the length of \ | ||
the axis", | ||
); | ||
let dof = n - ddof; | ||
let mut mean = A::from_usize(0).expect("Converting 0 to `A` must not fail."); | ||
let mut sum_sq = A::from_usize(0).expect("Converting 0 to `A` must not fail."); | ||
for (count, x) in self.iter().enumerate() { | ||
let delta = *x - mean; | ||
mean = mean + delta / A::from_usize(count + 1).unwrap(); | ||
sum_sq = (*x - mean).mul_add(delta, sum_sq); | ||
} | ||
sum_sq / dof | ||
} | ||
|
||
/// Return standard deviation for the flattened array. | ||
/// | ||
/// The standard deviation is computed from the variance. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use ndarray::{arr2, Axis}; | ||
/// | ||
/// let a = arr2(&[[1., 2.], | ||
/// [3., 4.], | ||
/// [5., 6.]]); | ||
/// | ||
/// let a_flat = a.view().into_shape(6).expect("This must not fail."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "This must not fail" does not add information, just use |
||
/// assert_eq!(a.std(1.), a_flat.std_axis(Axis(0), 1.).into_scalar()); | ||
/// ``` | ||
pub fn std(&self, ddof: A) -> A | ||
where | ||
A: Float + FromPrimitive, | ||
{ | ||
self.var(ddof).sqrt() | ||
} | ||
|
||
/// Return `true` if the arrays' elementwise differences are all within | ||
/// the given absolute tolerance, `false` otherwise. | ||
/// | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to note, iter is not an efficient way to iterate a "flattened" array, so this should potentially use unary Zip instead, for example, or
ArrayBase::fold
. In the high level, this is documented in the module documentation, about array traversals.can of course be fixed as a later change.