From af63c919e863e389e4849b144d1e783476ed8c0a Mon Sep 17 00:00:00 2001 From: dcraven Date: Mon, 18 May 2020 23:18:05 +0200 Subject: [PATCH 1/6] bench: add criterion bench for order_statistic --- Cargo.toml | 9 ++- benches/order_statistics.rs | 116 ++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 benches/order_statistics.rs diff --git a/Cargo.toml b/Cargo.toml index d9aace77..f8788488 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,4 +19,11 @@ path = "src/lib.rs" [dependencies] rand = "0.7" -nalgebra = "0.19" \ No newline at end of file +nalgebra = "0.19" + +[dev-dependencies] +criterion = "*" + +[[bench]] +name = "order_statistics" +harness = false \ No newline at end of file diff --git a/benches/order_statistics.rs b/benches/order_statistics.rs new file mode 100644 index 00000000..d9b71710 --- /dev/null +++ b/benches/order_statistics.rs @@ -0,0 +1,116 @@ +extern crate rand; +extern crate statrs; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use rand::prelude::*; +// use statrs::distribution::Categorical; +// use statrs::distribution::InverseCDF; +use statrs::statistics::*; + +fn bench_order_statistic(c: &mut Criterion) { + let mut rng = thread_rng(); + let to_random_owned = |data: &[f64]| -> Vec { + let mut rng = thread_rng(); + let mut owned = data.to_vec(); + owned.shuffle(&mut rng); + owned + }; + let k = black_box(rng.gen()); + let tau = black_box(rng.gen_range(0.0, 1.0)); + let mut group = c.benchmark_group("order statistic"); + let data: Vec<_> = (0..100).map(|x| x as f64).collect(); + group.bench_function("order_statistic", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.order_statistic(k), + BatchSize::SmallInput, + ) + }); + group.bench_function("median", |b| { + b.iter_batched( + || to_random_owned(&data), + |data| data.median(), + BatchSize::SmallInput, + ) + }); + group.bench_function("quantile", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.quantile(tau), + BatchSize::SmallInput, + ) + }); + group.bench_function("percentile", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.percentile(k), + BatchSize::SmallInput, + ) + }); + group.bench_function("lower_quartile", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.lower_quartile(), + BatchSize::SmallInput, + ) + }); + group.bench_function("upper_quartile", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.upper_quartile(), + BatchSize::SmallInput, + ) + }); + group.bench_function("interquartile_range", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.interquartile_range(), + BatchSize::SmallInput, + ) + }); + group.bench_function("ranks: RankTieBreaker::First", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.ranks(RankTieBreaker::First), + BatchSize::SmallInput, + ) + }); + group.bench_function("ranks: RankTieBreaker::Average", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.ranks(RankTieBreaker::Average), + BatchSize::SmallInput, + ) + }); + group.bench_function("ranks: RankTieBreaker::Min", |b| { + b.iter_batched( + || to_random_owned(&data), + |mut data| data.ranks(RankTieBreaker::Min), + BatchSize::SmallInput, + ) + }); + group.finish(); +} + +// fn bench_categorical_distribution(c: &mut Criterion) { +// c.bench_function("new categorical distribution", |b| { +// b.iter_batched( +// || rand_vec(100), +// |slice| Categorical::new(&slice), +// BatchSize::SmallInput, +// ) +// }); +// } + +// fn categorical_distribution_inverse_cdf(c: &mut Criterion) { +// let x = black_box(5.0); +// c.bench_function("inverse_cdf", |b| { +// b.iter_batched( +// || Categorical::new(&rand_vec(100)).unwrap(), +// |categorical| categorical.inverse_cdf(x), +// BatchSize::SmallInput, +// ) +// }); +// } + +criterion_group!(benches, bench_order_statistic); +criterion_main!(benches); From c8fed8452d9ccbff4e3547f250911cd0db21de11 Mon Sep 17 00:00:00 2001 From: dcraven Date: Tue, 5 May 2020 01:24:23 +0200 Subject: [PATCH 2/6] fix: refactor unsafes safer fix: stylistic issues --- src/distribution/beta.rs | 5 +- src/distribution/categorical.rs | 59 ++--- src/distribution/cauchy.rs | 5 +- src/distribution/chi_squared.rs | 5 +- src/distribution/gamma.rs | 5 +- src/distribution/geometric.rs | 2 +- src/distribution/hypergeometric.rs | 6 +- src/distribution/inverse_gamma.rs | 5 +- src/distribution/log_normal.rs | 5 +- src/distribution/multinomial.rs | 5 +- src/distribution/multivariate_normal.rs | 2 +- src/distribution/normal.rs | 5 +- src/distribution/pareto.rs | 5 +- src/distribution/poisson.rs | 2 +- src/distribution/students_t.rs | 6 +- src/distribution/triangular.rs | 6 +- src/distribution/weibull.rs | 4 +- src/statistics/slice_statistics.rs | 297 +++--------------------- 18 files changed, 67 insertions(+), 362 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index da98229e..55f24f5f 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -51,10 +51,7 @@ impl Beta { match (shape_a, shape_b, is_nan) { (_, _, true) => Err(StatsError::BadParams), (_, _, false) if shape_a <= 0.0 || shape_b <= 0.0 => Err(StatsError::BadParams), - (_, _, false) => Ok(Beta { - shape_a: shape_a, - shape_b: shape_b, - }), + (_, _, false) => Ok(Beta { shape_a, shape_b }), } } diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 6064da0b..ca58104f 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -62,16 +62,11 @@ impl Categorical { // extract normalized probability mass let sum = cdf[cdf.len() - 1]; let mut norm_pmf = vec![0.0; prob_mass.len()]; - for i in 0..prob_mass.len() { - unsafe { - let elem = norm_pmf.get_unchecked_mut(i); - *elem = prob_mass.get_unchecked(i) / sum; - } - } - Ok(Categorical { - norm_pmf: norm_pmf, - cdf: cdf, - }) + norm_pmf + .iter_mut() + .zip(prob_mass.iter()) + .for_each(|(np, pm)| *np = *pm / sum); + Ok(Categorical { norm_pmf, cdf }) } } @@ -293,11 +288,7 @@ impl Discrete for Categorical { /// p_x /// ``` fn pmf(&self, x: u64) -> f64 { - if x >= self.norm_pmf.len() as u64 { - 0.0 - } else { - unsafe { *self.norm_pmf.get_unchecked(x as usize) } - } + *self.norm_pmf.get(x as usize).unwrap_or(&0.0) } /// Calculates the log probability mass function for the categorical @@ -311,38 +302,22 @@ impl Discrete for Categorical { /// without doing any bounds checking pub fn sample_unchecked(r: &mut R, cdf: &[f64]) -> f64 { let draw = r.gen::() * unsafe { cdf.get_unchecked(cdf.len() - 1) }; - let mut idx = 0; - - if draw == 0.0 { - // skip zero-probability categories - let mut el = unsafe { cdf.get_unchecked(idx) }; - while *el == 0.0 { - // don't need bounds checking because we do not allow - // creating Categorical distributions with all 0.0 probs - idx += 1; - el = unsafe { cdf.get_unchecked(idx) } - } - } - let mut el = unsafe { cdf.get_unchecked(idx) }; - while draw > *el { - idx += 1; - el = unsafe { cdf.get_unchecked(idx) }; - } - idx as f64 + cdf.iter() + .enumerate() + .find(|(_, val)| **val >= draw) + .map(|(i, _)| i) + .unwrap() as f64 } /// Computes the cdf from the given probability masses. Performs /// no parameter or bounds checking. pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec { - let mut cdf = vec![0.0; prob_mass.len()]; - cdf[0] = prob_mass[0]; - for i in 1..prob_mass.len() { - unsafe { - let val = cdf.get_unchecked(i - 1) + prob_mass.get_unchecked(i); - let elem = cdf.get_unchecked_mut(i); - *elem = val; - } - } + let mut cdf = Vec::with_capacity(prob_mass.len()); + prob_mass.iter().fold(0.0, |s, p| { + let sum = s + p; + cdf.push(sum); + sum + }); cdf } diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 0d3229d6..57bde23d 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -47,10 +47,7 @@ impl Cauchy { if location.is_nan() || scale.is_nan() || scale <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Cauchy { - location: location, - scale: scale, - }) + Ok(Cauchy { location, scale }) } } diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 5e00587b..27f00f4d 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -50,10 +50,7 @@ impl ChiSquared { /// assert!(result.is_err()); /// ``` pub fn new(freedom: f64) -> Result { - Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { - freedom: freedom, - g: g, - }) + Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) } /// Returns the degrees of freedom of the chi-squared diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 8c4d8def..160b7931 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -51,10 +51,7 @@ impl Gamma { match (shape, rate, is_nan) { (_, _, true) => Err(StatsError::BadParams), (_, _, false) if shape <= 0.0 || rate <= 0.0 => Err(StatsError::BadParams), - (_, _, false) => Ok(Gamma { - shape: shape, - rate: rate, - }), + (_, _, false) => Ok(Gamma { shape, rate }), } } diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index f5f9f67d..11dd1854 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -49,7 +49,7 @@ impl Geometric { if p <= 0.0 || p > 1.0 || p.is_nan() { Err(StatsError::BadParams) } else { - Ok(Geometric { p: p }) + Ok(Geometric { p }) } } diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 962470d2..ad7b4c07 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -48,9 +48,9 @@ impl Hypergeometric { Err(StatsError::BadParams) } else { Ok(Hypergeometric { - population: population, - successes: successes, - draws: draws, + population, + successes, + draws, }) } } diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 2f22cc81..97a218d4 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -55,10 +55,7 @@ impl InverseGamma { (_, _, false) if shape == f64::INFINITY || rate == f64::INFINITY => { Err(StatsError::BadParams) } - (_, _, false) => Ok(InverseGamma { - shape: shape, - rate: rate, - }), + (_, _, false) => Ok(InverseGamma { shape, rate }), } } diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 0b041436..5104541c 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -51,10 +51,7 @@ impl LogNormal { if location.is_nan() || scale.is_nan() || scale <= 0.0 { Err(StatsError::BadParams) } else { - Ok(LogNormal { - location: location, - scale: scale, - }) + Ok(LogNormal { location, scale }) } } } diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index a06e9d06..170168d3 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -54,10 +54,7 @@ impl Multinomial { if !super::internal::is_valid_multinomial(p, true) { Err(StatsError::BadParams) } else { - Ok(Multinomial { - p: p.to_vec(), - n: n, - }) + Ok(Multinomial { p: p.to_vec(), n }) } } diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 80b4e812..fc8e09bc 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -81,7 +81,7 @@ where mu: mean.clone(), cov: cov.clone(), precision: cholesky_decomp.inverse(), - pdf_const: pdf_const, + pdf_const, }), } } diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 3f5dff84..292493b0 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -49,10 +49,7 @@ impl Normal { if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Normal { - mean: mean, - std_dev: std_dev, - }) + Ok(Normal { mean, std_dev }) } } } diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 003b066b..70fd6317 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -51,10 +51,7 @@ impl Pareto { if is_nan || scale <= 0.0 || shape <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Pareto { - scale: scale, - shape: shape, - }) + Ok(Pareto { scale, shape }) } } diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 34f7b9c9..0a2c05f5 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -49,7 +49,7 @@ impl Poisson { if lambda.is_nan() || lambda <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Poisson { lambda: lambda }) + Ok(Poisson { lambda }) } } diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index a68eabcd..b024dd11 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -54,9 +54,9 @@ impl StudentsT { Err(StatsError::BadParams) } else { Ok(StudentsT { - location: location, - scale: scale, - freedom: freedom, + location, + scale, + freedom, }) } } diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 14d1a682..c8a38423 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -59,11 +59,7 @@ impl Triangular { if max == min { return Err(StatsError::BadParams); } - Ok(Triangular { - min: min, - max: max, - mode: mode, - }) + Ok(Triangular { min, max, mode }) } } diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 1d7f11f1..d2685cee 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -54,8 +54,8 @@ impl Weibull { (_, _, true) => Err(StatsError::BadParams), (_, _, false) if shape <= 0.0 || scale <= 0.0 => Err(StatsError::BadParams), (_, _, false) => Ok(Weibull { - shape: shape, - scale: scale, + shape, + scale, scale_pow_shape_inv: scale.powf(-shape), }), } diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index 4d0223e7..0ba4de6a 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -61,48 +61,38 @@ impl OrderStatistics for [f64] { fn ranks(&mut self, tie_breaker: RankTieBreaker) -> Vec { let n = self.len(); let mut ranks: Vec = vec![0.0; n]; - let mut index: Vec = (0..n).collect(); - + let mut enumerated: Vec<_> = self.iter().enumerate().collect(); + enumerated.sort_by(|fst, snd| fst.1.partial_cmp(&snd.1).unwrap()); match tie_breaker { RankTieBreaker::First => { - quick_sort_all(self, &mut *index, 0, n - 1); - unsafe { - for i in 0..ranks.len() { - ranks[*index.get_unchecked(i)] = (i + 1) as f64; - } + for (i, idx) in enumerated.into_iter().map(|(idx, _)| idx).enumerate() { + ranks[idx] = (i + 1) as f64 } ranks } _ => { - sort(self, &mut *index); + let mut prev = 0; let mut prev_idx = 0; - unsafe { - for i in 1..n { - if (*self.get_unchecked(i) - *self.get_unchecked(prev_idx)).abs() <= 0.0 { - continue; - } - if i == prev_idx + 1 { - ranks[*index.get_unchecked(prev_idx)] = i as f64; - } else { - handle_rank_ties( - &mut *ranks, - &*index, - prev_idx as isize, - i as isize, - tie_breaker, - ); - } - prev_idx = i; + let mut prev_elt = 0.0; + for (i, (idx, elt)) in enumerated.iter().cloned().enumerate() { + if i == 0 { + prev_idx = idx; + prev_elt = *elt; + } + if (*elt - prev_elt).abs() <= 0.0 { + continue; } + if i == prev + 1 { + ranks[prev_idx] = i as f64; + } else { + handle_rank_ties(&mut ranks, &enumerated, prev, i, tie_breaker); + } + prev = i; + prev_idx = idx; + prev_elt = *elt; } - handle_rank_ties( - &mut *ranks, - &*index, - prev_idx as isize, - n as isize, - tie_breaker, - ); + handle_rank_ties(&mut ranks, &enumerated, prev, n, tie_breaker); ranks } } @@ -282,21 +272,20 @@ impl Median for [f64] { fn handle_rank_ties( ranks: &mut [f64], - index: &[usize], - a: isize, - b: isize, + index: &[(usize, &f64)], + a: usize, + b: usize, tie_breaker: RankTieBreaker, ) { let rank = match tie_breaker { - RankTieBreaker::Average => (b + a - 1) as f64 / 2.0 + 1.0, + // equivalent to (b + a - 1) as f64 / 2.0 + 1.0 but less overflow issues + RankTieBreaker::Average => b as f64 / 2.0 + a as f64 / 2.0 + 0.5, RankTieBreaker::Min => (a + 1) as f64, RankTieBreaker::Max => b as f64, RankTieBreaker::First => unreachable!(), }; - unsafe { - for i in a..b { - ranks[*index.get_unchecked(i as usize)] = rank - } + for i in &index[a..b] { + ranks[i.0] = rank } } @@ -369,234 +358,6 @@ fn select_inplace(arr: &mut [f64], rank: usize) -> f64 { } } -// sorts a primary slice and re-orders the secondary slice automatically. Uses -// insertion sort on small -// containers and quick sorts for larger ones -fn sort(primary: &mut [f64], secondary: &mut [usize]) { - assert_eq!( - primary.len(), - secondary.len(), - "{}", - StatsError::ContainersMustBeSameLength - ); - - let n = primary.len(); - if n <= 1 { - return; - } - if n == 2 { - unsafe { - if *primary.get_unchecked(0) > *primary.get_unchecked(1) { - primary.swap(0, 1); - secondary.swap(0, 1); - } - return; - } - } - - // insertion sort for really short containers - if n <= 10 { - unsafe { - for i in 1..n { - let key = *primary.get_unchecked(i); - let item = *secondary.get_unchecked(i); - let mut j = i as isize - 1; - while j >= 0 && *primary.get_unchecked(j as usize) > key { - primary[j as usize + 1] = *primary.get_unchecked(j as usize); - secondary[j as usize + 1] = *secondary.get_unchecked(j as usize); - j -= 1; - } - primary[j as usize + 1] = key; - secondary[j as usize + 1] = item; - } - return; - } - } - - quick_sort(primary, secondary, 0, n - 1); -} - -// quick sorts a primary slice and re-orders the secondary slice automatically -fn quick_sort(primary: &mut [f64], secondary: &mut [usize], left: usize, right: usize) { - assert_eq!( - primary.len(), - secondary.len(), - "{}", - StatsError::ContainersMustBeSameLength - ); - - // shadow left and right for mutability in loop - let mut left = left; - let mut right = right; - - unsafe { - loop { - // Pivoting - let mut a = left; - let mut b = right; - let p = a + ((b - a) >> 1); - - if *primary.get_unchecked(a) > *primary.get_unchecked(p) { - primary.swap(a, p); - secondary.swap(a, p); - } - if *primary.get_unchecked(a) > *primary.get_unchecked(b) { - primary.swap(a, b); - secondary.swap(a, b); - } - if *primary.get_unchecked(p) > *primary.get_unchecked(b) { - primary.swap(p, b); - secondary.swap(p, b); - } - - let pivot = *primary.get_unchecked(p); - - // Hoare partitioning - loop { - while *primary.get_unchecked(a) < pivot { - a += 1; - } - while pivot < *primary.get_unchecked(b) { - b -= 1; - } - if a > b { - break; - } - if a < b { - primary.swap(a, b); - secondary.swap(a, b); - } - - a += 1; - b -= 1; - - if a > b { - break; - } - } - - // In order to limit recursion depth to log(n), sort the shorter - // partition recursively and the longer partition iteratively. - // - // Must cast to isize as it's possible for left > b or a > right/ - // TODO: make this more robust - if (b as isize - left as isize) <= (right as isize - a as isize) { - if left < b { - quick_sort(primary, secondary, left, b); - } - left = a; - } else { - if a < right { - quick_sort(primary, secondary, a, right); - } - right = b; - } - - if left >= right { - break; - } - } - } -} - -// quick sorts a primary slice and re-orders the secondary slice automatically. -// Sorts secondarily by the secondary slice on primary key duplicates -fn quick_sort_all(primary: &mut [f64], secondary: &mut [usize], left: usize, right: usize) { - assert_eq!( - primary.len(), - secondary.len(), - "{}", - StatsError::ContainersMustBeSameLength - ); - - // shadow left and right for mutability in loop - let mut left = left; - let mut right = right; - - unsafe { - loop { - // Pivoting - let mut a = left; - let mut b = right; - let p = a + ((b - a) >> 1); - - if *primary.get_unchecked(a) > *primary.get_unchecked(p) - || *primary.get_unchecked(a) == *primary.get_unchecked(p) - && *secondary.get_unchecked(a) > *secondary.get_unchecked(p) - { - primary.swap(a, p); - secondary.swap(a, p); - } - if *primary.get_unchecked(a) > *primary.get_unchecked(b) - || *primary.get_unchecked(a) == *primary.get_unchecked(b) - && *secondary.get_unchecked(a) > *secondary.get_unchecked(b) - { - primary.swap(a, b); - secondary.swap(a, b); - } - if *primary.get_unchecked(p) > *primary.get_unchecked(b) - || *primary.get_unchecked(p) == *primary.get_unchecked(b) - && *secondary.get_unchecked(p) > *secondary.get_unchecked(b) - { - primary.swap(p, b); - secondary.swap(p, b); - } - - let pivot1 = *primary.get_unchecked(p); - let pivot2 = *secondary.get_unchecked(p); - - // Hoare partitioning - loop { - while *primary.get_unchecked(a) < pivot1 - || *primary.get_unchecked(a) == pivot1 && *secondary.get_unchecked(a) < pivot2 - { - a += 1; - } - while pivot1 < *primary.get_unchecked(b) - || pivot1 == *primary.get_unchecked(b) && pivot2 < *secondary.get_unchecked(b) - { - b -= 1; - } - if a > b { - break; - } - if a < b { - primary.swap(a, b); - secondary.swap(a, b); - } - - a += 1; - b -= 1; - - if a > b { - break; - } - } - - // In order to limit recursion depth to log(n), sort the shorter - // partition recursively and the longer partition iteratively. - // - // Must cast to isize as it's possible for left > b or a > right/ - // TODO: make this more robust - if (b as isize - left as isize) <= (right as isize - a as isize) { - if left < b { - quick_sort_all(primary, secondary, left, b); - } - left = a; - } else { - if a < right { - quick_sort_all(primary, secondary, a, right); - } - right = b; - } - - if left >= right { - break; - } - } - } -} - #[cfg_attr(rustfmt, rustfmt_skip)] #[cfg(test)] mod test { From 9620b53c4bc32b23ab3b688142360549239e59a5 Mon Sep 17 00:00:00 2001 From: dcraven Date: Mon, 11 May 2020 16:50:58 +0200 Subject: [PATCH 3/6] fix: compiler warnings --- src/function/factorial.rs | 4 ++-- src/statistics/slice_statistics.rs | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/function/factorial.rs b/src/function/factorial.rs index 3c3473d2..5f3267cc 100644 --- a/src/function/factorial.rs +++ b/src/function/factorial.rs @@ -5,7 +5,7 @@ use crate::error::StatsError; use crate::function::gamma; use crate::Result; use std::f64; -use std::sync::{Once, ONCE_INIT}; +use std::sync::Once; /// The maximum factorial representable /// by a 64-bit floating point without @@ -101,7 +101,7 @@ pub fn checked_multinomial(n: u64, ni: &[u64]) -> Result { const CACHE_SIZE: usize = 171; static mut FCACHE: &'static mut [f64; CACHE_SIZE] = &mut [1.0; CACHE_SIZE]; -static START: Once = ONCE_INIT; +static START: Once = Once::new(); fn get_fcache() -> &'static [f64; CACHE_SIZE] { unsafe { diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index 0ba4de6a..60462734 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -1,4 +1,3 @@ -use crate::error::StatsError; use crate::statistics::*; use std::f64; From a0bd3b8f426c6537ed9614b10e3d4f2846154fe1 Mon Sep 17 00:00:00 2001 From: dcraven Date: Mon, 18 May 2020 23:17:35 +0200 Subject: [PATCH 4/6] fix: removed more unsafes --- src/distribution/categorical.rs | 8 +-- src/distribution/internal.rs | 11 ++-- src/distribution/multinomial.rs | 4 +- src/function/evaluate.rs | 11 ++-- src/statistics/slice_statistics.rs | 86 +++++++++++++++--------------- 5 files changed, 57 insertions(+), 63 deletions(-) diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index ca58104f..6fbd65f1 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -71,7 +71,7 @@ impl Categorical { } fn cdf_max(&self) -> f64 { - *unsafe { self.cdf.get_unchecked(self.cdf.len() - 1) } + *self.cdf.last().unwrap() } } @@ -98,7 +98,7 @@ impl Univariate for Categorical { } else if x >= self.cdf.len() as f64 { 1.0 } else { - unsafe { self.cdf.get_unchecked(x as usize) / self.cdf_max() } + self.cdf.get(x as usize).unwrap() / self.cdf_max() } } } @@ -301,7 +301,7 @@ impl Discrete for Categorical { /// Draws a sample from the categorical distribution described by `cdf` /// without doing any bounds checking pub fn sample_unchecked(r: &mut R, cdf: &[f64]) -> f64 { - let draw = r.gen::() * unsafe { cdf.get_unchecked(cdf.len() - 1) }; + let draw = r.gen::() * cdf.last().unwrap(); cdf.iter() .enumerate() .find(|(_, val)| **val >= draw) @@ -333,7 +333,7 @@ fn binary_index(search: &[f64], val: f64) -> usize { let mut high = search.len() as isize - 1; while low <= high { let mid = low + ((high - low) / 2); - let el = *unsafe { search.get_unchecked(mid as usize) }; + let el = *search.get(mid as usize).unwrap(); if el > val { high = mid - 1; } else if el < val { diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 35ed8569..76f2bf9b 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -3,16 +3,15 @@ /// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { let mut sum = 0.0; - for i in 0..arr.len() { - let el = *unsafe { arr.get_unchecked(i) }; - if incl_zero && el < 0.0 { + for &elt in arr { + if incl_zero && elt < 0.0 { return false; - } else if !incl_zero && el <= 0.0 { + } else if !incl_zero && elt <= 0.0 { return false; - } else if el.is_nan() { + } else if elt.is_nan() { return false; } - sum += el; + sum += elt; } sum != 0.0 } diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 170168d3..63037eac 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -95,8 +95,8 @@ impl Distribution> for Multinomial { let mut res = vec![0.0; self.p.len()]; for _ in 0..self.n { let i = super::categorical::sample_unchecked(r, &p_cdf); - let el = unsafe { res.get_unchecked_mut(i as usize) }; - *el = *el + 1.0; + let el = res.get_mut(i as usize).unwrap(); + *el += 1.0; } res } diff --git a/src/function/evaluate.rs b/src/function/evaluate.rs index 24493ffc..59886a4c 100644 --- a/src/function/evaluate.rs +++ b/src/function/evaluate.rs @@ -16,14 +16,11 @@ pub fn polynomial(z: f64, coeff: &[f64]) -> f64 { return 0.0; } - unsafe { - let mut sum = *coeff.get_unchecked(n - 1); - for i in (0..n - 1).rev() { - sum *= z; - sum += *coeff.get_unchecked(i); - } - sum + let mut sum = *coeff.last().unwrap(); + for c in coeff[0..n - 1].iter().rev() { + sum = *c + z * sum; } + sum } #[cfg_attr(rustfmt, rustfmt_skip)] diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index 60462734..a0ff088b 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -298,61 +298,59 @@ fn select_inplace(arr: &mut [f64], rank: usize) -> f64 { return arr.max(); } - unsafe { - let mut low = 0; - let mut high = arr.len() - 1; - loop { - if high <= low + 1 { - if high == low + 1 && *arr.get_unchecked(high) < *arr.get_unchecked(low) { - arr.swap(low, high) - } - return *arr.get_unchecked(rank); + let mut low = 0; + let mut high = arr.len() - 1; + loop { + if high <= low + 1 { + if high == low + 1 && arr[high] < arr[low] { + arr.swap(low, high) } + return arr[rank]; + } - let middle = (low + high) >> 1; - arr.swap(middle, low + 1); + let middle = (low + high) / 2; + arr.swap(middle, low + 1); - if *arr.get_unchecked(low) > *arr.get_unchecked(high) { - arr.swap(low, high); - } - if *arr.get_unchecked(low + 1) > *arr.get_unchecked(high) { - arr.swap(low + 1, high); - } - if *arr.get_unchecked(low) > *arr.get_unchecked(low + 1) { - arr.swap(low, low + 1); - } + if arr[low] > arr[high] { + arr.swap(low, high); + } + if arr[low + 1] > arr[high] { + arr.swap(low + 1, high); + } + if arr[low] > arr[low + 1] { + arr.swap(low, low + 1); + } - let mut begin = low + 1; - let mut end = high; - let pivot = *arr.get_unchecked(begin); + let mut begin = low + 1; + let mut end = high; + let pivot = arr[begin]; + loop { loop { - loop { - begin += 1; - if *arr.get_unchecked(begin) >= pivot { - break; - } - } - loop { - end -= 1; - if *arr.get_unchecked(end) <= pivot { - break; - } + begin += 1; + if arr[begin] >= pivot { + break; } - if end < begin { + } + loop { + end -= 1; + if arr[end] <= pivot { break; } - arr.swap(begin, end); } + if end < begin { + break; + } + arr.swap(begin, end); + } - arr[low + 1] = *arr.get_unchecked(end); - arr[end] = pivot; + arr[low + 1] = arr[end]; + arr[end] = pivot; - if end >= rank { - high = end - 1; - } - if end <= rank { - low = begin; - } + if end >= rank { + high = end - 1; + } + if end <= rank { + low = begin; } } } From fe67e53b960f622831979b9efe09efed9c2bd30f Mon Sep 17 00:00:00 2001 From: dcraven Date: Sat, 30 May 2020 14:18:32 +0200 Subject: [PATCH 5/6] fix: cleanup bench --- benches/order_statistics.rs | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/benches/order_statistics.rs b/benches/order_statistics.rs index d9b71710..71f36f5b 100644 --- a/benches/order_statistics.rs +++ b/benches/order_statistics.rs @@ -2,8 +2,6 @@ extern crate rand; extern crate statrs; use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; use rand::prelude::*; -// use statrs::distribution::Categorical; -// use statrs::distribution::InverseCDF; use statrs::statistics::*; fn bench_order_statistic(c: &mut Criterion) { @@ -91,26 +89,5 @@ fn bench_order_statistic(c: &mut Criterion) { group.finish(); } -// fn bench_categorical_distribution(c: &mut Criterion) { -// c.bench_function("new categorical distribution", |b| { -// b.iter_batched( -// || rand_vec(100), -// |slice| Categorical::new(&slice), -// BatchSize::SmallInput, -// ) -// }); -// } - -// fn categorical_distribution_inverse_cdf(c: &mut Criterion) { -// let x = black_box(5.0); -// c.bench_function("inverse_cdf", |b| { -// b.iter_batched( -// || Categorical::new(&rand_vec(100)).unwrap(), -// |categorical| categorical.inverse_cdf(x), -// BatchSize::SmallInput, -// ) -// }); -// } - criterion_group!(benches, bench_order_statistic); criterion_main!(benches); From 37dc023d85cf8c94f4ed8e5a5e27b9e608ecf88f Mon Sep 17 00:00:00 2001 From: dcraven Date: Tue, 2 Jun 2020 23:59:10 +0200 Subject: [PATCH 6/6] fix: make sort function clearer --- src/statistics/slice_statistics.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index a0ff088b..fb9042ce 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -61,7 +61,7 @@ impl OrderStatistics for [f64] { let n = self.len(); let mut ranks: Vec = vec![0.0; n]; let mut enumerated: Vec<_> = self.iter().enumerate().collect(); - enumerated.sort_by(|fst, snd| fst.1.partial_cmp(&snd.1).unwrap()); + enumerated.sort_by(|(_, el_a), (_, el_b)| el_a.partial_cmp(el_b).unwrap()); match tie_breaker { RankTieBreaker::First => { for (i, idx) in enumerated.into_iter().map(|(idx, _)| idx).enumerate() {