diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index d6e4931ce34..526ea0bc124 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -20,3 +20,4 @@ appveyor = { repository = "rust-random/rand" } [dependencies] rand = { path = "..", version = ">=0.5, <=0.7" } +num-traits = "0.2" diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index 35dc0930b94..bd832f8279d 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -12,6 +12,7 @@ use rand::Rng; use crate::{ziggurat_tables, Distribution}; use crate::utils::ziggurat; +use num_traits::Float; /// Samples floating-point numbers according to the exponential distribution, /// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or @@ -39,6 +40,15 @@ use crate::utils::ziggurat; #[derive(Clone, Copy, Debug)] pub struct Exp1; +impl Distribution for Exp1 { + #[inline] + fn sample(&self, rng: &mut R) -> f32 { + // TODO: use optimal 32-bit implementation + let x: f64 = self.sample(rng); + x as f32 + } +} + // This could be done via `-rng.gen::().ln()` but that is slower. impl Distribution for Exp1 { #[inline] @@ -76,9 +86,9 @@ impl Distribution for Exp1 { /// println!("{} is from a Exp(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Exp { +pub struct Exp { /// `lambda` stored as `1/lambda`, since this is what we scale by. - lambda_inverse: f64 + lambda_inverse: N } /// Error type returned from `Exp::new`. @@ -88,22 +98,25 @@ pub enum Error { LambdaTooSmall, } -impl Exp { +impl Exp +where Exp1: Distribution +{ /// Construct a new `Exp` with the given shape parameter /// `lambda`. #[inline] - pub fn new(lambda: f64) -> Result { - if !(lambda > 0.0) { + pub fn new(lambda: N) -> Result, Error> { + if !(lambda > N::zero()) { return Err(Error::LambdaTooSmall); } - Ok(Exp { lambda_inverse: 1.0 / lambda }) + Ok(Exp { lambda_inverse: N::one() / lambda }) } } -impl Distribution for Exp { - fn sample(&self, rng: &mut R) -> f64 { - let n: f64 = rng.sample(Exp1); - n * self.lambda_inverse +impl Distribution for Exp +where Exp1: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + rng.sample(Exp1) * self.lambda_inverse } } diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 7ddc1fb13fa..6035f61107d 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -14,7 +14,8 @@ use self::ChiSquaredRepr::*; use rand::Rng; use crate::normal::StandardNormal; -use crate::{Distribution, Exp, Open01}; +use crate::{Distribution, Exp1, Exp, Open01}; +use num_traits::Float; /// The Gamma distribution `Gamma(shape, scale)` distribution. /// @@ -47,8 +48,8 @@ use crate::{Distribution, Exp, Open01}; /// (September 2000), 363-372. /// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) #[derive(Clone, Copy, Debug)] -pub struct Gamma { - repr: GammaRepr, +pub struct Gamma { + repr: GammaRepr, } /// Error type returned from `Gamma::new`. @@ -63,10 +64,10 @@ pub enum Error { } #[derive(Clone, Copy, Debug)] -enum GammaRepr { - Large(GammaLargeShape), - One(Exp), - Small(GammaSmallShape) +enum GammaRepr { + Large(GammaLargeShape), + One(Exp), + Small(GammaSmallShape) } // These two helpers could be made public, but saving the @@ -84,9 +85,9 @@ enum GammaRepr { /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug)] -struct GammaSmallShape { - inv_shape: f64, - large_shape: GammaLargeShape +struct GammaSmallShape { + inv_shape: N, + large_shape: GammaLargeShape } /// Gamma distribution where the shape parameter is larger than 1. @@ -94,27 +95,29 @@ struct GammaSmallShape { /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug)] -struct GammaLargeShape { - scale: f64, - c: f64, - d: f64 +struct GammaLargeShape { + scale: N, + c: N, + d: N } -impl Gamma { +impl Gamma +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Construct an object representing the `Gamma(shape, scale)` /// distribution. #[inline] - pub fn new(shape: f64, scale: f64) -> Result { - if !(shape > 0.0) { + pub fn new(shape: N, scale: N) -> Result, Error> { + if !(shape > N::zero()) { return Err(Error::ShapeTooSmall); } - if !(scale > 0.0) { + if !(scale > N::zero()) { return Err(Error::ScaleTooSmall); } - let repr = if shape == 1.0 { - One(Exp::new(1.0 / scale).map_err(|_| Error::ScaleTooLarge)?) - } else if shape < 1.0 { + let repr = if shape == N::one() { + One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?) + } else if shape < N::one() { Small(GammaSmallShape::new_raw(shape, scale)) } else { Large(GammaLargeShape::new_raw(shape, scale)) @@ -123,28 +126,34 @@ impl Gamma { } } -impl GammaSmallShape { - fn new_raw(shape: f64, scale: f64) -> GammaSmallShape { +impl GammaSmallShape +where StandardNormal: Distribution, Open01: Distribution +{ + fn new_raw(shape: N, scale: N) -> GammaSmallShape { GammaSmallShape { - inv_shape: 1. / shape, - large_shape: GammaLargeShape::new_raw(shape + 1.0, scale) + inv_shape: N::one() / shape, + large_shape: GammaLargeShape::new_raw(shape + N::one(), scale) } } } -impl GammaLargeShape { - fn new_raw(shape: f64, scale: f64) -> GammaLargeShape { - let d = shape - 1. / 3.; +impl GammaLargeShape +where StandardNormal: Distribution, Open01: Distribution +{ + fn new_raw(shape: N, scale: N) -> GammaLargeShape { + let d = shape - N::from(1. / 3.).unwrap(); GammaLargeShape { scale, - c: 1. / (9. * d).sqrt(), + c: N::one() / (N::from(9.).unwrap() * d).sqrt(), d } } } -impl Distribution for Gamma { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for Gamma +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { match self.repr { Small(ref g) => g.sample(rng), One(ref g) => g.sample(rng), @@ -152,28 +161,34 @@ impl Distribution for Gamma { } } } -impl Distribution for GammaSmallShape { - fn sample(&self, rng: &mut R) -> f64 { - let u: f64 = rng.sample(Open01); +impl Distribution for GammaSmallShape +where StandardNormal: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let u: N = rng.sample(Open01); self.large_shape.sample(rng) * u.powf(self.inv_shape) } } -impl Distribution for GammaLargeShape { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for GammaLargeShape +where StandardNormal: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + // Marsaglia & Tsang method, 2000 loop { - let x = rng.sample(StandardNormal); - let v_cbrt = 1.0 + self.c * x; - if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0 + let x: N = rng.sample(StandardNormal); + let v_cbrt = N::one() + self.c * x; + if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0 continue } let v = v_cbrt * v_cbrt * v_cbrt; - let u: f64 = rng.sample(Open01); + let u: N = rng.sample(Open01); let x_sqr = x * x; - if u < 1.0 - 0.0331 * x_sqr * x_sqr || - u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) { + if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr || + u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln()) + { return self.d * v * self.scale } } @@ -215,7 +230,7 @@ enum ChiSquaredRepr { // e.g. when alpha = 1/2 as it would be for this case, so special- // casing and using the definition of N(0,1)^2 is faster. DoFExactlyOne, - DoFAnythingElse(Gamma), + DoFAnythingElse(Gamma), } impl ChiSquared { @@ -238,7 +253,7 @@ impl Distribution for ChiSquared { match self.repr { DoFExactlyOne => { // k == 1 => N(0,1)^2 - let norm = rng.sample(StandardNormal); + let norm: f64 = rng.sample(StandardNormal); norm * norm } DoFAnythingElse(ref g) => g.sample(rng) @@ -332,7 +347,7 @@ impl StudentT { } impl Distribution for StudentT { fn sample(&self, rng: &mut R) -> f64 { - let norm = rng.sample(StandardNormal); + let norm: f64 = rng.sample(StandardNormal); norm * (self.dof / self.chi.sample(rng)).sqrt() } } @@ -350,8 +365,8 @@ impl Distribution for StudentT { /// ``` #[derive(Clone, Copy, Debug)] pub struct Beta { - gamma_a: Gamma, - gamma_b: Gamma, + gamma_a: Gamma, + gamma_b: Gamma, } /// Error type returned from `Beta::new`. diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index 436e7dfd9ac..ea700e54d8a 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -12,6 +12,7 @@ use rand::Rng; use crate::{ziggurat_tables, Distribution, Open01}; use crate::utils::ziggurat; +use num_traits::Float; /// Samples floating-point numbers according to the normal distribution /// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to @@ -37,6 +38,15 @@ use crate::utils::ziggurat; #[derive(Clone, Copy, Debug)] pub struct StandardNormal; +impl Distribution for StandardNormal { + #[inline] + fn sample(&self, rng: &mut R) -> f32 { + // TODO: use optimal 32-bit implementation + let x: f64 = self.sample(rng); + x as f32 + } +} + impl Distribution for StandardNormal { fn sample(&self, rng: &mut R) -> f64 { #[inline] @@ -93,9 +103,9 @@ impl Distribution for StandardNormal { /// /// [`StandardNormal`]: crate::StandardNormal #[derive(Clone, Copy, Debug)] -pub struct Normal { - mean: f64, - std_dev: f64, +pub struct Normal { + mean: N, + std_dev: N, } /// Error type returned from `Normal::new` and `LogNormal::new`. @@ -105,12 +115,14 @@ pub enum Error { StdDevTooSmall, } -impl Normal { +impl Normal +where StandardNormal: Distribution +{ /// Construct a new `Normal` distribution with the given mean and /// standard deviation. #[inline] - pub fn new(mean: f64, std_dev: f64) -> Result { - if !(std_dev >= 0.0) { + pub fn new(mean: N, std_dev: N) -> Result, Error> { + if !(std_dev >= N::zero()) { return Err(Error::StdDevTooSmall); } Ok(Normal { @@ -119,9 +131,12 @@ impl Normal { }) } } -impl Distribution for Normal { - fn sample(&self, rng: &mut R) -> f64 { - let n = rng.sample(StandardNormal); + +impl Distribution for Normal +where StandardNormal: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let n: N = rng.sample(StandardNormal); self.mean + self.std_dev * n } } @@ -143,23 +158,28 @@ impl Distribution for Normal { /// println!("{} is from an ln N(2, 9) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct LogNormal { - norm: Normal +pub struct LogNormal { + norm: Normal } -impl LogNormal { +impl LogNormal +where StandardNormal: Distribution +{ /// Construct a new `LogNormal` distribution with the given mean /// and standard deviation of the logarithm of the distribution. #[inline] - pub fn new(mean: f64, std_dev: f64) -> Result { - if !(std_dev >= 0.0) { + pub fn new(mean: N, std_dev: N) -> Result, Error> { + if !(std_dev >= N::zero()) { return Err(Error::StdDevTooSmall); } Ok(LogNormal { norm: Normal::new(mean, std_dev).unwrap() }) } } -impl Distribution for LogNormal { - fn sample(&self, rng: &mut R) -> f64 { + +impl Distribution for LogNormal +where StandardNormal: Distribution +{ + fn sample(&self, rng: &mut R) -> N { self.norm.sample(rng).exp() } }