Skip to content

Commit

Permalink
Optimize Bernoulli::new
Browse files Browse the repository at this point in the history
  • Loading branch information
pitdicker committed Jun 9, 2018
1 parent 7a30fef commit 6044cc8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
3 changes: 1 addition & 2 deletions benches/misc.rs
Expand Up @@ -63,9 +63,8 @@ fn misc_gen_ratio_var(b: &mut Bencher) {
#[bench]
fn misc_bernoulli_const(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
let d = rand::distributions::Bernoulli::new(0.18);
b.iter(|| {
// Can be evaluated at compile time.
let d = rand::distributions::Bernoulli::new(0.18);
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.sample(d);
Expand Down
51 changes: 33 additions & 18 deletions src/distributions/bernoulli.rs
Expand Up @@ -37,6 +37,31 @@ pub struct Bernoulli {
p_int: u64,
}

// To sample from the Bernoulli distribution we use a method that compares a
// random `u64` value `v < (p * 2^64)`.
//
// If `p == 1.0`, the integer `v` to compare against can not represented as a
// `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64).
// Note that value of `p < 1.0` can never result in `u64::MAX`, because an
// `f64` only has 53 bits of precision, and the next largest value of `p` will
// result in `2^64 - 2048`.
//
// Also there is a 100% theoretical concern: if someone consistenly wants to
// generate `true` using the Bernoulli distribution (i.e. by using a probability
// of `1.0`), just using `u64::MAX` is not enough. On average it would return
// false once every 2^64 iterations. Some people apparently care about this
// case.
//
// That is why we special-case `u64::MAX` to always return `true`, without using
// the RNG, and pay the performance price for all uses that *are* reasonable.
// Luckily, if `new()` and `sample` are close, the compiler can optimize out the
// extra check.
const ALWAYS_TRUE: u64 = ::core::u64::MAX;

// This is just `2.0.powi(64)`, but written this way because it is not available
// in `no_std` mode.
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;

impl Bernoulli {
/// Construct a new `Bernoulli` with the given probability of success `p`.
///
Expand All @@ -54,18 +79,11 @@ impl Bernoulli {
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
#[inline]
pub fn new(p: f64) -> Bernoulli {
assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0");
// Technically, this should be 2^64 or `u64::MAX + 1` because we compare
// using `<` when sampling. However, `u64::MAX` rounds to an `f64`
// larger than `u64::MAX` anyway.
const MAX_P_INT: f64 = ::core::u64::MAX as f64;
let p_int = if p < 1.0 {
(p * MAX_P_INT) as u64
} else {
// Avoid overflow: `MAX_P_INT` cannot be represented as u64.
::core::u64::MAX
};
Bernoulli { p_int }
if p < 0.0 || p >= 1.0 {
if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } }
panic!("Bernoulli::new not called with 0.0 <= p <= 1.0");
}
Bernoulli { p_int: (p * SCALE) as u64 }
}

/// Construct a new `Bernoulli` with the probability of success of
Expand All @@ -85,7 +103,6 @@ impl Bernoulli {
if numerator == denominator {
return Bernoulli { p_int: ::core::u64::MAX }
}
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
Bernoulli { p_int }
}
Expand All @@ -95,11 +112,9 @@ impl Distribution<bool> for Bernoulli {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
// Make sure to always return true for p = 1.0.
if self.p_int == ::core::u64::MAX {
return true;
}
let r: u64 = rng.gen();
r < self.p_int
if self.p_int == ALWAYS_TRUE { return true; }
let v: u64 = rng.gen();
v < self.p_int
}
}

Expand Down

0 comments on commit 6044cc8

Please sign in to comment.