Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binomial: Faster sampling for n * p < 10 #735

Merged
merged 3 commits into from Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions benches/distributions.rs
Expand Up @@ -203,6 +203,7 @@ distr_float!(distr_gamma_large_shape, f64, Gamma::new(10., 1.0));
distr_float!(distr_gamma_small_shape, f64, Gamma::new(0.1, 1.0));
distr_float!(distr_cauchy, f64, Cauchy::new(4.2, 6.9));
distr_int!(distr_binomial, u64, Binomial::new(20, 0.7));
distr_int!(distr_binomial_small, u64, Binomial::new(1000000, 1e-30));
distr_int!(distr_poisson, u64, Poisson::new(4.0));
distr!(distr_bernoulli, bool, Bernoulli::new(0.18));
distr_arr!(distr_circle, [f64; 2], UnitCircle::new());
Expand Down
138 changes: 88 additions & 50 deletions src/distributions/binomial.rs
Expand Up @@ -10,7 +10,7 @@
//! The binomial distribution.

use Rng;
use distributions::{Distribution, Bernoulli, Cauchy};
use distributions::{Distribution, Cauchy};
use distributions::utils::log_gamma;

/// The binomial distribution `Binomial(n, p)`.
Expand Down Expand Up @@ -47,6 +47,31 @@ impl Binomial {
}
}

/// Raise a `base` to the power of `exp`, using exponentiation by squaring.
///
/// This implementation is based on the one in the `num_traits` crate. It is
/// slightly modified to accept `u64` exponents.
fn pow(mut base: f64, mut exp: u64) -> f64 {
if exp == 0 {
return 1.;
}

while exp & 1 == 0 {
base *= base;
exp >>= 1;
}

let mut acc = base;
while exp > 1 {
exp >>= 1;
base *= base;
if exp & 1 == 1 {
acc *= base;
}
}
acc
}

impl Distribution<u64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
// Handle these values directly.
Expand All @@ -55,19 +80,7 @@ impl Distribution<u64> for Binomial {
} else if self.p == 1.0 {
return self.n;
}

// For low n, it is faster to sample directly. For both methods,
// performance is independent of p. On Intel Haswell CPU this method
// appears to be faster for approx n < 300.
if self.n < 300 {
let mut result = 0;
let d = Bernoulli::new(self.p);
for _ in 0 .. self.n {
result += rng.sample(d) as u32;
}
return result as u64;
}


// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
// switch p so that it is less than 0.5 - this allows for lower expected values
// we will just invert the result at the end
Expand All @@ -77,53 +90,78 @@ impl Distribution<u64> for Binomial {
1.0 - self.p
};

// prepare some cached values
let float_n = self.n as f64;
let ln_fact_n = log_gamma(float_n + 1.0);
let pc = 1.0 - p;
let log_p = p.ln();
let log_pc = pc.ln();
let expected = self.n as f64 * p;
let sq = (expected * (2.0 * pc)).sqrt();

let mut lresult;

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(0.0, 1.0);
loop {
let mut comp_dev: f64;
let result;

// For small n * min(p, 1 - p), the BINV algorithm based on the inverse
// transformation of the binomial distribution is more efficient:
//
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
// random variate generation. Commun. ACM 31, 2 (February 1988),
// 216-222. http://dx.doi.org/10.1145/42372.42381
if (self.n as f64) * p < 10. {
let q = 1. - p;
let s = p / q;
let a = ((self.n + 1) as f64) * s;
let mut r = pow(q, self.n);
let mut u: f64 = rng.gen();
let mut x = 0;
while u > r as f64 {
u -= r;
x += 1;
r *= a / (x as f64) - s;
}
result = x;
} else {
// FIXME: Using the BTPE algorithm is probably faster.

// prepare some cached values
let float_n = self.n as f64;
let ln_fact_n = log_gamma(float_n + 1.0);
let pc = 1.0 - p;
let log_p = p.ln();
let log_pc = pc.ln();
let expected = self.n as f64 * p;
let sq = (expected * (2.0 * pc)).sqrt();
let mut lresult;

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(0.0, 1.0);
loop {
// draw from the Cauchy distribution
comp_dev = rng.sample(cauchy);
// shift the peak of the comparison ditribution
lresult = expected + sq * comp_dev;
// repeat the drawing until we are in the range of possible values
if lresult >= 0.0 && lresult < float_n + 1.0 {
break;
let mut comp_dev: f64;
loop {
// draw from the Cauchy distribution
comp_dev = rng.sample(cauchy);
// shift the peak of the comparison ditribution
lresult = expected + sq * comp_dev;
// repeat the drawing until we are in the range of possible values
if lresult >= 0.0 && lresult < float_n + 1.0 {
break;
}
}
}

// the result should be discrete
lresult = lresult.floor();
// the result should be discrete
lresult = lresult.floor();

let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
// this is the binomial probability divided by the comparison probability
// we will generate a uniform random value and if it is larger than this,
// we interpret it as a value falling out of the distribution and repeat
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
// this is the binomial probability divided by the comparison probability
// we will generate a uniform random value and if it is larger than this,
// we interpret it as a value falling out of the distribution and repeat
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));

if comparison_coeff >= rng.gen() {
break;
if comparison_coeff >= rng.gen() {
break;
}
}
result = lresult as u64;
}

// invert the result for p < 0.5
if p != self.p {
self.n - lresult as u64
self.n - result
} else {
lresult as u64
result
}
}
}
Expand Down