From f8b587e4afd99d125fd7d2ece7430d92b153d411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elie=20G=C3=A9nard?= Date: Fri, 20 Mar 2020 12:24:11 +0100 Subject: [PATCH 1/7] Add inverse Gaussian distribution --- rand_distr/src/inverse_gaussian.rs | 85 ++++++++++++++++++++++++++++++ rand_distr/src/lib.rs | 4 ++ 2 files changed, 89 insertions(+) create mode 100644 rand_distr/src/inverse_gaussian.rs diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs new file mode 100644 index 00000000000..e24fba093fb --- /dev/null +++ b/rand_distr/src/inverse_gaussian.rs @@ -0,0 +1,85 @@ +use crate::{Distribution, Float, Standard, StandardNormal}; +use rand::prelude::*; + +/// Error type returned from `InverseGaussian::new` +#[derive(Debug, PartialEq)] +pub enum Error { + /// `mean <= 0` or `nan`. + MeanNegativeOrNull, + /// `shape <= 0` or `nan`. + ShapeNegativeOrNull, +} + +/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) +#[derive(Debug)] +pub struct InverseGaussian { + mean: N, + shape: N, +} + +impl InverseGaussian +where StandardNormal: Distribution +{ + /// Construct a new `InverseGaussian` distribution with the given mean and + /// shape. + pub fn new(mean: N, shape: N) -> Result, Error> { + if !(mean > N::from(0.0)) { + return Err(Error::MeanNegativeOrNull); + } + + if !(shape > N::from(0.0)) { + return Err(Error::ShapeNegativeOrNull); + } + + Ok(Self { mean, shape }) + } +} + +impl Distribution for InverseGaussian +where + StandardNormal: Distribution, + Standard: Distribution, +{ + fn sample(&self, rng: &mut R) -> N + where R: Rng + ?Sized { + let mu = self.mean; + let l = self.shape; + + let v: N = rng.sample(StandardNormal); + let y = mu * v * v; + + let mu_2l = mu / (N::from(2.) * l); + + let x = mu + mu_2l * (y - (N::from(4.) * l * y + y * y).sqrt()); + + let u: N = rng.gen(); + + if u <= mu / (mu + x) { + return x; + } + + mu * mu / x + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_inverse_gaussian() { + let inv_gauss = InverseGaussian::new(1.0, 1.0).unwrap(); + let mut rng = crate::test::rng(210); + for _ in 0..1000 { + inv_gauss.sample(&mut rng); + } + } + + #[test] + fn test_inverse_gaussian_invalid_param() { + assert!(InverseGaussian::new(-1.0, 1.0).is_err()); + assert!(InverseGaussian::new(-1.0, -1.0).is_err()); + assert!(InverseGaussian::new(1.0, -1.0).is_err()); + assert!(InverseGaussian::new(1.0, 1.0).is_ok()); + } +} diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 0e7beb91b5d..38e514fb588 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -66,6 +66,8 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution +//! - Misc. distributions +//! - [`InverseGaussian`] distribution pub use rand::distributions::{ uniform, weighted, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, @@ -80,6 +82,7 @@ pub use self::gamma::{ Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError, Gamma, StudentT, }; +pub use self::inverse_gaussian::{InverseGaussian, Error as InverseGaussianError}; pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; pub use self::pareto::{Error as ParetoError, Pareto}; pub use self::pert::{Pert, PertError}; @@ -97,6 +100,7 @@ mod cauchy; mod dirichlet; mod exponential; mod gamma; +mod inverse_gaussian; mod normal; mod pareto; mod pert; From 162d0a24c32c75ef589b2d2b10a7c343d7eaf147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elie=20G=C3=A9nard?= Date: Fri, 20 Mar 2020 12:24:44 +0100 Subject: [PATCH 2/7] Add normal-inverse Gaussian distribution --- rand_distr/src/lib.rs | 3 + rand_distr/src/normal_inverse_gaussian.rs | 82 +++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 rand_distr/src/normal_inverse_gaussian.rs diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 38e514fb588..a21abd07b62 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -68,6 +68,7 @@ //! - [`UnitDisc`] distribution //! - Misc. distributions //! - [`InverseGaussian`] distribution +//! - [`NormalInverseGaussian`] distribution pub use rand::distributions::{ uniform, weighted, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, @@ -84,6 +85,7 @@ pub use self::gamma::{ }; pub use self::inverse_gaussian::{InverseGaussian, Error as InverseGaussianError}; pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; +pub use self::normal_inverse_gaussian::{NormalInverseGaussian, Error as NormalInverseGaussianError}; pub use self::pareto::{Error as ParetoError, Pareto}; pub use self::pert::{Pert, PertError}; pub use self::poisson::{Error as PoissonError, Poisson}; @@ -102,6 +104,7 @@ mod exponential; mod gamma; mod inverse_gaussian; mod normal; +mod normal_inverse_gaussian; mod pareto; mod pert; mod poisson; diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs new file mode 100644 index 00000000000..0018de1e7b7 --- /dev/null +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -0,0 +1,82 @@ +use crate::{Distribution, Float, InverseGaussian, Standard, StandardNormal}; +use rand::prelude::*; + +/// Error type returned from `NormalInverseGaussian::new` +#[derive(Debug, PartialEq)] +pub enum Error { + /// `alpha <= 0` or `nan`. + AlphaNegativeOrNull, + /// `|beta| >= alpha` or `nan`. + AbsoluteBetaLessThanAlpha, +} + +/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) +#[derive(Debug)] +pub struct NormalInverseGaussian { + alpha: N, + beta: N, + inverse_gaussian: InverseGaussian, +} + +impl NormalInverseGaussian +where StandardNormal: Distribution +{ + /// Construct a new `NormalInverseGaussian` distribution with the given alpha (tail heaviness) and + /// beta (asymmetry) parameters. + pub fn new(alpha: N, beta: N) -> Result, Error> { + if !(alpha > N::from(0.0)) { + return Err(Error::AlphaNegativeOrNull); + } + + if !(beta.abs() < alpha) { + return Err(Error::AbsoluteBetaLessThanAlpha); + } + + let gamma = (alpha.powf(N::from(2.)) - beta.powf(N::from(2.))).sqrt(); + + let mu = N::from(1.) / gamma; + + let inverse_gaussian = InverseGaussian::new(mu, N::from(1.)).unwrap(); + + Ok(Self { + alpha, + beta, + inverse_gaussian, + }) + } +} + +impl Distribution for NormalInverseGaussian +where + StandardNormal: Distribution, + Standard: Distribution, +{ + fn sample(&self, rng: &mut R) -> N + where R: Rng + ?Sized { + let inv_gauss = rng.sample(&self.inverse_gaussian); + + self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normal_inverse_gaussian() { + let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap(); + let mut rng = crate::test::rng(210); + for _ in 0..1000 { + norm_inv_gauss.sample(&mut rng); + } + } + + #[test] + fn test_normal_inverse_gaussian_invalid_param() { + assert!(NormalInverseGaussian::new(-1.0, 1.0).is_err()); + assert!(NormalInverseGaussian::new(-1.0, -1.0).is_err()); + assert!(NormalInverseGaussian::new(1.0, 2.0).is_err()); + assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok()); + } +} From 3c9172975bca8873a15d1e9a2a0e020097fc9c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elie=20G=C3=A9nard?= Date: Sat, 21 Mar 2020 10:36:07 +0100 Subject: [PATCH 3/7] Qualify imports --- rand_distr/src/inverse_gaussian.rs | 2 +- rand_distr/src/normal_inverse_gaussian.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index e24fba093fb..210a06491b1 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -1,5 +1,5 @@ use crate::{Distribution, Float, Standard, StandardNormal}; -use rand::prelude::*; +use rand::Rng; /// Error type returned from `InverseGaussian::new` #[derive(Debug, PartialEq)] diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index 0018de1e7b7..adc98579851 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -1,5 +1,5 @@ use crate::{Distribution, Float, InverseGaussian, Standard, StandardNormal}; -use rand::prelude::*; +use rand::Rng; /// Error type returned from `NormalInverseGaussian::new` #[derive(Debug, PartialEq)] From d841db0575d1f9bdfa130d4532ec2f5c73fc3887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elie=20G=C3=A9nard?= Date: Sat, 21 Mar 2020 10:36:21 +0100 Subject: [PATCH 4/7] Add value stability tests --- rand_distr/src/inverse_gaussian.rs | 24 ++++++++++++++++++++++ rand_distr/src/normal_inverse_gaussian.rs | 25 +++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index 210a06491b1..171aa473eee 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -82,4 +82,28 @@ mod tests { assert!(InverseGaussian::new(1.0, -1.0).is_err()); assert!(InverseGaussian::new(1.0, 1.0).is_ok()); } + + #[test] + fn value_stability() { + fn test_samples>( + distr: D, zero: N, expected: &[N], + ) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + test_samples(InverseGaussian::new(1.0, 3.0).unwrap(), 0f32, &[ + 0.9339157, 1.108113, 0.50864697, 0.39849377, + ]); + test_samples(InverseGaussian::new(1.0, 3.0).unwrap(), 0f64, &[ + 1.0707604954722476, + 0.9628140605340697, + 0.4069687656468226, + 0.660283852985818, + ]); + } } diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index adc98579851..1e2c5886ab5 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -79,4 +79,29 @@ mod tests { assert!(NormalInverseGaussian::new(1.0, 2.0).is_err()); assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok()); } + + + #[test] + fn value_stability() { + fn test_samples>( + distr: D, zero: N, expected: &[N], + ) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + test_samples(NormalInverseGaussian::new(2.0, 1.0).unwrap(), 0f32, &[ + 0.6568966, 1.3744819, 2.216063, 0.11488572, + ]); + test_samples(NormalInverseGaussian::new(2.0, 1.0).unwrap(), 0f64, &[ + 0.6838707059642927, + 2.4447306460569784, + 0.2361045023235968, + 1.7774534624785319, + ]); + } } From 868eaa93d4f492ef1a99683202bea7ea509b06a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elie=20G=C3=A9nard?= Date: Sat, 21 Mar 2020 12:21:18 +0100 Subject: [PATCH 5/7] Fix error name --- rand_distr/src/normal_inverse_gaussian.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index 1e2c5886ab5..9c404db581b 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -7,7 +7,7 @@ pub enum Error { /// `alpha <= 0` or `nan`. AlphaNegativeOrNull, /// `|beta| >= alpha` or `nan`. - AbsoluteBetaLessThanAlpha, + AbsoluteBetaNotLessThanAlpha, } /// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) From 0205195e72470e5de945b999fae0e23c00f1c4ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elie=20G=C3=A9nard?= Date: Sat, 21 Mar 2020 12:25:06 +0100 Subject: [PATCH 6/7] Improve exponentiation performance --- rand_distr/src/normal_inverse_gaussian.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index 9c404db581b..819f7040349 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -32,7 +32,7 @@ where StandardNormal: Distribution return Err(Error::AbsoluteBetaLessThanAlpha); } - let gamma = (alpha.powf(N::from(2.)) - beta.powf(N::from(2.))).sqrt(); + let gamma = (alpha * alpha - beta * beta).sqrt(); let mu = N::from(1.) / gamma; From 7983840f72ec57357614bb0f6c096e89987eb74d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elie=20G=C3=A9nard?= Date: Sat, 21 Mar 2020 16:24:18 +0100 Subject: [PATCH 7/7] Fix error variant name --- rand_distr/src/normal_inverse_gaussian.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index 819f7040349..fc6e9801217 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -29,7 +29,7 @@ where StandardNormal: Distribution } if !(beta.abs() < alpha) { - return Err(Error::AbsoluteBetaLessThanAlpha); + return Err(Error::AbsoluteBetaNotLessThanAlpha); } let gamma = (alpha * alpha - beta * beta).sqrt();