From 7facd3f9a8a5a901603225c610d81b35fb4071bb Mon Sep 17 00:00:00 2001 From: Paul Dicker Date: Thu, 7 Jun 2018 08:25:21 +0200 Subject: [PATCH] Optimize Bernoulli::new --- benches/misc.rs | 3 +-- src/distributions/bernoulli.rs | 49 ++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/benches/misc.rs b/benches/misc.rs index a1822a53a83..6ccb6e5105b 100644 --- a/benches/misc.rs +++ b/benches/misc.rs @@ -39,9 +39,8 @@ fn misc_gen_bool_var(b: &mut Bencher) { #[bench] fn misc_bernoulli_const(b: &mut Bencher) { let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap(); - let d = rand::distributions::Bernoulli::new(0.18); b.iter(|| { - // Can be evaluated at compile time. + let d = rand::distributions::Bernoulli::new(0.18); let mut accum = true; for _ in 0..::RAND_BENCH_N { accum ^= rng.sample(d); diff --git a/src/distributions/bernoulli.rs b/src/distributions/bernoulli.rs index 2361fac0c21..04cbb966684 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distributions/bernoulli.rs @@ -37,6 +37,27 @@ pub struct Bernoulli { p_int: u64, } +// To sample from the Bernoulli distribution we use a method that compares a +// random `u64` value `v < (p * 2^64)`. +// +// If `p == 1.0`, the integer `v` to compare against can not represented as a +// `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64). +// Note that value of `p < 1.0` can never result in `u64::MAX`, because an +// `f64` only has 53 bits of precision, and the next largest value of `p` will +// result in `2^64 - 2048`. +// +// Also there is a 100% theoretical concern: if someone consistenly wants to +// generate `true` using the Bernoulli distribution (i.e. by using a probability +// of `1.0`), just using `u64::MAX` is not enough. On average it would return +// false once every 2^64 iterations. Some people apparently care about this +// case. +// +// That is why we special-case `u64::MAX` to always return `true`, without using +// the RNG, and pay the performance price for all uses that *are* reasonable. +// Luckily, if `new()` and `sample` are close, the compiler can optimize out the +// extra check. +const ALWAYS_TRUE: u64 = ::core::u64::MAX; + impl Bernoulli { /// Construct a new `Bernoulli` with the given probability of success `p`. /// @@ -54,18 +75,14 @@ impl Bernoulli { /// 2-64 in `[0, 1]` can be represented as a `f64`.) #[inline] pub fn new(p: f64) -> Bernoulli { - assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0"); - // Technically, this should be 2^64 or `u64::MAX + 1` because we compare - // using `<` when sampling. However, `u64::MAX` rounds to an `f64` - // larger than `u64::MAX` anyway. - const MAX_P_INT: f64 = ::core::u64::MAX as f64; - let p_int = if p < 1.0 { - (p * MAX_P_INT) as u64 - } else { - // Avoid overflow: `MAX_P_INT` cannot be represented as u64. - ::core::u64::MAX - }; - Bernoulli { p_int } + if p < 0.0 || p >= 1.0 { + if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } } + panic!("Bernoulli::new not called with 0.0 <= p <= 1.0"); + } + // This is just `2.0.powi(64)`, but written this way because it is not + // available in `no_std` mode. + const SCALE: f64 = 2.0 * (1u64 << 63) as f64; + Bernoulli { p_int: (p * SCALE) as u64 } } } @@ -73,11 +90,9 @@ impl Distribution for Bernoulli { #[inline] fn sample(&self, rng: &mut R) -> bool { // Make sure to always return true for p = 1.0. - if self.p_int == ::core::u64::MAX { - return true; - } - let r: u64 = rng.gen(); - r < self.p_int + if self.p_int == ALWAYS_TRUE { return true; } + let v: u64 = rng.gen(); + v < self.p_int } }