Skip to content

Commit

Permalink
fix: refactor unsafes safer
Browse files Browse the repository at this point in the history
fix: stylistic issues
  • Loading branch information
troublescooter committed May 11, 2020
1 parent 7b20b39 commit 7bfcff6
Show file tree
Hide file tree
Showing 18 changed files with 67 additions and 362 deletions.
5 changes: 1 addition & 4 deletions src/distribution/beta.rs
Expand Up @@ -51,10 +51,7 @@ impl Beta {
match (shape_a, shape_b, is_nan) {
(_, _, true) => Err(StatsError::BadParams),
(_, _, false) if shape_a <= 0.0 || shape_b <= 0.0 => Err(StatsError::BadParams),
(_, _, false) => Ok(Beta {
shape_a: shape_a,
shape_b: shape_b,
}),
(_, _, false) => Ok(Beta { shape_a, shape_b }),
}
}

Expand Down
59 changes: 17 additions & 42 deletions src/distribution/categorical.rs
Expand Up @@ -62,16 +62,11 @@ impl Categorical {
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
for i in 0..prob_mass.len() {
unsafe {
let elem = norm_pmf.get_unchecked_mut(i);
*elem = prob_mass.get_unchecked(i) / sum;
}
}
Ok(Categorical {
norm_pmf: norm_pmf,
cdf: cdf,
})
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf })
}
}

Expand Down Expand Up @@ -293,11 +288,7 @@ impl Discrete<u64, f64> for Categorical {
/// p_x
/// ```
fn pmf(&self, x: u64) -> f64 {
if x >= self.norm_pmf.len() as u64 {
0.0
} else {
unsafe { *self.norm_pmf.get_unchecked(x as usize) }
}
*self.norm_pmf.get(x as usize).unwrap_or(&0.0)
}

/// Calculates the log probability mass function for the categorical
Expand All @@ -311,38 +302,22 @@ impl Discrete<u64, f64> for Categorical {
/// without doing any bounds checking
pub fn sample_unchecked<R: Rng + ?Sized>(r: &mut R, cdf: &[f64]) -> f64 {
let draw = r.gen::<f64>() * unsafe { cdf.get_unchecked(cdf.len() - 1) };
let mut idx = 0;

if draw == 0.0 {
// skip zero-probability categories
let mut el = unsafe { cdf.get_unchecked(idx) };
while *el == 0.0 {
// don't need bounds checking because we do not allow
// creating Categorical distributions with all 0.0 probs
idx += 1;
el = unsafe { cdf.get_unchecked(idx) }
}
}
let mut el = unsafe { cdf.get_unchecked(idx) };
while draw > *el {
idx += 1;
el = unsafe { cdf.get_unchecked(idx) };
}
idx as f64
cdf.iter()
.enumerate()
.find(|(_, val)| **val >= draw)
.map(|(i, _)| i)
.unwrap() as f64
}

/// Computes the cdf from the given probability masses. Performs
/// no parameter or bounds checking.
pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
let mut cdf = vec![0.0; prob_mass.len()];
cdf[0] = prob_mass[0];
for i in 1..prob_mass.len() {
unsafe {
let val = cdf.get_unchecked(i - 1) + prob_mass.get_unchecked(i);
let elem = cdf.get_unchecked_mut(i);
*elem = val;
}
}
let mut cdf = Vec::with_capacity(prob_mass.len());
prob_mass.iter().fold(0.0, |s, p| {
let sum = s + p;
cdf.push(sum);
sum
});
cdf
}

Expand Down
5 changes: 1 addition & 4 deletions src/distribution/cauchy.rs
Expand Up @@ -47,10 +47,7 @@ impl Cauchy {
if location.is_nan() || scale.is_nan() || scale <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Cauchy {
location: location,
scale: scale,
})
Ok(Cauchy { location, scale })
}
}

Expand Down
5 changes: 1 addition & 4 deletions src/distribution/chi_squared.rs
Expand Up @@ -50,10 +50,7 @@ impl ChiSquared {
/// assert!(result.is_err());
/// ```
pub fn new(freedom: f64) -> Result<ChiSquared> {
Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared {
freedom: freedom,
g: g,
})
Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g })
}

/// Returns the degrees of freedom of the chi-squared
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/gamma.rs
Expand Up @@ -51,10 +51,7 @@ impl Gamma {
match (shape, rate, is_nan) {
(_, _, true) => Err(StatsError::BadParams),
(_, _, false) if shape <= 0.0 || rate <= 0.0 => Err(StatsError::BadParams),
(_, _, false) => Ok(Gamma {
shape: shape,
rate: rate,
}),
(_, _, false) => Ok(Gamma { shape, rate }),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/distribution/geometric.rs
Expand Up @@ -49,7 +49,7 @@ impl Geometric {
if p <= 0.0 || p > 1.0 || p.is_nan() {
Err(StatsError::BadParams)
} else {
Ok(Geometric { p: p })
Ok(Geometric { p })
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/distribution/hypergeometric.rs
Expand Up @@ -48,9 +48,9 @@ impl Hypergeometric {
Err(StatsError::BadParams)
} else {
Ok(Hypergeometric {
population: population,
successes: successes,
draws: draws,
population,
successes,
draws,
})
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/inverse_gamma.rs
Expand Up @@ -55,10 +55,7 @@ impl InverseGamma {
(_, _, false) if shape == f64::INFINITY || rate == f64::INFINITY => {
Err(StatsError::BadParams)
}
(_, _, false) => Ok(InverseGamma {
shape: shape,
rate: rate,
}),
(_, _, false) => Ok(InverseGamma { shape, rate }),
}
}

Expand Down
5 changes: 1 addition & 4 deletions src/distribution/log_normal.rs
Expand Up @@ -51,10 +51,7 @@ impl LogNormal {
if location.is_nan() || scale.is_nan() || scale <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(LogNormal {
location: location,
scale: scale,
})
Ok(LogNormal { location, scale })
}
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/multinomial.rs
Expand Up @@ -54,10 +54,7 @@ impl Multinomial {
if !super::internal::is_valid_multinomial(p, true) {
Err(StatsError::BadParams)
} else {
Ok(Multinomial {
p: p.to_vec(),
n: n,
})
Ok(Multinomial { p: p.to_vec(), n })
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/distribution/multivariate_normal.rs
Expand Up @@ -81,7 +81,7 @@ where
mu: mean.clone(),
cov: cov.clone(),
precision: cholesky_decomp.inverse(),
pdf_const: pdf_const,
pdf_const,
}),
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/normal.rs
Expand Up @@ -49,10 +49,7 @@ impl Normal {
if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Normal {
mean: mean,
std_dev: std_dev,
})
Ok(Normal { mean, std_dev })
}
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/pareto.rs
Expand Up @@ -51,10 +51,7 @@ impl Pareto {
if is_nan || scale <= 0.0 || shape <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Pareto {
scale: scale,
shape: shape,
})
Ok(Pareto { scale, shape })
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/distribution/poisson.rs
Expand Up @@ -49,7 +49,7 @@ impl Poisson {
if lambda.is_nan() || lambda <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Poisson { lambda: lambda })
Ok(Poisson { lambda })
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/distribution/students_t.rs
Expand Up @@ -54,9 +54,9 @@ impl StudentsT {
Err(StatsError::BadParams)
} else {
Ok(StudentsT {
location: location,
scale: scale,
freedom: freedom,
location,
scale,
freedom,
})
}
}
Expand Down
6 changes: 1 addition & 5 deletions src/distribution/triangular.rs
Expand Up @@ -59,11 +59,7 @@ impl Triangular {
if max == min {
return Err(StatsError::BadParams);
}
Ok(Triangular {
min: min,
max: max,
mode: mode,
})
Ok(Triangular { min, max, mode })
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/distribution/weibull.rs
Expand Up @@ -54,8 +54,8 @@ impl Weibull {
(_, _, true) => Err(StatsError::BadParams),
(_, _, false) if shape <= 0.0 || scale <= 0.0 => Err(StatsError::BadParams),
(_, _, false) => Ok(Weibull {
shape: shape,
scale: scale,
shape,
scale,
scale_pow_shape_inv: scale.powf(-shape),
}),
}
Expand Down

0 comments on commit 7bfcff6

Please sign in to comment.