Skip to content

Commit

Permalink
Merge pull request #735 from vks/faster-binomial
Browse files Browse the repository at this point in the history
Binomial: Faster sampling for n * p < 10
  • Loading branch information
dhardy committed Feb 27, 2019
2 parents 270682e + 5f6f89e commit a9f6b8f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 50 deletions.
1 change: 1 addition & 0 deletions benches/distributions.rs
Expand Up @@ -203,6 +203,7 @@ distr_float!(distr_gamma_large_shape, f64, Gamma::new(10., 1.0));
distr_float!(distr_gamma_small_shape, f64, Gamma::new(0.1, 1.0));
distr_float!(distr_cauchy, f64, Cauchy::new(4.2, 6.9));
distr_int!(distr_binomial, u64, Binomial::new(20, 0.7));
distr_int!(distr_binomial_small, u64, Binomial::new(1000000, 1e-30));
distr_int!(distr_poisson, u64, Poisson::new(4.0));
distr!(distr_bernoulli, bool, Bernoulli::new(0.18));
distr_arr!(distr_circle, [f64; 2], UnitCircle::new());
Expand Down
138 changes: 88 additions & 50 deletions src/distributions/binomial.rs
Expand Up @@ -10,7 +10,7 @@
//! The binomial distribution.

use Rng;
use distributions::{Distribution, Bernoulli, Cauchy};
use distributions::{Distribution, Cauchy};
use distributions::utils::log_gamma;

/// The binomial distribution `Binomial(n, p)`.
Expand Down Expand Up @@ -47,6 +47,31 @@ impl Binomial {
}
}

/// Raise a `base` to the power of `exp`, using exponentiation by squaring.
///
/// This implementation is based on the one in the `num_traits` crate. It is
/// slightly modified to accept `u64` exponents.
fn pow(mut base: f64, mut exp: u64) -> f64 {
if exp == 0 {
return 1.;
}

while exp & 1 == 0 {
base *= base;
exp >>= 1;
}

let mut acc = base;
while exp > 1 {
exp >>= 1;
base *= base;
if exp & 1 == 1 {
acc *= base;
}
}
acc
}

impl Distribution<u64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
// Handle these values directly.
Expand All @@ -55,19 +80,7 @@ impl Distribution<u64> for Binomial {
} else if self.p == 1.0 {
return self.n;
}

// For low n, it is faster to sample directly. For both methods,
// performance is independent of p. On Intel Haswell CPU this method
// appears to be faster for approx n < 300.
if self.n < 300 {
let mut result = 0;
let d = Bernoulli::new(self.p);
for _ in 0 .. self.n {
result += rng.sample(d) as u32;
}
return result as u64;
}


// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
// switch p so that it is less than 0.5 - this allows for lower expected values
// we will just invert the result at the end
Expand All @@ -77,53 +90,78 @@ impl Distribution<u64> for Binomial {
1.0 - self.p
};

// prepare some cached values
let float_n = self.n as f64;
let ln_fact_n = log_gamma(float_n + 1.0);
let pc = 1.0 - p;
let log_p = p.ln();
let log_pc = pc.ln();
let expected = self.n as f64 * p;
let sq = (expected * (2.0 * pc)).sqrt();

let mut lresult;

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(0.0, 1.0);
loop {
let mut comp_dev: f64;
let result;

// For small n * min(p, 1 - p), the BINV algorithm based on the inverse
// transformation of the binomial distribution is more efficient:
//
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
// random variate generation. Commun. ACM 31, 2 (February 1988),
// 216-222. http://dx.doi.org/10.1145/42372.42381
if (self.n as f64) * p < 10. {
let q = 1. - p;
let s = p / q;
let a = ((self.n + 1) as f64) * s;
let mut r = pow(q, self.n);
let mut u: f64 = rng.gen();
let mut x = 0;
while u > r as f64 {
u -= r;
x += 1;
r *= a / (x as f64) - s;
}
result = x;
} else {
// FIXME: Using the BTPE algorithm is probably faster.

// prepare some cached values
let float_n = self.n as f64;
let ln_fact_n = log_gamma(float_n + 1.0);
let pc = 1.0 - p;
let log_p = p.ln();
let log_pc = pc.ln();
let expected = self.n as f64 * p;
let sq = (expected * (2.0 * pc)).sqrt();
let mut lresult;

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(0.0, 1.0);
loop {
// draw from the Cauchy distribution
comp_dev = rng.sample(cauchy);
// shift the peak of the comparison ditribution
lresult = expected + sq * comp_dev;
// repeat the drawing until we are in the range of possible values
if lresult >= 0.0 && lresult < float_n + 1.0 {
break;
let mut comp_dev: f64;
loop {
// draw from the Cauchy distribution
comp_dev = rng.sample(cauchy);
// shift the peak of the comparison ditribution
lresult = expected + sq * comp_dev;
// repeat the drawing until we are in the range of possible values
if lresult >= 0.0 && lresult < float_n + 1.0 {
break;
}
}
}

// the result should be discrete
lresult = lresult.floor();
// the result should be discrete
lresult = lresult.floor();

let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
// this is the binomial probability divided by the comparison probability
// we will generate a uniform random value and if it is larger than this,
// we interpret it as a value falling out of the distribution and repeat
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
// this is the binomial probability divided by the comparison probability
// we will generate a uniform random value and if it is larger than this,
// we interpret it as a value falling out of the distribution and repeat
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));

if comparison_coeff >= rng.gen() {
break;
if comparison_coeff >= rng.gen() {
break;
}
}
result = lresult as u64;
}

// invert the result for p < 0.5
if p != self.p {
self.n - lresult as u64
self.n - result
} else {
lresult as u64
result
}
}
}
Expand Down

0 comments on commit a9f6b8f

Please sign in to comment.