Skip to content

Commit

Permalink
Merge pull request rust-random#785 from dhardy/distr
Browse files Browse the repository at this point in the history
Make distributions generic / impl for f32
  • Loading branch information
dhardy committed May 10, 2019
2 parents 01343e1 + 19829a4 commit b664e64
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 73 deletions.
1 change: 1 addition & 0 deletions rand_distr/Cargo.toml
Expand Up @@ -20,3 +20,4 @@ appveyor = { repository = "rust-random/rand" }

[dependencies]
rand = { path = "..", version = ">=0.5, <=0.7" }
num-traits = "0.2"
33 changes: 23 additions & 10 deletions rand_distr/src/exponential.rs
Expand Up @@ -12,6 +12,7 @@
use rand::Rng;
use crate::{ziggurat_tables, Distribution};
use crate::utils::ziggurat;
use num_traits::Float;

/// Samples floating-point numbers according to the exponential distribution,
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
Expand Down Expand Up @@ -39,6 +40,15 @@ use crate::utils::ziggurat;
#[derive(Clone, Copy, Debug)]
pub struct Exp1;

impl Distribution<f32> for Exp1 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
// TODO: use optimal 32-bit implementation
let x: f64 = self.sample(rng);
x as f32
}
}

// This could be done via `-rng.gen::<f64>().ln()` but that is slower.
impl Distribution<f64> for Exp1 {
#[inline]
Expand Down Expand Up @@ -76,9 +86,9 @@ impl Distribution<f64> for Exp1 {
/// println!("{} is from a Exp(2) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Exp {
pub struct Exp<N> {
/// `lambda` stored as `1/lambda`, since this is what we scale by.
lambda_inverse: f64
lambda_inverse: N
}

/// Error type returned from `Exp::new`.
Expand All @@ -88,22 +98,25 @@ pub enum Error {
LambdaTooSmall,
}

impl Exp {
impl<N: Float> Exp<N>
where Exp1: Distribution<N>
{
/// Construct a new `Exp` with the given shape parameter
/// `lambda`.
#[inline]
pub fn new(lambda: f64) -> Result<Exp, Error> {
if !(lambda > 0.0) {
pub fn new(lambda: N) -> Result<Exp<N>, Error> {
if !(lambda > N::zero()) {
return Err(Error::LambdaTooSmall);
}
Ok(Exp { lambda_inverse: 1.0 / lambda })
Ok(Exp { lambda_inverse: N::one() / lambda })
}
}

impl Distribution<f64> for Exp {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let n: f64 = rng.sample(Exp1);
n * self.lambda_inverse
impl<N: Float> Distribution<N> for Exp<N>
where Exp1: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
rng.sample(Exp1) * self.lambda_inverse
}
}

Expand Down
109 changes: 62 additions & 47 deletions rand_distr/src/gamma.rs
Expand Up @@ -14,7 +14,8 @@ use self::ChiSquaredRepr::*;

use rand::Rng;
use crate::normal::StandardNormal;
use crate::{Distribution, Exp, Open01};
use crate::{Distribution, Exp1, Exp, Open01};
use num_traits::Float;

/// The Gamma distribution `Gamma(shape, scale)` distribution.
///
Expand Down Expand Up @@ -47,8 +48,8 @@ use crate::{Distribution, Exp, Open01};
/// (September 2000), 363-372.
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
#[derive(Clone, Copy, Debug)]
pub struct Gamma {
repr: GammaRepr,
pub struct Gamma<N> {
repr: GammaRepr<N>,
}

/// Error type returned from `Gamma::new`.
Expand All @@ -63,10 +64,10 @@ pub enum Error {
}

#[derive(Clone, Copy, Debug)]
enum GammaRepr {
Large(GammaLargeShape),
One(Exp),
Small(GammaSmallShape)
enum GammaRepr<N> {
Large(GammaLargeShape<N>),
One(Exp<N>),
Small(GammaSmallShape<N>)
}

// These two helpers could be made public, but saving the
Expand All @@ -84,37 +85,39 @@ enum GammaRepr {
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
struct GammaSmallShape {
inv_shape: f64,
large_shape: GammaLargeShape
struct GammaSmallShape<N> {
inv_shape: N,
large_shape: GammaLargeShape<N>
}

/// Gamma distribution where the shape parameter is larger than 1.
///
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
struct GammaLargeShape {
scale: f64,
c: f64,
d: f64
struct GammaLargeShape<N> {
scale: N,
c: N,
d: N
}

impl Gamma {
impl<N: Float> Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
/// Construct an object representing the `Gamma(shape, scale)`
/// distribution.
#[inline]
pub fn new(shape: f64, scale: f64) -> Result<Gamma, Error> {
if !(shape > 0.0) {
pub fn new(shape: N, scale: N) -> Result<Gamma<N>, Error> {
if !(shape > N::zero()) {
return Err(Error::ShapeTooSmall);
}
if !(scale > 0.0) {
if !(scale > N::zero()) {
return Err(Error::ScaleTooSmall);
}

let repr = if shape == 1.0 {
One(Exp::new(1.0 / scale).map_err(|_| Error::ScaleTooLarge)?)
} else if shape < 1.0 {
let repr = if shape == N::one() {
One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
} else if shape < N::one() {
Small(GammaSmallShape::new_raw(shape, scale))
} else {
Large(GammaLargeShape::new_raw(shape, scale))
Expand All @@ -123,57 +126,69 @@ impl Gamma {
}
}

impl GammaSmallShape {
fn new_raw(shape: f64, scale: f64) -> GammaSmallShape {
impl<N: Float> GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaSmallShape<N> {
GammaSmallShape {
inv_shape: 1. / shape,
large_shape: GammaLargeShape::new_raw(shape + 1.0, scale)
inv_shape: N::one() / shape,
large_shape: GammaLargeShape::new_raw(shape + N::one(), scale)
}
}
}

impl GammaLargeShape {
fn new_raw(shape: f64, scale: f64) -> GammaLargeShape {
let d = shape - 1. / 3.;
impl<N: Float> GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaLargeShape<N> {
let d = shape - N::from(1. / 3.).unwrap();
GammaLargeShape {
scale,
c: 1. / (9. * d).sqrt(),
c: N::one() / (N::from(9.).unwrap() * d).sqrt(),
d
}
}
}

impl Distribution<f64> for Gamma {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
match self.repr {
Small(ref g) => g.sample(rng),
One(ref g) => g.sample(rng),
Large(ref g) => g.sample(rng),
}
}
}
impl Distribution<f64> for GammaSmallShape {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let u: f64 = rng.sample(Open01);
impl<N: Float> Distribution<N> for GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
let u: N = rng.sample(Open01);

self.large_shape.sample(rng) * u.powf(self.inv_shape)
}
}
impl Distribution<f64> for GammaLargeShape {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
// Marsaglia & Tsang method, 2000
loop {
let x = rng.sample(StandardNormal);
let v_cbrt = 1.0 + self.c * x;
if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
let x: N = rng.sample(StandardNormal);
let v_cbrt = N::one() + self.c * x;
if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0
continue
}

let v = v_cbrt * v_cbrt * v_cbrt;
let u: f64 = rng.sample(Open01);
let u: N = rng.sample(Open01);

let x_sqr = x * x;
if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) {
if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr ||
u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln())
{
return self.d * v * self.scale
}
}
Expand Down Expand Up @@ -215,7 +230,7 @@ enum ChiSquaredRepr {
// e.g. when alpha = 1/2 as it would be for this case, so special-
// casing and using the definition of N(0,1)^2 is faster.
DoFExactlyOne,
DoFAnythingElse(Gamma),
DoFAnythingElse(Gamma<f64>),
}

impl ChiSquared {
Expand All @@ -238,7 +253,7 @@ impl Distribution<f64> for ChiSquared {
match self.repr {
DoFExactlyOne => {
// k == 1 => N(0,1)^2
let norm = rng.sample(StandardNormal);
let norm: f64 = rng.sample(StandardNormal);
norm * norm
}
DoFAnythingElse(ref g) => g.sample(rng)
Expand Down Expand Up @@ -332,7 +347,7 @@ impl StudentT {
}
impl Distribution<f64> for StudentT {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let norm = rng.sample(StandardNormal);
let norm: f64 = rng.sample(StandardNormal);
norm * (self.dof / self.chi.sample(rng)).sqrt()
}
}
Expand All @@ -350,8 +365,8 @@ impl Distribution<f64> for StudentT {
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Beta {
gamma_a: Gamma,
gamma_b: Gamma,
gamma_a: Gamma<f64>,
gamma_b: Gamma<f64>,
}

/// Error type returned from `Beta::new`.
Expand Down

0 comments on commit b664e64

Please sign in to comment.