From 7bfcff65b9ba7fc30ead43cf8d5fed9c1d6fdb78 Mon Sep 17 00:00:00 2001 From: dcraven Date: Tue, 5 May 2020 01:24:23 +0200 Subject: [PATCH] fix: refactor unsafes safer fix: stylistic issues --- src/distribution/beta.rs | 5 +- src/distribution/categorical.rs | 59 ++--- src/distribution/cauchy.rs | 5 +- src/distribution/chi_squared.rs | 5 +- src/distribution/gamma.rs | 5 +- src/distribution/geometric.rs | 2 +- src/distribution/hypergeometric.rs | 6 +- src/distribution/inverse_gamma.rs | 5 +- src/distribution/log_normal.rs | 5 +- src/distribution/multinomial.rs | 5 +- src/distribution/multivariate_normal.rs | 2 +- src/distribution/normal.rs | 5 +- src/distribution/pareto.rs | 5 +- src/distribution/poisson.rs | 2 +- src/distribution/students_t.rs | 6 +- src/distribution/triangular.rs | 6 +- src/distribution/weibull.rs | 4 +- src/statistics/slice_statistics.rs | 297 +++--------------------- 18 files changed, 67 insertions(+), 362 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index da98229e..55f24f5f 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -51,10 +51,7 @@ impl Beta { match (shape_a, shape_b, is_nan) { (_, _, true) => Err(StatsError::BadParams), (_, _, false) if shape_a <= 0.0 || shape_b <= 0.0 => Err(StatsError::BadParams), - (_, _, false) => Ok(Beta { - shape_a: shape_a, - shape_b: shape_b, - }), + (_, _, false) => Ok(Beta { shape_a, shape_b }), } } diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 6064da0b..ca58104f 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -62,16 +62,11 @@ impl Categorical { // extract normalized probability mass let sum = cdf[cdf.len() - 1]; let mut norm_pmf = vec![0.0; prob_mass.len()]; - for i in 0..prob_mass.len() { - unsafe { - let elem = norm_pmf.get_unchecked_mut(i); - *elem = prob_mass.get_unchecked(i) / sum; - } - } - Ok(Categorical { - norm_pmf: norm_pmf, - cdf: cdf, - }) + norm_pmf + .iter_mut() + .zip(prob_mass.iter()) + .for_each(|(np, pm)| *np = *pm / sum); + Ok(Categorical { norm_pmf, cdf }) } } @@ -293,11 +288,7 @@ impl Discrete for Categorical { /// p_x /// ``` fn pmf(&self, x: u64) -> f64 { - if x >= self.norm_pmf.len() as u64 { - 0.0 - } else { - unsafe { *self.norm_pmf.get_unchecked(x as usize) } - } + *self.norm_pmf.get(x as usize).unwrap_or(&0.0) } /// Calculates the log probability mass function for the categorical @@ -311,38 +302,22 @@ impl Discrete for Categorical { /// without doing any bounds checking pub fn sample_unchecked(r: &mut R, cdf: &[f64]) -> f64 { let draw = r.gen::() * unsafe { cdf.get_unchecked(cdf.len() - 1) }; - let mut idx = 0; - - if draw == 0.0 { - // skip zero-probability categories - let mut el = unsafe { cdf.get_unchecked(idx) }; - while *el == 0.0 { - // don't need bounds checking because we do not allow - // creating Categorical distributions with all 0.0 probs - idx += 1; - el = unsafe { cdf.get_unchecked(idx) } - } - } - let mut el = unsafe { cdf.get_unchecked(idx) }; - while draw > *el { - idx += 1; - el = unsafe { cdf.get_unchecked(idx) }; - } - idx as f64 + cdf.iter() + .enumerate() + .find(|(_, val)| **val >= draw) + .map(|(i, _)| i) + .unwrap() as f64 } /// Computes the cdf from the given probability masses. Performs /// no parameter or bounds checking. pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec { - let mut cdf = vec![0.0; prob_mass.len()]; - cdf[0] = prob_mass[0]; - for i in 1..prob_mass.len() { - unsafe { - let val = cdf.get_unchecked(i - 1) + prob_mass.get_unchecked(i); - let elem = cdf.get_unchecked_mut(i); - *elem = val; - } - } + let mut cdf = Vec::with_capacity(prob_mass.len()); + prob_mass.iter().fold(0.0, |s, p| { + let sum = s + p; + cdf.push(sum); + sum + }); cdf } diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 0d3229d6..57bde23d 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -47,10 +47,7 @@ impl Cauchy { if location.is_nan() || scale.is_nan() || scale <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Cauchy { - location: location, - scale: scale, - }) + Ok(Cauchy { location, scale }) } } diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 5e00587b..27f00f4d 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -50,10 +50,7 @@ impl ChiSquared { /// assert!(result.is_err()); /// ``` pub fn new(freedom: f64) -> Result { - Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { - freedom: freedom, - g: g, - }) + Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) } /// Returns the degrees of freedom of the chi-squared diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 8c4d8def..160b7931 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -51,10 +51,7 @@ impl Gamma { match (shape, rate, is_nan) { (_, _, true) => Err(StatsError::BadParams), (_, _, false) if shape <= 0.0 || rate <= 0.0 => Err(StatsError::BadParams), - (_, _, false) => Ok(Gamma { - shape: shape, - rate: rate, - }), + (_, _, false) => Ok(Gamma { shape, rate }), } } diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index f5f9f67d..11dd1854 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -49,7 +49,7 @@ impl Geometric { if p <= 0.0 || p > 1.0 || p.is_nan() { Err(StatsError::BadParams) } else { - Ok(Geometric { p: p }) + Ok(Geometric { p }) } } diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 962470d2..ad7b4c07 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -48,9 +48,9 @@ impl Hypergeometric { Err(StatsError::BadParams) } else { Ok(Hypergeometric { - population: population, - successes: successes, - draws: draws, + population, + successes, + draws, }) } } diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 2f22cc81..97a218d4 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -55,10 +55,7 @@ impl InverseGamma { (_, _, false) if shape == f64::INFINITY || rate == f64::INFINITY => { Err(StatsError::BadParams) } - (_, _, false) => Ok(InverseGamma { - shape: shape, - rate: rate, - }), + (_, _, false) => Ok(InverseGamma { shape, rate }), } } diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 0b041436..5104541c 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -51,10 +51,7 @@ impl LogNormal { if location.is_nan() || scale.is_nan() || scale <= 0.0 { Err(StatsError::BadParams) } else { - Ok(LogNormal { - location: location, - scale: scale, - }) + Ok(LogNormal { location, scale }) } } } diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index a06e9d06..170168d3 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -54,10 +54,7 @@ impl Multinomial { if !super::internal::is_valid_multinomial(p, true) { Err(StatsError::BadParams) } else { - Ok(Multinomial { - p: p.to_vec(), - n: n, - }) + Ok(Multinomial { p: p.to_vec(), n }) } } diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 80b4e812..fc8e09bc 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -81,7 +81,7 @@ where mu: mean.clone(), cov: cov.clone(), precision: cholesky_decomp.inverse(), - pdf_const: pdf_const, + pdf_const, }), } } diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 3f5dff84..292493b0 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -49,10 +49,7 @@ impl Normal { if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Normal { - mean: mean, - std_dev: std_dev, - }) + Ok(Normal { mean, std_dev }) } } } diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 003b066b..70fd6317 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -51,10 +51,7 @@ impl Pareto { if is_nan || scale <= 0.0 || shape <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Pareto { - scale: scale, - shape: shape, - }) + Ok(Pareto { scale, shape }) } } diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 34f7b9c9..0a2c05f5 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -49,7 +49,7 @@ impl Poisson { if lambda.is_nan() || lambda <= 0.0 { Err(StatsError::BadParams) } else { - Ok(Poisson { lambda: lambda }) + Ok(Poisson { lambda }) } } diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index a68eabcd..b024dd11 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -54,9 +54,9 @@ impl StudentsT { Err(StatsError::BadParams) } else { Ok(StudentsT { - location: location, - scale: scale, - freedom: freedom, + location, + scale, + freedom, }) } } diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 14d1a682..c8a38423 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -59,11 +59,7 @@ impl Triangular { if max == min { return Err(StatsError::BadParams); } - Ok(Triangular { - min: min, - max: max, - mode: mode, - }) + Ok(Triangular { min, max, mode }) } } diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 1d7f11f1..d2685cee 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -54,8 +54,8 @@ impl Weibull { (_, _, true) => Err(StatsError::BadParams), (_, _, false) if shape <= 0.0 || scale <= 0.0 => Err(StatsError::BadParams), (_, _, false) => Ok(Weibull { - shape: shape, - scale: scale, + shape, + scale, scale_pow_shape_inv: scale.powf(-shape), }), } diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index 4d0223e7..0ba4de6a 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -61,48 +61,38 @@ impl OrderStatistics for [f64] { fn ranks(&mut self, tie_breaker: RankTieBreaker) -> Vec { let n = self.len(); let mut ranks: Vec = vec![0.0; n]; - let mut index: Vec = (0..n).collect(); - + let mut enumerated: Vec<_> = self.iter().enumerate().collect(); + enumerated.sort_by(|fst, snd| fst.1.partial_cmp(&snd.1).unwrap()); match tie_breaker { RankTieBreaker::First => { - quick_sort_all(self, &mut *index, 0, n - 1); - unsafe { - for i in 0..ranks.len() { - ranks[*index.get_unchecked(i)] = (i + 1) as f64; - } + for (i, idx) in enumerated.into_iter().map(|(idx, _)| idx).enumerate() { + ranks[idx] = (i + 1) as f64 } ranks } _ => { - sort(self, &mut *index); + let mut prev = 0; let mut prev_idx = 0; - unsafe { - for i in 1..n { - if (*self.get_unchecked(i) - *self.get_unchecked(prev_idx)).abs() <= 0.0 { - continue; - } - if i == prev_idx + 1 { - ranks[*index.get_unchecked(prev_idx)] = i as f64; - } else { - handle_rank_ties( - &mut *ranks, - &*index, - prev_idx as isize, - i as isize, - tie_breaker, - ); - } - prev_idx = i; + let mut prev_elt = 0.0; + for (i, (idx, elt)) in enumerated.iter().cloned().enumerate() { + if i == 0 { + prev_idx = idx; + prev_elt = *elt; + } + if (*elt - prev_elt).abs() <= 0.0 { + continue; } + if i == prev + 1 { + ranks[prev_idx] = i as f64; + } else { + handle_rank_ties(&mut ranks, &enumerated, prev, i, tie_breaker); + } + prev = i; + prev_idx = idx; + prev_elt = *elt; } - handle_rank_ties( - &mut *ranks, - &*index, - prev_idx as isize, - n as isize, - tie_breaker, - ); + handle_rank_ties(&mut ranks, &enumerated, prev, n, tie_breaker); ranks } } @@ -282,21 +272,20 @@ impl Median for [f64] { fn handle_rank_ties( ranks: &mut [f64], - index: &[usize], - a: isize, - b: isize, + index: &[(usize, &f64)], + a: usize, + b: usize, tie_breaker: RankTieBreaker, ) { let rank = match tie_breaker { - RankTieBreaker::Average => (b + a - 1) as f64 / 2.0 + 1.0, + // equivalent to (b + a - 1) as f64 / 2.0 + 1.0 but less overflow issues + RankTieBreaker::Average => b as f64 / 2.0 + a as f64 / 2.0 + 0.5, RankTieBreaker::Min => (a + 1) as f64, RankTieBreaker::Max => b as f64, RankTieBreaker::First => unreachable!(), }; - unsafe { - for i in a..b { - ranks[*index.get_unchecked(i as usize)] = rank - } + for i in &index[a..b] { + ranks[i.0] = rank } } @@ -369,234 +358,6 @@ fn select_inplace(arr: &mut [f64], rank: usize) -> f64 { } } -// sorts a primary slice and re-orders the secondary slice automatically. Uses -// insertion sort on small -// containers and quick sorts for larger ones -fn sort(primary: &mut [f64], secondary: &mut [usize]) { - assert_eq!( - primary.len(), - secondary.len(), - "{}", - StatsError::ContainersMustBeSameLength - ); - - let n = primary.len(); - if n <= 1 { - return; - } - if n == 2 { - unsafe { - if *primary.get_unchecked(0) > *primary.get_unchecked(1) { - primary.swap(0, 1); - secondary.swap(0, 1); - } - return; - } - } - - // insertion sort for really short containers - if n <= 10 { - unsafe { - for i in 1..n { - let key = *primary.get_unchecked(i); - let item = *secondary.get_unchecked(i); - let mut j = i as isize - 1; - while j >= 0 && *primary.get_unchecked(j as usize) > key { - primary[j as usize + 1] = *primary.get_unchecked(j as usize); - secondary[j as usize + 1] = *secondary.get_unchecked(j as usize); - j -= 1; - } - primary[j as usize + 1] = key; - secondary[j as usize + 1] = item; - } - return; - } - } - - quick_sort(primary, secondary, 0, n - 1); -} - -// quick sorts a primary slice and re-orders the secondary slice automatically -fn quick_sort(primary: &mut [f64], secondary: &mut [usize], left: usize, right: usize) { - assert_eq!( - primary.len(), - secondary.len(), - "{}", - StatsError::ContainersMustBeSameLength - ); - - // shadow left and right for mutability in loop - let mut left = left; - let mut right = right; - - unsafe { - loop { - // Pivoting - let mut a = left; - let mut b = right; - let p = a + ((b - a) >> 1); - - if *primary.get_unchecked(a) > *primary.get_unchecked(p) { - primary.swap(a, p); - secondary.swap(a, p); - } - if *primary.get_unchecked(a) > *primary.get_unchecked(b) { - primary.swap(a, b); - secondary.swap(a, b); - } - if *primary.get_unchecked(p) > *primary.get_unchecked(b) { - primary.swap(p, b); - secondary.swap(p, b); - } - - let pivot = *primary.get_unchecked(p); - - // Hoare partitioning - loop { - while *primary.get_unchecked(a) < pivot { - a += 1; - } - while pivot < *primary.get_unchecked(b) { - b -= 1; - } - if a > b { - break; - } - if a < b { - primary.swap(a, b); - secondary.swap(a, b); - } - - a += 1; - b -= 1; - - if a > b { - break; - } - } - - // In order to limit recursion depth to log(n), sort the shorter - // partition recursively and the longer partition iteratively. - // - // Must cast to isize as it's possible for left > b or a > right/ - // TODO: make this more robust - if (b as isize - left as isize) <= (right as isize - a as isize) { - if left < b { - quick_sort(primary, secondary, left, b); - } - left = a; - } else { - if a < right { - quick_sort(primary, secondary, a, right); - } - right = b; - } - - if left >= right { - break; - } - } - } -} - -// quick sorts a primary slice and re-orders the secondary slice automatically. -// Sorts secondarily by the secondary slice on primary key duplicates -fn quick_sort_all(primary: &mut [f64], secondary: &mut [usize], left: usize, right: usize) { - assert_eq!( - primary.len(), - secondary.len(), - "{}", - StatsError::ContainersMustBeSameLength - ); - - // shadow left and right for mutability in loop - let mut left = left; - let mut right = right; - - unsafe { - loop { - // Pivoting - let mut a = left; - let mut b = right; - let p = a + ((b - a) >> 1); - - if *primary.get_unchecked(a) > *primary.get_unchecked(p) - || *primary.get_unchecked(a) == *primary.get_unchecked(p) - && *secondary.get_unchecked(a) > *secondary.get_unchecked(p) - { - primary.swap(a, p); - secondary.swap(a, p); - } - if *primary.get_unchecked(a) > *primary.get_unchecked(b) - || *primary.get_unchecked(a) == *primary.get_unchecked(b) - && *secondary.get_unchecked(a) > *secondary.get_unchecked(b) - { - primary.swap(a, b); - secondary.swap(a, b); - } - if *primary.get_unchecked(p) > *primary.get_unchecked(b) - || *primary.get_unchecked(p) == *primary.get_unchecked(b) - && *secondary.get_unchecked(p) > *secondary.get_unchecked(b) - { - primary.swap(p, b); - secondary.swap(p, b); - } - - let pivot1 = *primary.get_unchecked(p); - let pivot2 = *secondary.get_unchecked(p); - - // Hoare partitioning - loop { - while *primary.get_unchecked(a) < pivot1 - || *primary.get_unchecked(a) == pivot1 && *secondary.get_unchecked(a) < pivot2 - { - a += 1; - } - while pivot1 < *primary.get_unchecked(b) - || pivot1 == *primary.get_unchecked(b) && pivot2 < *secondary.get_unchecked(b) - { - b -= 1; - } - if a > b { - break; - } - if a < b { - primary.swap(a, b); - secondary.swap(a, b); - } - - a += 1; - b -= 1; - - if a > b { - break; - } - } - - // In order to limit recursion depth to log(n), sort the shorter - // partition recursively and the longer partition iteratively. - // - // Must cast to isize as it's possible for left > b or a > right/ - // TODO: make this more robust - if (b as isize - left as isize) <= (right as isize - a as isize) { - if left < b { - quick_sort_all(primary, secondary, left, b); - } - left = a; - } else { - if a < right { - quick_sort_all(primary, secondary, a, right); - } - right = b; - } - - if left >= right { - break; - } - } - } -} - #[cfg_attr(rustfmt, rustfmt_skip)] #[cfg(test)] mod test {