Skip to content

Commit

Permalink
Merge pull request #740 from vks/faster-binomial2
Browse files Browse the repository at this point in the history
Binomial: Faster sampling for n * p >= 10
  • Loading branch information
dhardy committed Apr 1, 2019
2 parents 1eef88c + f8149ab commit e47c5a9
Showing 1 changed file with 174 additions and 45 deletions.
219 changes: 174 additions & 45 deletions src/distributions/binomial.rs
Expand Up @@ -10,8 +10,7 @@
//! The binomial distribution.

use Rng;
use distributions::{Distribution, Cauchy};
use distributions::utils::log_gamma;
use distributions::{Distribution, Uniform};

/// The binomial distribution `Binomial(n, p)`.
///
Expand Down Expand Up @@ -47,6 +46,13 @@ impl Binomial {
}
}

/// Convert a `f64` to an `i64`, panicing on overflow.
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
fn f64_to_i64(x: f64) -> i64 {
assert!(x < (::std::i64::MAX as f64));
x as i64
}

impl Distribution<u64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
// Handle these values directly.
Expand All @@ -56,25 +62,33 @@ impl Distribution<u64> for Binomial {
return self.n;
}

// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
// switch p so that it is less than 0.5 - this allows for lower expected values
// we will just invert the result at the end
// The binomial distribution is symmetrical with respect to p -> 1-p,
// k -> n-k switch p so that it is less than 0.5 - this allows for lower
// expected values we will just invert the result at the end
let p = if self.p <= 0.5 {
self.p
} else {
1.0 - self.p
};

let result;
let q = 1. - p;

// For small n * min(p, 1 - p), the BINV algorithm based on the inverse
// transformation of the binomial distribution is more efficient:
// transformation of the binomial distribution is efficient. Otherwise,
// the BTPE algorithm is used.
//
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
// random variate generation. Commun. ACM 31, 2 (February 1988),
// 216-222. http://dx.doi.org/10.1145/42372.42381
if (self.n as f64) * p < 10. && self.n <= (::std::i32::MAX as u64) {
let q = 1. - p;

// Threshold for prefering the BINV algorithm. The paper suggests 10,
// Ranlib uses 30, and GSL uses 14.
const BINV_THRESHOLD: f64 = 10.;

if (self.n as f64) * p < BINV_THRESHOLD &&
self.n <= (::std::i32::MAX as u64) {
// Use the BINV algorithm.
let s = p / q;
let a = ((self.n + 1) as f64) * s;
let mut r = q.powi(self.n as i32);
Expand All @@ -87,52 +101,165 @@ impl Distribution<u64> for Binomial {
}
result = x;
} else {
// FIXME: Using the BTPE algorithm is probably faster.

// prepare some cached values
let float_n = self.n as f64;
let ln_fact_n = log_gamma(float_n + 1.0);
let pc = 1.0 - p;
let log_p = p.ln();
let log_pc = pc.ln();
let expected = self.n as f64 * p;
let sq = (expected * (2.0 * pc)).sqrt();
let mut lresult;

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(0.0, 1.0);
// Use the BTPE algorithm.

// Threshold for using the squeeze algorithm. This can be freely
// chosen based on performance. Ranlib and GSL use 20.
const SQUEEZE_THRESHOLD: i64 = 20;

// Step 0: Calculate constants as functions of `n` and `p`.
let n = self.n as f64;
let np = n * p;
let npq = np * q;
let f_m = np + p;
let m = f64_to_i64(f_m);
// radius of triangle region, since height=1 also area of region
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
// tip of triangle
let x_m = (m as f64) + 0.5;
// left edge of triangle
let x_l = x_m - p1;
// right edge of triangle
let x_r = x_m + p1;
let c = 0.134 + 20.5 / (15.3 + (m as f64));
// p1 + area of parallelogram region
let p2 = p1 * (1. + 2. * c);

fn lambda(a: f64) -> f64 {
a * (1. + 0.5 * a)
}

let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
let lambda_r = lambda((x_r - f_m) / (x_r * q));
// p1 + area of left tail
let p3 = p2 + c / lambda_l;
// p1 + area of right tail
let p4 = p3 + c / lambda_r;

// return value
let mut y: i64;

let gen_u = Uniform::new(0., p4);
let gen_v = Uniform::new(0., 1.);

loop {
let mut comp_dev: f64;
loop {
// draw from the Cauchy distribution
comp_dev = rng.sample(cauchy);
// shift the peak of the comparison ditribution
lresult = expected + sq * comp_dev;
// repeat the drawing until we are in the range of possible values
if lresult >= 0.0 && lresult < float_n + 1.0 {
break;
// Step 1: Generate `u` for selecting the region. If region 1 is
// selected, generate a triangularly distributed variate.
let u = gen_u.sample(rng);
let mut v = gen_v.sample(rng);
if !(u > p1) {
y = f64_to_i64(x_m - p1 * v + u);
break;
}

if !(u > p2) {
// Step 2: Region 2, parallelograms. Check if region 2 is
// used. If so, generate `y`.
let x = x_l + (u - p1) / c;
v = v * c + 1.0 - (x - x_m).abs() / p1;
if v > 1. {
continue;
} else {
y = f64_to_i64(x);
}
} else if !(u > p3) {
// Step 3: Region 3, left exponential tail.
y = f64_to_i64(x_l + v.ln() / lambda_l);
if y < 0 {
continue;
} else {
v *= (u - p2) * lambda_l;
}
} else {
// Step 4: Region 4, right exponential tail.
y = f64_to_i64(x_r - v.ln() / lambda_r);
if y > 0 && (y as u64) > self.n {
continue;
} else {
v *= (u - p3) * lambda_r;
}
}

// the result should be discrete
lresult = lresult.floor();
// Step 5: Acceptance/rejection comparison.

let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
// this is the binomial probability divided by the comparison probability
// we will generate a uniform random value and if it is larger than this,
// we interpret it as a value falling out of the distribution and repeat
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
// Step 5.0: Test for appropriate method of evaluating f(y).
let k = (y - m).abs();
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
// Step 5.1: Evaluate f(y) via the recursive relationship. Start the
// search from the mode.
let s = p / q;
let a = s * (n + 1.);
let mut f = 1.0;
if m < y {
let mut i = m;
loop {
i += 1;
f *= a / (i as f64) - s;
if i == y {
break;
}
}
} else if m > y {
let mut i = y;
loop {
i += 1;
f /= a / (i as f64) - s;
if i == m {
break;
}
}
}
if v > f {
continue;
} else {
break;
}
}

if comparison_coeff >= rng.gen() {
// Step 5.2: Squeezing. Check the value of ln(v) againts upper and
// lower bound of ln(f(y)).
let k = k as f64;
let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1./6.) / npq + 0.5);
let t = -0.5 * k*k / npq;
let alpha = v.ln();
if alpha < t - rho {
break;
}
if alpha > t + rho {
continue;
}

// Step 5.3: Final acceptance/rejection test.
let x1 = (y + 1) as f64;
let f1 = (m + 1) as f64;
let z = (f64_to_i64(n) + 1 - m) as f64;
let w = (f64_to_i64(n) - y + 1) as f64;

fn stirling(a: f64) -> f64 {
let a2 = a * a;
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
}

if alpha > x_m * (f1 / x1).ln()
+ (n - (m as f64) + 0.5) * (z / w).ln()
+ ((y - m) as f64) * (w * p / (x1 * q)).ln()
// We use the signs from the GSL implementation, which are
// different than the ones in the reference. According to
// the GSL authors, the new signs were verified to be
// correct by one of the original designers of the
// algorithm.
+ stirling(f1) + stirling(z) - stirling(x1) - stirling(w)
{
continue;
}

break;
}
result = lresult as u64;
assert!(y >= 0);
result = y as u64;
}

// invert the result for p < 0.5
// Invert the result for p < 0.5.
if p != self.p {
self.n - result
} else {
Expand All @@ -157,12 +284,14 @@ mod test {
for i in results.iter_mut() { *i = binomial.sample(rng) as f64; }

let mean = results.iter().sum::<f64>() / results.len() as f64;
assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0,
"mean: {}, expected_mean: {}", mean, expected_mean);

let variance =
results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>()
/ results.len() as f64;
assert!((variance - expected_variance).abs() < expected_variance / 10.0);
assert!((variance - expected_variance).abs() < expected_variance / 10.0,
"variance: {}, expected_variance: {}", variance, expected_variance);
}

#[test]
Expand Down

0 comments on commit e47c5a9

Please sign in to comment.