Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A stab at removing unsafes #109

Merged
merged 6 commits into from Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion Cargo.toml
Expand Up @@ -19,4 +19,11 @@ path = "src/lib.rs"

[dependencies]
rand = "0.7"
nalgebra = "0.19"
nalgebra = "0.19"

[dev-dependencies]
criterion = "*"

[[bench]]
name = "order_statistics"
harness = false
93 changes: 93 additions & 0 deletions benches/order_statistics.rs
@@ -0,0 +1,93 @@
extern crate rand;
extern crate statrs;
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use rand::prelude::*;
use statrs::statistics::*;

fn bench_order_statistic(c: &mut Criterion) {
let mut rng = thread_rng();
let to_random_owned = |data: &[f64]| -> Vec<f64> {
let mut rng = thread_rng();
let mut owned = data.to_vec();
owned.shuffle(&mut rng);
owned
};
let k = black_box(rng.gen());
let tau = black_box(rng.gen_range(0.0, 1.0));
let mut group = c.benchmark_group("order statistic");
let data: Vec<_> = (0..100).map(|x| x as f64).collect();
group.bench_function("order_statistic", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.order_statistic(k),
BatchSize::SmallInput,
)
});
group.bench_function("median", |b| {
b.iter_batched(
|| to_random_owned(&data),
|data| data.median(),
BatchSize::SmallInput,
)
});
group.bench_function("quantile", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.quantile(tau),
BatchSize::SmallInput,
)
});
group.bench_function("percentile", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.percentile(k),
BatchSize::SmallInput,
)
});
group.bench_function("lower_quartile", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.lower_quartile(),
BatchSize::SmallInput,
)
});
group.bench_function("upper_quartile", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.upper_quartile(),
BatchSize::SmallInput,
)
});
group.bench_function("interquartile_range", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.interquartile_range(),
BatchSize::SmallInput,
)
});
group.bench_function("ranks: RankTieBreaker::First", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.ranks(RankTieBreaker::First),
BatchSize::SmallInput,
)
});
group.bench_function("ranks: RankTieBreaker::Average", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.ranks(RankTieBreaker::Average),
BatchSize::SmallInput,
)
});
group.bench_function("ranks: RankTieBreaker::Min", |b| {
b.iter_batched(
|| to_random_owned(&data),
|mut data| data.ranks(RankTieBreaker::Min),
BatchSize::SmallInput,
)
});
group.finish();
}

criterion_group!(benches, bench_order_statistic);
criterion_main!(benches);
5 changes: 1 addition & 4 deletions src/distribution/beta.rs
Expand Up @@ -51,10 +51,7 @@ impl Beta {
match (shape_a, shape_b, is_nan) {
(_, _, true) => Err(StatsError::BadParams),
(_, _, false) if shape_a <= 0.0 || shape_b <= 0.0 => Err(StatsError::BadParams),
(_, _, false) => Ok(Beta {
shape_a: shape_a,
shape_b: shape_b,
}),
(_, _, false) => Ok(Beta { shape_a, shape_b }),
}
}

Expand Down
67 changes: 21 additions & 46 deletions src/distribution/categorical.rs
Expand Up @@ -62,21 +62,16 @@ impl Categorical {
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
for i in 0..prob_mass.len() {
unsafe {
let elem = norm_pmf.get_unchecked_mut(i);
*elem = prob_mass.get_unchecked(i) / sum;
}
}
Ok(Categorical {
norm_pmf: norm_pmf,
cdf: cdf,
})
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf })
}
}

fn cdf_max(&self) -> f64 {
*unsafe { self.cdf.get_unchecked(self.cdf.len() - 1) }
*self.cdf.last().unwrap()
}
}

Expand All @@ -103,7 +98,7 @@ impl Univariate<u64, f64> for Categorical {
} else if x >= self.cdf.len() as f64 {
1.0
} else {
unsafe { self.cdf.get_unchecked(x as usize) / self.cdf_max() }
self.cdf.get(x as usize).unwrap() / self.cdf_max()
}
}
}
Expand Down Expand Up @@ -293,11 +288,7 @@ impl Discrete<u64, f64> for Categorical {
/// p_x
/// ```
fn pmf(&self, x: u64) -> f64 {
if x >= self.norm_pmf.len() as u64 {
0.0
} else {
unsafe { *self.norm_pmf.get_unchecked(x as usize) }
}
*self.norm_pmf.get(x as usize).unwrap_or(&0.0)
}

/// Calculates the log probability mass function for the categorical
Expand All @@ -310,39 +301,23 @@ impl Discrete<u64, f64> for Categorical {
/// Draws a sample from the categorical distribution described by `cdf`
/// without doing any bounds checking
pub fn sample_unchecked<R: Rng + ?Sized>(r: &mut R, cdf: &[f64]) -> f64 {
let draw = r.gen::<f64>() * unsafe { cdf.get_unchecked(cdf.len() - 1) };
let mut idx = 0;

if draw == 0.0 {
// skip zero-probability categories
let mut el = unsafe { cdf.get_unchecked(idx) };
while *el == 0.0 {
// don't need bounds checking because we do not allow
// creating Categorical distributions with all 0.0 probs
idx += 1;
el = unsafe { cdf.get_unchecked(idx) }
}
}
let mut el = unsafe { cdf.get_unchecked(idx) };
while draw > *el {
idx += 1;
el = unsafe { cdf.get_unchecked(idx) };
}
idx as f64
let draw = r.gen::<f64>() * cdf.last().unwrap();
cdf.iter()
.enumerate()
.find(|(_, val)| **val >= draw)
.map(|(i, _)| i)
.unwrap() as f64
}

/// Computes the cdf from the given probability masses. Performs
/// no parameter or bounds checking.
pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
let mut cdf = vec![0.0; prob_mass.len()];
cdf[0] = prob_mass[0];
for i in 1..prob_mass.len() {
unsafe {
let val = cdf.get_unchecked(i - 1) + prob_mass.get_unchecked(i);
let elem = cdf.get_unchecked_mut(i);
*elem = val;
}
}
let mut cdf = Vec::with_capacity(prob_mass.len());
prob_mass.iter().fold(0.0, |s, p| {
let sum = s + p;
cdf.push(sum);
sum
});
cdf
}

Expand All @@ -358,7 +333,7 @@ fn binary_index(search: &[f64], val: f64) -> usize {
let mut high = search.len() as isize - 1;
while low <= high {
let mid = low + ((high - low) / 2);
let el = *unsafe { search.get_unchecked(mid as usize) };
let el = *search.get(mid as usize).unwrap();
if el > val {
high = mid - 1;
} else if el < val {
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/cauchy.rs
Expand Up @@ -47,10 +47,7 @@ impl Cauchy {
if location.is_nan() || scale.is_nan() || scale <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Cauchy {
location: location,
scale: scale,
})
Ok(Cauchy { location, scale })
}
}

Expand Down
5 changes: 1 addition & 4 deletions src/distribution/chi_squared.rs
Expand Up @@ -50,10 +50,7 @@ impl ChiSquared {
/// assert!(result.is_err());
/// ```
pub fn new(freedom: f64) -> Result<ChiSquared> {
Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared {
freedom: freedom,
g: g,
})
Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g })
}

/// Returns the degrees of freedom of the chi-squared
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/gamma.rs
Expand Up @@ -51,10 +51,7 @@ impl Gamma {
match (shape, rate, is_nan) {
(_, _, true) => Err(StatsError::BadParams),
(_, _, false) if shape <= 0.0 || rate <= 0.0 => Err(StatsError::BadParams),
(_, _, false) => Ok(Gamma {
shape: shape,
rate: rate,
}),
(_, _, false) => Ok(Gamma { shape, rate }),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/distribution/geometric.rs
Expand Up @@ -49,7 +49,7 @@ impl Geometric {
if p <= 0.0 || p > 1.0 || p.is_nan() {
Err(StatsError::BadParams)
} else {
Ok(Geometric { p: p })
Ok(Geometric { p })
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/distribution/hypergeometric.rs
Expand Up @@ -48,9 +48,9 @@ impl Hypergeometric {
Err(StatsError::BadParams)
} else {
Ok(Hypergeometric {
population: population,
successes: successes,
draws: draws,
population,
successes,
draws,
})
}
}
Expand Down
11 changes: 5 additions & 6 deletions src/distribution/internal.rs
Expand Up @@ -3,16 +3,15 @@
/// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0`
pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool {
let mut sum = 0.0;
for i in 0..arr.len() {
let el = *unsafe { arr.get_unchecked(i) };
if incl_zero && el < 0.0 {
for &elt in arr {
if incl_zero && elt < 0.0 {
return false;
} else if !incl_zero && el <= 0.0 {
} else if !incl_zero && elt <= 0.0 {
return false;
} else if el.is_nan() {
} else if elt.is_nan() {
return false;
}
sum += el;
sum += elt;
}
sum != 0.0
}
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/inverse_gamma.rs
Expand Up @@ -55,10 +55,7 @@ impl InverseGamma {
(_, _, false) if shape == f64::INFINITY || rate == f64::INFINITY => {
Err(StatsError::BadParams)
}
(_, _, false) => Ok(InverseGamma {
shape: shape,
rate: rate,
}),
(_, _, false) => Ok(InverseGamma { shape, rate }),
}
}

Expand Down
5 changes: 1 addition & 4 deletions src/distribution/log_normal.rs
Expand Up @@ -51,10 +51,7 @@ impl LogNormal {
if location.is_nan() || scale.is_nan() || scale <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(LogNormal {
location: location,
scale: scale,
})
Ok(LogNormal { location, scale })
}
}
}
Expand Down
9 changes: 3 additions & 6 deletions src/distribution/multinomial.rs
Expand Up @@ -54,10 +54,7 @@ impl Multinomial {
if !super::internal::is_valid_multinomial(p, true) {
Err(StatsError::BadParams)
} else {
Ok(Multinomial {
p: p.to_vec(),
n: n,
})
Ok(Multinomial { p: p.to_vec(), n })
}
}

Expand Down Expand Up @@ -98,8 +95,8 @@ impl Distribution<Vec<f64>> for Multinomial {
let mut res = vec![0.0; self.p.len()];
for _ in 0..self.n {
let i = super::categorical::sample_unchecked(r, &p_cdf);
let el = unsafe { res.get_unchecked_mut(i as usize) };
*el = *el + 1.0;
let el = res.get_mut(i as usize).unwrap();
*el += 1.0;
}
res
}
Expand Down
2 changes: 1 addition & 1 deletion src/distribution/multivariate_normal.rs
Expand Up @@ -81,7 +81,7 @@ where
mu: mean.clone(),
cov: cov.clone(),
precision: cholesky_decomp.inverse(),
pdf_const: pdf_const,
pdf_const,
}),
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/normal.rs
Expand Up @@ -49,10 +49,7 @@ impl Normal {
if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Normal {
mean: mean,
std_dev: std_dev,
})
Ok(Normal { mean, std_dev })
}
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/distribution/pareto.rs
Expand Up @@ -51,10 +51,7 @@ impl Pareto {
if is_nan || scale <= 0.0 || shape <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Pareto {
scale: scale,
shape: shape,
})
Ok(Pareto { scale, shape })
}
}

Expand Down