diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs new file mode 100644 index 00000000000..171aa473eee --- /dev/null +++ b/rand_distr/src/inverse_gaussian.rs @@ -0,0 +1,109 @@ +use crate::{Distribution, Float, Standard, StandardNormal}; +use rand::Rng; + +/// Error type returned from `InverseGaussian::new` +#[derive(Debug, PartialEq)] +pub enum Error { + /// `mean <= 0` or `nan`. + MeanNegativeOrNull, + /// `shape <= 0` or `nan`. + ShapeNegativeOrNull, +} + +/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) +#[derive(Debug)] +pub struct InverseGaussian { + mean: N, + shape: N, +} + +impl InverseGaussian +where StandardNormal: Distribution +{ + /// Construct a new `InverseGaussian` distribution with the given mean and + /// shape. + pub fn new(mean: N, shape: N) -> Result, Error> { + if !(mean > N::from(0.0)) { + return Err(Error::MeanNegativeOrNull); + } + + if !(shape > N::from(0.0)) { + return Err(Error::ShapeNegativeOrNull); + } + + Ok(Self { mean, shape }) + } +} + +impl Distribution for InverseGaussian +where + StandardNormal: Distribution, + Standard: Distribution, +{ + fn sample(&self, rng: &mut R) -> N + where R: Rng + ?Sized { + let mu = self.mean; + let l = self.shape; + + let v: N = rng.sample(StandardNormal); + let y = mu * v * v; + + let mu_2l = mu / (N::from(2.) * l); + + let x = mu + mu_2l * (y - (N::from(4.) * l * y + y * y).sqrt()); + + let u: N = rng.gen(); + + if u <= mu / (mu + x) { + return x; + } + + mu * mu / x + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_inverse_gaussian() { + let inv_gauss = InverseGaussian::new(1.0, 1.0).unwrap(); + let mut rng = crate::test::rng(210); + for _ in 0..1000 { + inv_gauss.sample(&mut rng); + } + } + + #[test] + fn test_inverse_gaussian_invalid_param() { + assert!(InverseGaussian::new(-1.0, 1.0).is_err()); + assert!(InverseGaussian::new(-1.0, -1.0).is_err()); + assert!(InverseGaussian::new(1.0, -1.0).is_err()); + assert!(InverseGaussian::new(1.0, 1.0).is_ok()); + } + + #[test] + fn value_stability() { + fn test_samples>( + distr: D, zero: N, expected: &[N], + ) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + test_samples(InverseGaussian::new(1.0, 3.0).unwrap(), 0f32, &[ + 0.9339157, 1.108113, 0.50864697, 0.39849377, + ]); + test_samples(InverseGaussian::new(1.0, 3.0).unwrap(), 0f64, &[ + 1.0707604954722476, + 0.9628140605340697, + 0.4069687656468226, + 0.660283852985818, + ]); + } +} diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index a41c579cbcc..ebc14402771 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -66,6 +66,9 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution +//! - Misc. distributions +//! - [`InverseGaussian`] distribution +//! - [`NormalInverseGaussian`] distribution pub use rand::distributions::{ uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, @@ -80,7 +83,9 @@ pub use self::gamma::{ Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError, Gamma, StudentT, }; +pub use self::inverse_gaussian::{InverseGaussian, Error as InverseGaussianError}; pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; +pub use self::normal_inverse_gaussian::{NormalInverseGaussian, Error as NormalInverseGaussianError}; pub use self::pareto::{Error as ParetoError, Pareto}; pub use self::pert::{Pert, PertError}; pub use self::poisson::{Error as PoissonError, Poisson}; @@ -100,7 +105,9 @@ mod cauchy; mod dirichlet; mod exponential; mod gamma; +mod inverse_gaussian; mod normal; +mod normal_inverse_gaussian; mod pareto; mod pert; mod poisson; diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs new file mode 100644 index 00000000000..fc6e9801217 --- /dev/null +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -0,0 +1,107 @@ +use crate::{Distribution, Float, InverseGaussian, Standard, StandardNormal}; +use rand::Rng; + +/// Error type returned from `NormalInverseGaussian::new` +#[derive(Debug, PartialEq)] +pub enum Error { + /// `alpha <= 0` or `nan`. + AlphaNegativeOrNull, + /// `|beta| >= alpha` or `nan`. + AbsoluteBetaNotLessThanAlpha, +} + +/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) +#[derive(Debug)] +pub struct NormalInverseGaussian { + alpha: N, + beta: N, + inverse_gaussian: InverseGaussian, +} + +impl NormalInverseGaussian +where StandardNormal: Distribution +{ + /// Construct a new `NormalInverseGaussian` distribution with the given alpha (tail heaviness) and + /// beta (asymmetry) parameters. + pub fn new(alpha: N, beta: N) -> Result, Error> { + if !(alpha > N::from(0.0)) { + return Err(Error::AlphaNegativeOrNull); + } + + if !(beta.abs() < alpha) { + return Err(Error::AbsoluteBetaNotLessThanAlpha); + } + + let gamma = (alpha * alpha - beta * beta).sqrt(); + + let mu = N::from(1.) / gamma; + + let inverse_gaussian = InverseGaussian::new(mu, N::from(1.)).unwrap(); + + Ok(Self { + alpha, + beta, + inverse_gaussian, + }) + } +} + +impl Distribution for NormalInverseGaussian +where + StandardNormal: Distribution, + Standard: Distribution, +{ + fn sample(&self, rng: &mut R) -> N + where R: Rng + ?Sized { + let inv_gauss = rng.sample(&self.inverse_gaussian); + + self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normal_inverse_gaussian() { + let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap(); + let mut rng = crate::test::rng(210); + for _ in 0..1000 { + norm_inv_gauss.sample(&mut rng); + } + } + + #[test] + fn test_normal_inverse_gaussian_invalid_param() { + assert!(NormalInverseGaussian::new(-1.0, 1.0).is_err()); + assert!(NormalInverseGaussian::new(-1.0, -1.0).is_err()); + assert!(NormalInverseGaussian::new(1.0, 2.0).is_err()); + assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok()); + } + + + #[test] + fn value_stability() { + fn test_samples>( + distr: D, zero: N, expected: &[N], + ) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + test_samples(NormalInverseGaussian::new(2.0, 1.0).unwrap(), 0f32, &[ + 0.6568966, 1.3744819, 2.216063, 0.11488572, + ]); + test_samples(NormalInverseGaussian::new(2.0, 1.0).unwrap(), 0f64, &[ + 0.6838707059642927, + 2.4447306460569784, + 0.2361045023235968, + 1.7774534624785319, + ]); + } +}