From 546caa87a88451afbd93a8cf9b999431e6953807 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Tue, 29 May 2018 22:59:38 -0400 Subject: [PATCH 01/10] adding support for dirichlet distribution --- src/distributions/dirichlet.rs | 138 +++++++++++++++++++++++++++++++++ src/distributions/mod.rs | 5 +- 2 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 src/distributions/dirichlet.rs diff --git a/src/distributions/dirichlet.rs b/src/distributions/dirichlet.rs new file mode 100644 index 00000000000..a0696cb046f --- /dev/null +++ b/src/distributions/dirichlet.rs @@ -0,0 +1,138 @@ +// Copyright 2013 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// https://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The dirichlet distribution. + +use Rng; +use distributions::{Distribution, gamma::Gamma}; + +/// The dirichelet distribution `Dirichlet(alpha)`. +/// +/// The Dirichlet distribution } is a family of continuous multivariate probability distributions parameterized by +/// a vector alpha of positive reals +/// It is a multivariate generalization of the beta distribution. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::Dirichlet; +/// +/// let dirichlet = Dirichlet::new(&vec![1.0, 2.0, 3.0]); +/// let samples = dirichlet.sample(&mut rand::thread_rng()); +/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); +/// ``` + +#[derive(Clone, Debug)] +pub struct Dirichlet { + /// Concentration parameters (alpha) + alpha: Vec, +} + +impl Dirichlet { + /// Construct a new `Dirichlet` with the given alpha parameter + /// `alpha`. Panics if `alpha.len() < 2`. + #[inline] + pub fn new(alpha: &[f64]) -> Dirichlet { + assert!( + alpha.len() > 1, + "Dirichlet::new called with `alpha` with length < 2" + ); + for i in 0..alpha.len() { + assert!( + alpha[i] > 0.0, + "Dirichlet::new called with `alpha` <= 0.0" + ); + } + + Dirichlet { + alpha: alpha.to_vec(), + } + } + + /// Construct a new `Dirichlet` with the given shape parameter and size + /// `alpha`. Panics if `alpha <= 0.0`. + /// `size` . Panic if `size < 2` + #[inline] + pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet { + assert!(alpha > 0.0, "Dirichlet::new called with `alpha` <= 0.0"); + assert!(size > 1, "Dirichlet::new called with `size` <= 1"); + Dirichlet { + alpha: vec![alpha; size], + } + } +} + +impl Distribution> for Dirichlet { + fn sample(&self, rng: &mut R) -> Vec { + let n = self.alpha.len(); + let mut samples = vec![0.0f64; n]; + let mut sum = 0.0f64; + + for i in 0..n { + let g = Gamma::new(self.alpha[i], 1.0); + samples[i] = g.sample(rng); + sum += samples[i]; + } + let invacc = 1.0 / sum; + for i in 0..n { + samples[i] *= invacc; + } + samples + } +} + +#[cfg(test)] +mod test { + use super::Dirichlet; + use distributions::Distribution; + + #[test] + fn test_dirichlet() { + let d = Dirichlet::new(&vec![1.0, 2.0, 3.0]); + let mut rng = ::test::rng(221); + let samples = d.sample(&mut rng); + let _: Vec = samples + .into_iter() + .map(|x| { + assert!(x > 0.0); + x + }) + .collect(); + } + + #[test] + fn test_dirichlet_with_param() { + let alpha = 0.5f64; + let size = 2; + let d = Dirichlet::new_with_param(alpha, size); + let mut rng = ::test::rng(221); + let samples = d.sample(&mut rng); + let _: Vec = samples + .into_iter() + .map(|x| { + assert!(x > 0.0); + x + }) + .collect(); + } + + #[test] + #[should_panic] + fn test_dirichlet_invalid_length() { + Dirichlet::new_with_param(0.5f64, 1); + } + + #[test] + #[should_panic] + fn test_dirichlet_invalid_alpha() { + Dirichlet::new_with_param(0.0f64, 2); + } +} diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index e73fec9fd39..5029414c55d 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -95,7 +95,7 @@ //! - [`ChiSquared`] distribution //! - [`StudentT`] distribution //! - [`FisherF`] distribution -//! +//! - Dirichlet distribution //! //! # Examples //! @@ -148,6 +148,7 @@ //! [`Bernoulli`]: struct.Bernoulli.html //! [`Binomial`]: struct.Binomial.html //! [`ChiSquared`]: struct.ChiSquared.html +//! [`Dirichlet`]: struct.Dirichlet.html //! [`Exp`]: struct.Exp.html //! [`Exp1`]: struct.Exp1.html //! [`FisherF`]: struct.FisherF.html @@ -180,6 +181,7 @@ pub use self::uniform::Uniform as Range; #[cfg(feature = "std")] #[doc(inline)] pub use self::binomial::Binomial; #[doc(inline)] pub use self::bernoulli::Bernoulli; +#[doc(inline)] pub use self::dirichlet::Dirichlet; pub mod uniform; #[cfg(feature="std")] @@ -193,6 +195,7 @@ pub mod uniform; #[cfg(feature = "std")] #[doc(hidden)] pub mod binomial; #[doc(hidden)] pub mod bernoulli; +#[doc(hidden)] pub mod dirichlet; mod float; mod integer; From 14462ca122036aa74b90299a4954dd14ac7702e7 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Wed, 30 May 2018 00:07:19 -0400 Subject: [PATCH 02/10] modifying to support stable version --- src/distributions/dirichlet.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/distributions/dirichlet.rs b/src/distributions/dirichlet.rs index a0696cb046f..de7cf2ece37 100644 --- a/src/distributions/dirichlet.rs +++ b/src/distributions/dirichlet.rs @@ -11,7 +11,8 @@ //! The dirichlet distribution. use Rng; -use distributions::{Distribution, gamma::Gamma}; +use distributions::Distribution; +use distributions::gamma::Gamma; /// The dirichelet distribution `Dirichlet(alpha)`. /// From c530db7d647c7138b77596cd6b2a534b1d8d0c82 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Wed, 30 May 2018 00:15:22 -0400 Subject: [PATCH 03/10] added feature std --- src/distributions/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 5029414c55d..bc78748c7e4 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -181,6 +181,7 @@ pub use self::uniform::Uniform as Range; #[cfg(feature = "std")] #[doc(inline)] pub use self::binomial::Binomial; #[doc(inline)] pub use self::bernoulli::Bernoulli; +#[cfg(feature = "std")] #[doc(inline)] pub use self::dirichlet::Dirichlet; pub mod uniform; @@ -195,6 +196,7 @@ pub mod uniform; #[cfg(feature = "std")] #[doc(hidden)] pub mod binomial; #[doc(hidden)] pub mod bernoulli; +#[cfg(feature = "std")] #[doc(hidden)] pub mod dirichlet; mod float; @@ -207,6 +209,7 @@ mod ziggurat_tables; #[cfg(feature="std")] use distributions::float::IntoFloat; + /// Types that can be used to create a random instance of `Support`. #[deprecated(since="0.5.0", note="use Distribution instead")] pub trait Sample { From 57fb716e34aedba3c3225fdffc0939be49434925 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Wed, 30 May 2018 20:24:05 -0400 Subject: [PATCH 04/10] adding a link --- src/distributions/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index bc78748c7e4..d7b04fa8f3e 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -96,6 +96,7 @@ //! - [`StudentT`] distribution //! - [`FisherF`] distribution //! - Dirichlet distribution +//! - [`Dirichlet`] distribution //! //! # Examples //! From aad2f4536ee345521fde39758ae78c21e90e8ad3 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Wed, 30 May 2018 20:37:01 -0400 Subject: [PATCH 05/10] review comments --- src/distributions/dirichlet.rs | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/distributions/dirichlet.rs b/src/distributions/dirichlet.rs index de7cf2ece37..9bd7846e4e1 100644 --- a/src/distributions/dirichlet.rs +++ b/src/distributions/dirichlet.rs @@ -17,7 +17,7 @@ use distributions::gamma::Gamma; /// The dirichelet distribution `Dirichlet(alpha)`. /// /// The Dirichlet distribution } is a family of continuous multivariate probability distributions parameterized by -/// a vector alpha of positive reals +/// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution /// It is a multivariate generalization of the beta distribution. /// /// # Example @@ -26,7 +26,7 @@ use distributions::gamma::Gamma; /// use rand::prelude::*; /// use rand::distributions::Dirichlet; /// -/// let dirichlet = Dirichlet::new(&vec![1.0, 2.0, 3.0]); +/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]); /// let samples = dirichlet.sample(&mut rand::thread_rng()); /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` @@ -41,21 +41,14 @@ impl Dirichlet { /// Construct a new `Dirichlet` with the given alpha parameter /// `alpha`. Panics if `alpha.len() < 2`. #[inline] - pub fn new(alpha: &[f64]) -> Dirichlet { - assert!( - alpha.len() > 1, - "Dirichlet::new called with `alpha` with length < 2" - ); - for i in 0..alpha.len() { - assert!( - alpha[i] > 0.0, - "Dirichlet::new called with `alpha` <= 0.0" - ); + pub fn new>>(alpha: V) -> Dirichlet { + let a = alpha.into(); + assert!(a.len() > 1); + for i in 0..a.len() { + assert!(a[i] > 0.0); } - Dirichlet { - alpha: alpha.to_vec(), - } + Dirichlet { alpha: a.into() } } /// Construct a new `Dirichlet` with the given shape parameter and size @@ -63,8 +56,8 @@ impl Dirichlet { /// `size` . Panic if `size < 2` #[inline] pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet { - assert!(alpha > 0.0, "Dirichlet::new called with `alpha` <= 0.0"); - assert!(size > 1, "Dirichlet::new called with `size` <= 1"); + assert!(alpha > 0.0); + assert!(size > 1); Dirichlet { alpha: vec![alpha; size], } @@ -97,7 +90,7 @@ mod test { #[test] fn test_dirichlet() { - let d = Dirichlet::new(&vec![1.0, 2.0, 3.0]); + let d = Dirichlet::new(vec![1.0, 2.0, 3.0]); let mut rng = ::test::rng(221); let samples = d.sample(&mut rng); let _: Vec = samples From ad13516c2286e835b20339a8083bce4203c714dc Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Wed, 30 May 2018 20:48:57 -0400 Subject: [PATCH 06/10] removing extra empty line --- src/distributions/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index f3f5e16fb55..fe4af8d94be 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -213,7 +213,6 @@ mod ziggurat_tables; #[cfg(feature="std")] use distributions::float::IntoFloat; - /// Types that can be used to create a random instance of `Support`. #[deprecated(since="0.5.0", note="use Distribution instead")] pub trait Sample { From d7a5ced48b84e19e2ec918bfd90409ad1a7f9a16 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Wed, 30 May 2018 21:08:16 -0400 Subject: [PATCH 07/10] added std feature --- src/distributions/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index fe4af8d94be..490c9214614 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -183,6 +183,7 @@ pub use self::uniform::Uniform as Range; #[cfg(feature = "std")] #[doc(inline)] pub use self::binomial::Binomial; #[doc(inline)] pub use self::bernoulli::Bernoulli; +#[cfg(feature = "std")] #[doc(inline)] pub use self::cauchy::Cauchy; #[cfg(feature = "std")] #[doc(inline)] pub use self::dirichlet::Dirichlet; @@ -199,6 +200,7 @@ pub mod uniform; #[cfg(feature = "std")] #[doc(hidden)] pub mod binomial; #[doc(hidden)] pub mod bernoulli; +#[cfg(feature = "std")] #[doc(hidden)] pub mod cauchy; #[cfg(feature = "std")] #[doc(hidden)] pub mod dirichlet; From d7289890600e96b28888e2f63fed8d51d353b771 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Wed, 30 May 2018 21:22:52 -0400 Subject: [PATCH 08/10] removed stray } --- src/distributions/dirichlet.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distributions/dirichlet.rs b/src/distributions/dirichlet.rs index 9bd7846e4e1..93da1639473 100644 --- a/src/distributions/dirichlet.rs +++ b/src/distributions/dirichlet.rs @@ -16,7 +16,7 @@ use distributions::gamma::Gamma; /// The dirichelet distribution `Dirichlet(alpha)`. /// -/// The Dirichlet distribution } is a family of continuous multivariate probability distributions parameterized by +/// The Dirichlet distribution is a family of continuous multivariate probability distributions parameterized by /// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution /// It is a multivariate generalization of the beta distribution. /// From 732209a8cffd7b6cb8afa78296035e8ca8026cbf Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Thu, 31 May 2018 07:48:35 -0400 Subject: [PATCH 09/10] addressing review comments --- src/distributions/dirichlet.rs | 2 +- src/distributions/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distributions/dirichlet.rs b/src/distributions/dirichlet.rs index 93da1639473..4560d80f3b1 100644 --- a/src/distributions/dirichlet.rs +++ b/src/distributions/dirichlet.rs @@ -48,7 +48,7 @@ impl Dirichlet { assert!(a[i] > 0.0); } - Dirichlet { alpha: a.into() } + Dirichlet { alpha: a } } /// Construct a new `Dirichlet` with the given shape parameter and size diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 490c9214614..6904d7b62b2 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -95,7 +95,7 @@ //! - [`ChiSquared`] distribution //! - [`StudentT`] distribution //! - [`FisherF`] distribution -//! - Dirichlet distribution +//! - Related to continuous multivariate probability distributions //! - [`Dirichlet`] distribution //! //! # Examples From 2f706a6a2db0f4fe2e48b282398e7423c18f5609 Mon Sep 17 00:00:00 2001 From: Rohit Joshi Date: Fri, 1 Jun 2018 17:13:47 -0400 Subject: [PATCH 10/10] doc markup update --- src/distributions/dirichlet.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/distributions/dirichlet.rs b/src/distributions/dirichlet.rs index 4560d80f3b1..281bdfb1821 100644 --- a/src/distributions/dirichlet.rs +++ b/src/distributions/dirichlet.rs @@ -38,8 +38,11 @@ pub struct Dirichlet { } impl Dirichlet { - /// Construct a new `Dirichlet` with the given alpha parameter - /// `alpha`. Panics if `alpha.len() < 2`. + /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. + /// + /// # Panics + /// - if `alpha.len() < 2` + /// #[inline] pub fn new>>(alpha: V) -> Dirichlet { let a = alpha.into(); @@ -51,9 +54,12 @@ impl Dirichlet { Dirichlet { alpha: a } } - /// Construct a new `Dirichlet` with the given shape parameter and size - /// `alpha`. Panics if `alpha <= 0.0`. - /// `size` . Panic if `size < 2` + /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. + /// + /// # Panics + /// - if `alpha <= 0.0` + /// - if `size < 2` + /// #[inline] pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet { assert!(alpha > 0.0);