Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binomial: Faster sampling for n * p >= 10 #740

Merged
merged 7 commits into from Apr 1, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
219 changes: 174 additions & 45 deletions src/distributions/binomial.rs
Expand Up @@ -10,8 +10,7 @@
//! The binomial distribution.

use Rng;
use distributions::{Distribution, Cauchy};
use distributions::utils::log_gamma;
use distributions::{Distribution, Uniform};

/// The binomial distribution `Binomial(n, p)`.
///
Expand Down Expand Up @@ -47,6 +46,13 @@ impl Binomial {
}
}

/// Convert a `f64` to an `i64`, panicing on overflow.
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
fn f64_to_i64(x: f64) -> i64 {
assert!(x < (::std::i64::MAX as f64));
x as i64
}

impl Distribution<u64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
// Handle these values directly.
Expand All @@ -56,25 +62,33 @@ impl Distribution<u64> for Binomial {
return self.n;
}

// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
// switch p so that it is less than 0.5 - this allows for lower expected values
// we will just invert the result at the end
// The binomial distribution is symmetrical with respect to p -> 1-p,
// k -> n-k switch p so that it is less than 0.5 - this allows for lower
// expected values we will just invert the result at the end
let p = if self.p <= 0.5 {
self.p
} else {
1.0 - self.p
};

let result;
let q = 1. - p;

// For small n * min(p, 1 - p), the BINV algorithm based on the inverse
// transformation of the binomial distribution is more efficient:
// transformation of the binomial distribution is efficient. Otherwise,
// the BTPE algorithm is used.
//
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
// random variate generation. Commun. ACM 31, 2 (February 1988),
// 216-222. http://dx.doi.org/10.1145/42372.42381
if (self.n as f64) * p < 10. && self.n <= (::std::i32::MAX as u64) {
let q = 1. - p;

// Threshold for prefering the BINV algorithm. The paper suggests 10,
// Ranlib uses 30, and GSL uses 14.
const BINV_THRESHOLD: f64 = 10.;

if (self.n as f64) * p < BINV_THRESHOLD &&
self.n <= (::std::i32::MAX as u64) {
// Use the BINV algorithm.
let s = p / q;
let a = ((self.n + 1) as f64) * s;
let mut r = q.powi(self.n as i32);
Expand All @@ -87,52 +101,165 @@ impl Distribution<u64> for Binomial {
}
result = x;
} else {
// FIXME: Using the BTPE algorithm is probably faster.

// prepare some cached values
let float_n = self.n as f64;
let ln_fact_n = log_gamma(float_n + 1.0);
let pc = 1.0 - p;
let log_p = p.ln();
let log_pc = pc.ln();
let expected = self.n as f64 * p;
let sq = (expected * (2.0 * pc)).sqrt();
let mut lresult;

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(0.0, 1.0);
// Use the BTPE algorithm.

// Threshold for using the squeeze algorithm. This can be freely
// chosen based on performance. Ranlib and GSL use 20.
const SQUEEZE_THRESHOLD: i64 = 20;

// Step 0: Calculate constants as functions of `n` and `p`.
let n = self.n as f64;
let np = n * p;
let npq = np * q;
let f_m = np + p;
let m = f64_to_i64(f_m);
// radius of triangle region, since height=1 also area of region
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
// tip of triangle
let x_m = (m as f64) + 0.5;
// left edge of triangle
let x_l = x_m - p1;
// right edge of triangle
let x_r = x_m + p1;
let c = 0.134 + 20.5 / (15.3 + (m as f64));
// p1 + area of parallelogram region
let p2 = p1 * (1. + 2. * c);

fn lambda(a: f64) -> f64 {
a * (1. + 0.5 * a)
}

let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
let lambda_r = lambda((x_r - f_m) / (x_r * q));
// p1 + area of left tail
let p3 = p2 + c / lambda_l;
// p1 + area of right tail
let p4 = p3 + c / lambda_r;

// return value
let mut y: i64;

let gen_u = Uniform::new(0., p4);
let gen_v = Uniform::new(0., 1.);

loop {
let mut comp_dev: f64;
loop {
// draw from the Cauchy distribution
comp_dev = rng.sample(cauchy);
// shift the peak of the comparison ditribution
lresult = expected + sq * comp_dev;
// repeat the drawing until we are in the range of possible values
if lresult >= 0.0 && lresult < float_n + 1.0 {
break;
// Step 1: Generate `u` for selecting the region. If region 1 is
// selected, generate a triangularly distributed variate.
let u = gen_u.sample(rng);
let mut v = gen_v.sample(rng);
if !(u > p1) {
y = f64_to_i64(x_m - p1 * v + u);
break;
}

if !(u > p2) {
// Step 2: Region 2, parallelograms. Check if region 2 is
// used. If so, generate `y`.
let x = x_l + (u - p1) / c;
v = v * c + 1.0 - (x - x_m).abs() / p1;
if v > 1. {
continue;
} else {
y = f64_to_i64(x);
}
} else if !(u > p3) {
// Step 3: Region 3, left exponential tail.
y = f64_to_i64(x_l + v.ln() / lambda_l);
if y < 0 {
continue;
} else {
v *= (u - p2) * lambda_l;
}
} else {
// Step 4: Region 4, right exponential tail.
y = f64_to_i64(x_r - v.ln() / lambda_r);
if y > 0 && (y as u64) > self.n {
continue;
} else {
v *= (u - p3) * lambda_r;
}
}

// the result should be discrete
lresult = lresult.floor();
// Step 5: Acceptance/rejection comparison.

let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
// this is the binomial probability divided by the comparison probability
// we will generate a uniform random value and if it is larger than this,
// we interpret it as a value falling out of the distribution and repeat
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
// Step 5.0: Test for appropriate method of evaluating f(y).
let k = (y - m).abs();
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
// Step 5.1: Evaluate f(y) via the recursive relationship. Start the
// search from the mode.
let s = p / q;
let a = s * (n + 1.);
let mut f = 1.0;
if m < y {
let mut i = m;
loop {
i += 1;
f *= a / (i as f64) - s;
if i == y {
break;
}
}
} else if m > y {
let mut i = y;
loop {
i += 1;
f /= a / (i as f64) - s;
if i == m {
break;
}
}
}
if v > f {
continue;
} else {
break;
}
}

if comparison_coeff >= rng.gen() {
// Step 5.2: Squeezing. Check the value of ln(v) againts upper and
// lower bound of ln(f(y)).
let k = k as f64;
let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1./6.) / npq + 0.5);
let t = -0.5 * k*k / npq;
let alpha = v.ln();
if alpha < t - rho {
break;
}
if alpha > t + rho {
continue;
}

// Step 5.3: Final acceptance/rejection test.
let x1 = (y + 1) as f64;
let f1 = (m + 1) as f64;
let z = (f64_to_i64(n) + 1 - m) as f64;
let w = (f64_to_i64(n) - y + 1) as f64;

fn stirling(a: f64) -> f64 {
let a2 = a * a;
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
}

if alpha > x_m * (f1 / x1).ln()
+ (n - (m as f64) + 0.5) * (z / w).ln()
+ ((y - m) as f64) * (w * p / (x1 * q)).ln()
// We use the signs from the GSL implementation, which are
// different than the ones in the reference. According to
// the GSL authors, the new signs were verified to be
// correct by one of the original designers of the
// algorithm.
+ stirling(f1) + stirling(z) - stirling(x1) - stirling(w)
dhardy marked this conversation as resolved.
Show resolved Hide resolved
{
continue;
}

break;
}
result = lresult as u64;
assert!(y >= 0);
result = y as u64;
}

// invert the result for p < 0.5
// Invert the result for p < 0.5.
if p != self.p {
self.n - result
} else {
Expand All @@ -157,12 +284,14 @@ mod test {
for i in results.iter_mut() { *i = binomial.sample(rng) as f64; }

let mean = results.iter().sum::<f64>() / results.len() as f64;
assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0,
"mean: {}, expected_mean: {}", mean, expected_mean);

let variance =
results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>()
/ results.len() as f64;
assert!((variance - expected_variance).abs() < expected_variance / 10.0);
assert!((variance - expected_variance).abs() < expected_variance / 10.0,
"variance: {}, expected_variance: {}", variance, expected_variance);
}

#[test]
Expand Down