Skip to content

Commit

Permalink
impl Distribution<f32> for Gamma distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
dhardy committed Apr 27, 2019
1 parent 8e3ed11 commit 19829a4
Showing 1 changed file with 60 additions and 45 deletions.
105 changes: 60 additions & 45 deletions rand_distr/src/gamma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use self::ChiSquaredRepr::*;

use rand::Rng;
use crate::normal::StandardNormal;
use crate::{Distribution, Exp, Open01};
use crate::{Distribution, Exp1, Exp, Open01};
use num_traits::Float;

/// The Gamma distribution `Gamma(shape, scale)` distribution.
///
Expand Down Expand Up @@ -47,8 +48,8 @@ use crate::{Distribution, Exp, Open01};
/// (September 2000), 363-372.
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
#[derive(Clone, Copy, Debug)]
pub struct Gamma {
repr: GammaRepr,
pub struct Gamma<N> {
repr: GammaRepr<N>,
}

/// Error type returned from `Gamma::new`.
Expand All @@ -63,10 +64,10 @@ pub enum Error {
}

#[derive(Clone, Copy, Debug)]
enum GammaRepr {
Large(GammaLargeShape),
One(Exp<f64>),
Small(GammaSmallShape)
enum GammaRepr<N> {
Large(GammaLargeShape<N>),
One(Exp<N>),
Small(GammaSmallShape<N>)
}

// These two helpers could be made public, but saving the
Expand All @@ -84,37 +85,39 @@ enum GammaRepr {
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
struct GammaSmallShape {
inv_shape: f64,
large_shape: GammaLargeShape
struct GammaSmallShape<N> {
inv_shape: N,
large_shape: GammaLargeShape<N>
}

/// Gamma distribution where the shape parameter is larger than 1.
///
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
struct GammaLargeShape {
scale: f64,
c: f64,
d: f64
struct GammaLargeShape<N> {
scale: N,
c: N,
d: N
}

impl Gamma {
impl<N: Float> Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
/// Construct an object representing the `Gamma(shape, scale)`
/// distribution.
#[inline]
pub fn new(shape: f64, scale: f64) -> Result<Gamma, Error> {
if !(shape > 0.0) {
pub fn new(shape: N, scale: N) -> Result<Gamma<N>, Error> {
if !(shape > N::zero()) {
return Err(Error::ShapeTooSmall);
}
if !(scale > 0.0) {
if !(scale > N::zero()) {
return Err(Error::ScaleTooSmall);
}

let repr = if shape == 1.0 {
One(Exp::new(1.0 / scale).map_err(|_| Error::ScaleTooLarge)?)
} else if shape < 1.0 {
let repr = if shape == N::one() {
One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
} else if shape < N::one() {
Small(GammaSmallShape::new_raw(shape, scale))
} else {
Large(GammaLargeShape::new_raw(shape, scale))
Expand All @@ -123,57 +126,69 @@ impl Gamma {
}
}

impl GammaSmallShape {
fn new_raw(shape: f64, scale: f64) -> GammaSmallShape {
impl<N: Float> GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaSmallShape<N> {
GammaSmallShape {
inv_shape: 1. / shape,
large_shape: GammaLargeShape::new_raw(shape + 1.0, scale)
inv_shape: N::one() / shape,
large_shape: GammaLargeShape::new_raw(shape + N::one(), scale)
}
}
}

impl GammaLargeShape {
fn new_raw(shape: f64, scale: f64) -> GammaLargeShape {
let d = shape - 1. / 3.;
impl<N: Float> GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaLargeShape<N> {
let d = shape - N::from(1. / 3.).unwrap();
GammaLargeShape {
scale,
c: 1. / (9. * d).sqrt(),
c: N::one() / (N::from(9.).unwrap() * d).sqrt(),
d
}
}
}

impl Distribution<f64> for Gamma {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
match self.repr {
Small(ref g) => g.sample(rng),
One(ref g) => g.sample(rng),
Large(ref g) => g.sample(rng),
}
}
}
impl Distribution<f64> for GammaSmallShape {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let u: f64 = rng.sample(Open01);
impl<N: Float> Distribution<N> for GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
let u: N = rng.sample(Open01);

self.large_shape.sample(rng) * u.powf(self.inv_shape)
}
}
impl Distribution<f64> for GammaLargeShape {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
// Marsaglia & Tsang method, 2000
loop {
let x: f64 = rng.sample(StandardNormal);
let v_cbrt = 1.0 + self.c * x;
if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
let x: N = rng.sample(StandardNormal);
let v_cbrt = N::one() + self.c * x;
if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0
continue
}

let v = v_cbrt * v_cbrt * v_cbrt;
let u: f64 = rng.sample(Open01);
let u: N = rng.sample(Open01);

let x_sqr = x * x;
if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) {
if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr ||
u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln())
{
return self.d * v * self.scale
}
}
Expand Down Expand Up @@ -215,7 +230,7 @@ enum ChiSquaredRepr {
// e.g. when alpha = 1/2 as it would be for this case, so special-
// casing and using the definition of N(0,1)^2 is faster.
DoFExactlyOne,
DoFAnythingElse(Gamma),
DoFAnythingElse(Gamma<f64>),
}

impl ChiSquared {
Expand Down Expand Up @@ -350,8 +365,8 @@ impl Distribution<f64> for StudentT {
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Beta {
gamma_a: Gamma,
gamma_b: Gamma,
gamma_a: Gamma<f64>,
gamma_b: Gamma<f64>,
}

/// Error type returned from `Beta::new`.
Expand Down

0 comments on commit 19829a4

Please sign in to comment.