Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Dirichlet distribution #485

Merged
merged 12 commits into from Jun 12, 2018
132 changes: 132 additions & 0 deletions src/distributions/dirichlet.rs
@@ -0,0 +1,132 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! The dirichlet distribution.

use Rng;
use distributions::Distribution;
use distributions::gamma::Gamma;

/// The dirichelet distribution `Dirichlet(alpha)`.
///
/// The Dirichlet distribution is a family of continuous multivariate probability distributions parameterized by
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a bit long. I think we usually wrap comments at 80 characters.

/// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naked link will look weird in the docs. I think you can remove, there is not precedence in Rand for linking to Wikipedia.

/// It is a multivariate generalization of the beta distribution.
///
/// # Example
///
/// ```
/// use rand::prelude::*;
/// use rand::distributions::Dirichlet;
///
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]);
/// let samples = dirichlet.sample(&mut rand::thread_rng());
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```

#[derive(Clone, Debug)]
pub struct Dirichlet {
/// Concentration parameters (alpha)
alpha: Vec<f64>,
Copy link
Member

@dhardy dhardy Jun 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it, we could probably use alpha: [f64] here. It makes the type "unsized" (i.e. users have to write Box<Dirichlet>) but is more flexible (potentially more optimal).

On the other hand it may not be worth it since it makes the type less ergonomic to use for what is probably not a lot of gain.

Another option would be Dirichlet<N: usize> { alpha: [f64; N] } — except I don't think Rust supports that yet (though it would also allow sample(..) -> [f64; N], thus side-stepping @vks's concerns).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do our distributions even work when you write Box<Dirichlet>?

Copy link
Member

@dhardy dhardy Jun 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rng::sample won't of course but Distribution::sample will. Either way it's not really a great choice (less convenient for users).

}

impl Dirichlet {
/// Construct a new `Dirichlet` with the given alpha parameter
/// `alpha`. Panics if `alpha.len() < 2`.
#[inline]
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet {
let a = alpha.into();
assert!(a.len() > 1);
for i in 0..a.len() {
assert!(a[i] > 0.0);
}

Dirichlet { alpha: a.into() }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need into() again here — it's already a Vec

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit. Not sure why my last comment went somewhere else. a is already your target type, so you don't need .into() again.

}

/// Construct a new `Dirichlet` with the given shape parameter and size
/// `alpha`. Panics if `alpha <= 0.0`.
/// `size` . Panic if `size < 2`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't render well. If you want a list, leave a blank line, the prefix each item with - (it's Markdown). Otherwise just rewrite as two sentences.

#[inline]
pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet {
assert!(alpha > 0.0);
assert!(size > 1);
Dirichlet {
alpha: vec![alpha; size],
}
}
}

impl Distribution<Vec<f64>> for Dirichlet {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure our current distribution trait is well suited for multivariate distributions. It would be nice to sample without allocating, but this requires different method. Something like fn sample_multi(&self, &mut Rng, &mut [f64]).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is why you opened #496. I agree. On the other hand, I'm not too fussed about having to make breaking changes to this distribution later (it's still better for users than not having it, and we're not close to 1.0).

fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
let n = self.alpha.len();
let mut samples = vec![0.0f64; n];
let mut sum = 0.0f64;

for i in 0..n {
let g = Gamma::new(self.alpha[i], 1.0);
samples[i] = g.sample(rng);
sum += samples[i];
}
let invacc = 1.0 / sum;
for i in 0..n {
samples[i] *= invacc;
}
samples
}
}

#[cfg(test)]
mod test {
use super::Dirichlet;
use distributions::Distribution;

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]);
let mut rng = ::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}

#[test]
fn test_dirichlet_with_param() {
let alpha = 0.5f64;
let size = 2;
let d = Dirichlet::new_with_param(alpha, size);
let mut rng = ::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Dirichlet::new_with_param(0.5f64, 1);
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_param(0.0f64, 2);
}
}
9 changes: 7 additions & 2 deletions src/distributions/mod.rs
Expand Up @@ -81,7 +81,6 @@
//! - Related to real-valued quantities that grow linearly
//! (e.g. errors, offsets):
//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive
//! - [`Cauchy`] distribution
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitjoshi did you intend to move the reference to Cauchy in this documentation? Because you haven't added it back.

//! - Related to Bernoulli trials (yes/no events, with a given probability):
//! - [`Binomial`] distribution
//! - [`Bernoulli`] distribution, similar to [`Rng::gen_bool`].
Expand All @@ -96,7 +95,8 @@
//! - [`ChiSquared`] distribution
//! - [`StudentT`] distribution
//! - [`FisherF`] distribution
//!
//! - Dirichlet distribution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The [`Dirichlet`] link is missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been added now, but Dirichlet seems to be mentioned twice in a row

//! - [`Dirichlet`] distribution
//!
//! # Examples
//!
Expand Down Expand Up @@ -150,6 +150,7 @@
//! [`Binomial`]: struct.Binomial.html
//! [`Cauchy`]: struct.Cauchy.html
//! [`ChiSquared`]: struct.ChiSquared.html
//! [`Dirichlet`]: struct.Dirichlet.html
//! [`Exp`]: struct.Exp.html
//! [`Exp1`]: struct.Exp1.html
//! [`FisherF`]: struct.FisherF.html
Expand Down Expand Up @@ -184,6 +185,8 @@ pub use self::uniform::Uniform as Range;
#[doc(inline)] pub use self::bernoulli::Bernoulli;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::cauchy::Cauchy;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::dirichlet::Dirichlet;

pub mod uniform;
#[cfg(feature="std")]
Expand All @@ -199,6 +202,8 @@ pub mod uniform;
#[doc(hidden)] pub mod bernoulli;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod cauchy;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod dirichlet;

mod float;
mod integer;
Expand Down