Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Dirichlet distribution #485

Merged
merged 12 commits into from Jun 12, 2018
29 changes: 11 additions & 18 deletions src/distributions/dirichlet.rs
Expand Up @@ -17,7 +17,7 @@ use distributions::gamma::Gamma;
/// The dirichelet distribution `Dirichlet(alpha)`.
///
/// The Dirichlet distribution } is a family of continuous multivariate probability distributions parameterized by
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray }

Copying a fancy description from Wikipedia doesn't really explain much, especially since the links are missing. Not that I have a better idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the Mathematica explanation a bit more than Wikipedia's.

/// a vector alpha of positive reals
/// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naked link will look weird in the docs. I think you can remove, there is not precedence in Rand for linking to Wikipedia.

/// It is a multivariate generalization of the beta distribution.
///
/// # Example
Expand All @@ -26,7 +26,7 @@ use distributions::gamma::Gamma;
/// use rand::prelude::*;
/// use rand::distributions::Dirichlet;
///
/// let dirichlet = Dirichlet::new(&vec![1.0, 2.0, 3.0]);
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]);
/// let samples = dirichlet.sample(&mut rand::thread_rng());
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
Expand All @@ -41,30 +41,23 @@ impl Dirichlet {
/// Construct a new `Dirichlet` with the given alpha parameter
/// `alpha`. Panics if `alpha.len() < 2`.
#[inline]
pub fn new(alpha: &[f64]) -> Dirichlet {
assert!(
alpha.len() > 1,
"Dirichlet::new called with `alpha` with length < 2"
);
for i in 0..alpha.len() {
assert!(
alpha[i] > 0.0,
"Dirichlet::new called with `alpha` <= 0.0"
);
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet {
let a = alpha.into();
assert!(a.len() > 1);
for i in 0..a.len() {
assert!(a[i] > 0.0);
}

Dirichlet {
alpha: alpha.to_vec(),
}
Dirichlet { alpha: a.into() }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need into() again here — it's already a Vec

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit. Not sure why my last comment went somewhere else. a is already your target type, so you don't need .into() again.

}

/// Construct a new `Dirichlet` with the given shape parameter and size
/// `alpha`. Panics if `alpha <= 0.0`.
/// `size` . Panic if `size < 2`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't render well. If you want a list, leave a blank line, the prefix each item with - (it's Markdown). Otherwise just rewrite as two sentences.

#[inline]
pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet {
assert!(alpha > 0.0, "Dirichlet::new called with `alpha` <= 0.0");
assert!(size > 1, "Dirichlet::new called with `size` <= 1");
assert!(alpha > 0.0);
assert!(size > 1);
Dirichlet {
alpha: vec![alpha; size],
}
Expand Down Expand Up @@ -97,7 +90,7 @@ mod test {

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(&vec![1.0, 2.0, 3.0]);
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]);
let mut rng = ::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
Expand Down