From 1dec41cc00058a107e912908065e9c0ff397ff38 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Tue, 31 Dec 2019 11:34:28 +0800 Subject: [PATCH 01/24] feat: multivariate_normal distribution --- Cargo.toml | 5 +- src/distribution/mod.rs | 2 + src/distribution/multivariate_normal.rs | 170 ++++++++++++++++++++++++ src/statistics/traits.rs | 4 + 4 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 src/distribution/multivariate_normal.rs diff --git a/Cargo.toml b/Cargo.toml index 96a52797..be65717b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,4 +18,7 @@ name = "statrs" path = "src/lib.rs" [dependencies] -rand = "0.7" +rand = "0.6.5" +nalgebra = "0.18.1" +nalgebra-mvn = "0.1.1" +num-traits = "0.2.10" \ No newline at end of file diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 8eb7acc9..0c30279e 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -20,6 +20,7 @@ pub use self::hypergeometric::Hypergeometric; pub use self::inverse_gamma::InverseGamma; pub use self::log_normal::LogNormal; pub use self::multinomial::Multinomial; +pub use self::multivariate_normal::MultivariateNormal; pub use self::normal::Normal; pub use self::pareto::Pareto; pub use self::poisson::Poisson; @@ -48,6 +49,7 @@ mod internal; mod inverse_gamma; mod log_normal; mod multinomial; +mod multivariate_normal; mod normal; mod pareto; mod poisson; diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs new file mode 100644 index 00000000..77cdeeee --- /dev/null +++ b/src/distribution/multivariate_normal.rs @@ -0,0 +1,170 @@ +use crate::distribution::Continuous; +use crate::distribution::Normal; +use crate::statistics::{Covariance, Entropy, Max, Mean, Min, Mode}; +use crate::{Result, StatsError}; +use nalgebra::{ + base::allocator::Allocator, + base::{dimension::DimName, dimension::DimSub, MatrixN, VectorN}, + Cholesky, DefaultAllocator, Dim, DimMin, Dynamic, RealField, LU, U1, +}; +use num_traits::bounds::Bounded; +use rand::distributions::Distribution; +use rand::Rng; + +pub struct MultivariateNormal +where + Real: RealField, + N: Dim + DimMin, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + mvn: nalgebra_mvn::MultivariateNormal, + cov_chol_decomp: MatrixN, +} + +impl MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimSub, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { + match nalgebra_mvn::MultivariateNormal::from_mean_and_covariance(&mean, &cov.clone()) { + Ok(mvn) => { + // Store the Cholesky decomposition of the covariance matrix + // for sampling + let cholesky_decomp = Cholesky::new(cov.clone()).unwrap().unpack(); + Ok(MultivariateNormal { + mvn: mvn, + cov_chol_decomp: cholesky_decomp, + }) + } + Err(_) => Err(StatsError::BadParams), + } + } +} + +impl Distribution> for MultivariateNormal +where + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn sample(&self, rng: &mut R) -> VectorN { + let d = Normal::new(0., 1.).unwrap(); + let z = VectorN::::from_distribution(&d, rng); + (self.cov_chol_decomp.clone() * z) + self.mvn.mean() + } +} + +impl Min> for MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn min(&self) -> VectorN { + VectorN::min_value() + } +} + +impl Max> for MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn max(&self) -> VectorN { + VectorN::max_value() + } +} + +impl Mean> for MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn mean(&self) -> VectorN { + self.mvn.mean() + } +} + +impl Covariance> for MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn variance(&self) -> MatrixN { + Cholesky::new(self.mvn.precision().clone()) + .unwrap() + .inverse() + } +} + +impl Entropy for MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn entropy(&self) -> Real { + LU::new(self.variance().clone().scale(Real::two_pi() * Real::e())) + .determinant() + .ln() + } +} + +impl Mode> for MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn mode(&self) -> VectorN { + self.mvn.mean() + } +} + +impl Continuous, Real> for MultivariateNormal +where + Real: RealField, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, +{ + fn pdf(&self, x: VectorN) -> Real { + *self.mvn.pdf::(&x.transpose()).get((0, 0)).unwrap() + } + fn ln_pdf(&self, x: VectorN) -> Real { + *self.mvn.logpdf::(&x.transpose()).get((0, 0)).unwrap() + } +} diff --git a/src/statistics/traits.rs b/src/statistics/traits.rs index 4e3b3adf..f068548f 100644 --- a/src/statistics/traits.rs +++ b/src/statistics/traits.rs @@ -104,6 +104,10 @@ pub trait Variance: Mean { fn std_dev(&self) -> T; } +pub trait Covariance { + fn variance(&self) -> T; +} + pub trait CheckedVariance: CheckedMean { /// Returns the variance. /// # Examples From aa2da3b8111d30dc715078a32f9bc7853244875c Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 1 Jan 2020 02:15:01 +0800 Subject: [PATCH 02/24] doc: add documentation to distribution::multivariate_normal --- src/distribution/multivariate_normal.rs | 83 +++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 77cdeeee..49adac66 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -11,6 +11,23 @@ use num_traits::bounds::Bounded; use rand::distributions::Distribution; use rand::Rng; +/// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) +/// distribution using the "nalgebra" crate for matrix operations +/// +/// # Examples +/// +/// ``` +/// use statrs::distribution::{MultivariateNormal, Continuous}; +/// use nalgebra::base::dimension::U2; +/// use nalgebra::{Vector2, Matrix2}; +/// use statrs::statistics::{Mean, Covariance}; +/// +/// let mvn = MultivariateNormal::::new(&Vector2::::zeros(), &Matrix2::::identity()).unwrap(); +/// assert_eq!(mvn.mean(), Vector2::::new(0., 0.)); +/// assert_eq!(mvn.variance(), Matrix2::::new(1., 0., 0., 1.)); +/// assert_eq!(mvn.pdf(Vector2::::new(1., 1.)), 0.05854983152431917); +/// ``` +#[derive(Debug, Clone)] pub struct MultivariateNormal where Real: RealField, @@ -33,6 +50,13 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Constructs a new multivariate normal distribution with a mean of `mean` + /// and covariance matrix `cov` + /// + /// # Errors + /// + /// Returns an error if the given covariance matrix is not + /// symmetric or positive-definite pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { match nalgebra_mvn::MultivariateNormal::from_mean_and_covariance(&mean, &cov.clone()) { Ok(mvn) => { @@ -57,6 +81,15 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Samples from the multivariate normal distribution + /// + /// # Formula + /// L * Z + μ + /// + /// where `L` is the Cholesky decomposition of the covariance matrix, + /// `Z` is a vector of normally distributed random variables, and + /// `μ` is the mean vector + fn sample(&self, rng: &mut R) -> VectorN { let d = Normal::new(0., 1.).unwrap(); let z = VectorN::::from_distribution(&d, rng); @@ -73,6 +106,8 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Returns the minimum value in the domain of the + /// multivariate normal distribution represented by a real vector fn min(&self) -> VectorN { VectorN::min_value() } @@ -87,6 +122,8 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Returns the maximum value in the domain of the + /// multivariate normal distribution represented by a real vector fn max(&self) -> VectorN { VectorN::max_value() } @@ -101,6 +138,11 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Returns the mean of the normal distribution + /// + /// # Remarks + /// + /// This is the same mean used to construct the distribution fn mean(&self) -> VectorN { self.mvn.mean() } @@ -115,6 +157,7 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Returns the covariance matrix of the multivariate normal distribution fn variance(&self) -> MatrixN { Cholesky::new(self.mvn.precision().clone()) .unwrap() @@ -131,6 +174,15 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Returns the entropy of the multivariate normal distribution + /// + /// # Formula + /// + /// ```ignore + /// (1 / 2) * ln(det(2 * π * e * Σ)) + /// ``` + /// + /// where `Σ` is the covariance matrix and `det` is the determinant fn entropy(&self) -> Real { LU::new(self.variance().clone().scale(Real::two_pi() * Real::e())) .determinant() @@ -147,6 +199,15 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Returns the mode of the multivariate normal distribution + /// + /// # Formula + /// + /// ```ignore + /// μ + /// ``` + /// + /// where `μ` is the mean fn mode(&self) -> VectorN { self.mvn.mean() } @@ -161,9 +222,31 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { + /// Calculates the probability density function for the multivariate + /// normal distribution at `x` + /// + /// # Formula + /// + /// ```ignore + /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) + /// ``` + /// + /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant + /// of the covariance matrix, and `k` is the dimension of the distribution fn pdf(&self, x: VectorN) -> Real { *self.mvn.pdf::(&x.transpose()).get((0, 0)).unwrap() } + /// Calculates the log probability density function for the multivariate + /// normal distribution at `x` + /// + /// # Formula + /// + /// ```ignore + /// ln((2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ))) + /// ``` + /// + /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant + /// of the covariance matrix, and `k` is the dimension of the distribution fn ln_pdf(&self, x: VectorN) -> Real { *self.mvn.logpdf::(&x.transpose()).get((0, 0)).unwrap() } From 36a943e7aae68df2f9623bd1a432197f1ba5ff39 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 1 Jan 2020 02:15:59 +0800 Subject: [PATCH 03/24] fix: add checks to constructor for covariance matrix --- src/distribution/multivariate_normal.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 49adac66..5656d6a2 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -58,15 +58,22 @@ where /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { + // Check that the provided covariance matrix is symmetric + if (cov.lower_triangle() != cov.upper_triangle().transpose()) { return Err(StatsError::BadParams); } match nalgebra_mvn::MultivariateNormal::from_mean_and_covariance(&mean, &cov.clone()) { Ok(mvn) => { // Store the Cholesky decomposition of the covariance matrix // for sampling - let cholesky_decomp = Cholesky::new(cov.clone()).unwrap().unpack(); - Ok(MultivariateNormal { - mvn: mvn, - cov_chol_decomp: cholesky_decomp, - }) + match Cholesky::new(cov.clone()) { + None => Err(StatsError::BadParams), + Some(cholesky_decomp) => { + Ok(MultivariateNormal { + mvn: mvn, + cov_chol_decomp: cholesky_decomp.unpack(), + }) + }, + } + } Err(_) => Err(StatsError::BadParams), } From a6f4c5131d9a665ae60d5d62d2ede4c1ab228591 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Thu, 2 Jan 2020 10:55:39 +0800 Subject: [PATCH 04/24] fix: Build with nalgebra-mvn with dep nalgebra updated to 0.19.0, update other dependencies --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index be65717b..2e3834eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ name = "statrs" path = "src/lib.rs" [dependencies] -rand = "0.6.5" -nalgebra = "0.18.1" -nalgebra-mvn = "0.1.1" +rand = "0.7.2" +nalgebra = "0.19.0" +nalgebra-mvn = { path = "../nalgebra-mvn" } num-traits = "0.2.10" \ No newline at end of file From 3009cdc31f0afd3086571bf0634b5a2e3c39af9f Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Thu, 2 Jan 2020 10:55:59 +0800 Subject: [PATCH 05/24] fix: Make trait bounds consistent --- src/distribution/multivariate_normal.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 5656d6a2..da0bc4cb 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -31,7 +31,7 @@ use rand::Rng; pub struct MultivariateNormal where Real: RealField, - N: Dim + DimMin, + N: Dim + DimMin + DimName, DefaultAllocator: Allocator, DefaultAllocator: Allocator, DefaultAllocator: Allocator, @@ -44,7 +44,7 @@ where impl MultivariateNormal where Real: RealField, - N: Dim + DimMin + DimSub, + N: Dim + DimMin + DimName, DefaultAllocator: Allocator, DefaultAllocator: Allocator, DefaultAllocator: Allocator, From e8b5a66ccaa513d6acb58f4f758effcb879bb2d9 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Thu, 2 Jan 2020 10:58:06 +0800 Subject: [PATCH 06/24] fix: Make dependencies less specific --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2e3834eb..bb003edd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ name = "statrs" path = "src/lib.rs" [dependencies] -rand = "0.7.2" -nalgebra = "0.19.0" +rand = "0.7" +nalgebra = "0.19" nalgebra-mvn = { path = "../nalgebra-mvn" } num-traits = "0.2.10" \ No newline at end of file From d9044b86b1df1341ef929cbc45fe1c1dfda403a2 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Fri, 3 Jan 2020 10:51:54 +0800 Subject: [PATCH 07/24] fix: Update nalgebra-mvn to v0.2 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index bb003edd..7fe8d051 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,5 +20,5 @@ path = "src/lib.rs" [dependencies] rand = "0.7" nalgebra = "0.19" -nalgebra-mvn = { path = "../nalgebra-mvn" } +nalgebra-mvn = "0.2" num-traits = "0.2.10" \ No newline at end of file From 989df358bd65244dc908cd41910530a790affeb4 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Fri, 3 Jan 2020 10:59:58 +0800 Subject: [PATCH 08/24] fix: Run cargo fmt --- src/distribution/multivariate_normal.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index da0bc4cb..0df38e1a 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -59,21 +59,20 @@ where /// symmetric or positive-definite pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { // Check that the provided covariance matrix is symmetric - if (cov.lower_triangle() != cov.upper_triangle().transpose()) { return Err(StatsError::BadParams); } + if (cov.lower_triangle() != cov.upper_triangle().transpose()) { + return Err(StatsError::BadParams); + } match nalgebra_mvn::MultivariateNormal::from_mean_and_covariance(&mean, &cov.clone()) { Ok(mvn) => { // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { - None => Err(StatsError::BadParams), - Some(cholesky_decomp) => { - Ok(MultivariateNormal { - mvn: mvn, - cov_chol_decomp: cholesky_decomp.unpack(), - }) - }, + None => Err(StatsError::BadParams), + Some(cholesky_decomp) => Ok(MultivariateNormal { + mvn: mvn, + cov_chol_decomp: cholesky_decomp.unpack(), + }), } - } Err(_) => Err(StatsError::BadParams), } From d4d0b7c76ce8792fd1b9fda8670b066126f29d30 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Fri, 3 Jan 2020 13:21:45 +0800 Subject: [PATCH 09/24] fix: Remove dependence on nalgebra-mvn --- Cargo.toml | 1 - src/distribution/multivariate_normal.rs | 60 ++++++++++++++++--------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7fe8d051..db576a46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,5 +20,4 @@ path = "src/lib.rs" [dependencies] rand = "0.7" nalgebra = "0.19" -nalgebra-mvn = "0.2" num-traits = "0.2.10" \ No newline at end of file diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 0df38e1a..4d681a4a 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -8,6 +8,7 @@ use nalgebra::{ Cholesky, DefaultAllocator, Dim, DimMin, Dynamic, RealField, LU, U1, }; use num_traits::bounds::Bounded; +use num_traits::real::Real; use rand::distributions::Distribution; use rand::Rng; @@ -37,8 +38,11 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { - mvn: nalgebra_mvn::MultivariateNormal, cov_chol_decomp: MatrixN, + mu: VectorN, + cov: MatrixN, + precision: MatrixN, + pdf_const: Real, } impl MultivariateNormal @@ -62,19 +66,23 @@ where if (cov.lower_triangle() != cov.upper_triangle().transpose()) { return Err(StatsError::BadParams); } - match nalgebra_mvn::MultivariateNormal::from_mean_and_covariance(&mean, &cov.clone()) { - Ok(mvn) => { - // Store the Cholesky decomposition of the covariance matrix - // for sampling - match Cholesky::new(cov.clone()) { - None => Err(StatsError::BadParams), - Some(cholesky_decomp) => Ok(MultivariateNormal { - mvn: mvn, - cov_chol_decomp: cholesky_decomp.unpack(), - }), - } - } - Err(_) => Err(StatsError::BadParams), + let cov_det = LU::new(cov.clone()).determinant(); + let pdf_const = (Real::two_pi() + .powi(mean.nrows() as i32) + .recip() + .mul(cov_det.abs())) + .sqrt(); + // Store the Cholesky decomposition of the covariance matrix + // for sampling + match Cholesky::new(cov.clone()) { + None => Err(StatsError::BadParams), + Some(cholesky_decomp) => Ok(MultivariateNormal { + cov_chol_decomp: cholesky_decomp.clone().unpack(), + mu: mean.clone(), + cov: cov.clone(), + precision: cholesky_decomp.inverse(), + pdf_const: pdf_const, + }), } } } @@ -99,7 +107,7 @@ where fn sample(&self, rng: &mut R) -> VectorN { let d = Normal::new(0., 1.).unwrap(); let z = VectorN::::from_distribution(&d, rng); - (self.cov_chol_decomp.clone() * z) + self.mvn.mean() + (self.cov_chol_decomp.clone() * z) + self.mu.clone() } } @@ -150,7 +158,7 @@ where /// /// This is the same mean used to construct the distribution fn mean(&self) -> VectorN { - self.mvn.mean() + self.mu.clone() } } @@ -165,9 +173,7 @@ where { /// Returns the covariance matrix of the multivariate normal distribution fn variance(&self) -> MatrixN { - Cholesky::new(self.mvn.precision().clone()) - .unwrap() - .inverse() + self.cov.clone() } } @@ -215,7 +221,7 @@ where /// /// where `μ` is the mean fn mode(&self) -> VectorN { - self.mvn.mean() + self.mu.clone() } } @@ -240,7 +246,12 @@ where /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution fn pdf(&self, x: VectorN) -> Real { - *self.mvn.pdf::(&x.transpose()).get((0, 0)).unwrap() + let dv = x - &self.mu; + let exp_term = nalgebra::convert::(-0.5) + * *(&dv.transpose() * &self.precision * &dv) + .get((0, 0)) + .unwrap(); + self.pdf_const * exp_term.exp() } /// Calculates the log probability density function for the multivariate /// normal distribution at `x` @@ -254,6 +265,11 @@ where /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution fn ln_pdf(&self, x: VectorN) -> Real { - *self.mvn.logpdf::(&x.transpose()).get((0, 0)).unwrap() + let dv = x - &self.mu; + let exp_term = nalgebra::convert::(-0.5) + * *(&dv.transpose() * &self.precision * &dv) + .get((0, 0)) + .unwrap(); + self.pdf_const.ln() + exp_term } } From db05bd3f3ee4aafee4ddac0d66786ff07076496a Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Fri, 3 Jan 2020 13:29:28 +0800 Subject: [PATCH 10/24] fix: Remove unused imports, refactor --- src/distribution/multivariate_normal.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 4d681a4a..ffa7c219 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -4,11 +4,10 @@ use crate::statistics::{Covariance, Entropy, Max, Mean, Min, Mode}; use crate::{Result, StatsError}; use nalgebra::{ base::allocator::Allocator, - base::{dimension::DimName, dimension::DimSub, MatrixN, VectorN}, - Cholesky, DefaultAllocator, Dim, DimMin, Dynamic, RealField, LU, U1, + base::{dimension::DimName, MatrixN, VectorN}, + Cholesky, DefaultAllocator, Dim, DimMin, RealField, LU, U1, }; use num_traits::bounds::Bounded; -use num_traits::real::Real; use rand::distributions::Distribution; use rand::Rng; @@ -63,7 +62,7 @@ where /// symmetric or positive-definite pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { // Check that the provided covariance matrix is symmetric - if (cov.lower_triangle() != cov.upper_triangle().transpose()) { + if cov.lower_triangle() != cov.upper_triangle().transpose() { return Err(StatsError::BadParams); } let cov_det = LU::new(cov.clone()).determinant(); From 79df75f3c00c7a5ae5f1f763c0fdaaa3829fcffb Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Tue, 7 Jan 2020 11:11:47 +0800 Subject: [PATCH 11/24] fix: Enforce f64 as numeric type --- src/distribution/multivariate_normal.rs | 132 +++++++++++------------- 1 file changed, 61 insertions(+), 71 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index ffa7c219..ab65df8d 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -10,6 +10,7 @@ use nalgebra::{ use num_traits::bounds::Bounded; use rand::distributions::Distribution; use rand::Rng; +use std::f64::consts::{PI, E}; /// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) /// distribution using the "nalgebra" crate for matrix operations @@ -22,35 +23,33 @@ use rand::Rng; /// use nalgebra::{Vector2, Matrix2}; /// use statrs::statistics::{Mean, Covariance}; /// -/// let mvn = MultivariateNormal::::new(&Vector2::::zeros(), &Matrix2::::identity()).unwrap(); -/// assert_eq!(mvn.mean(), Vector2::::new(0., 0.)); -/// assert_eq!(mvn.variance(), Matrix2::::new(1., 0., 0., 1.)); -/// assert_eq!(mvn.pdf(Vector2::::new(1., 1.)), 0.05854983152431917); +/// let mvn = MultivariateNormal::::new(&Vector2::zeros(), &Matrix2::identity()).unwrap(); +/// assert_eq!(mvn.mean(), Vector2::new(0., 0.)); +/// assert_eq!(mvn.variance(), Matrix2::new(1., 0., 0., 1.)); +/// assert_eq!(mvn.pdf(Vector2::new(1., 1.)), 0.05854983152431917); /// ``` #[derive(Debug, Clone)] -pub struct MultivariateNormal +pub struct MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { - cov_chol_decomp: MatrixN, - mu: VectorN, - cov: MatrixN, - precision: MatrixN, - pdf_const: Real, + cov_chol_decomp: MatrixN, + mu: VectorN, + cov: MatrixN, + precision: MatrixN, + pdf_const: f64, } -impl MultivariateNormal +impl MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Constructs a new multivariate normal distribution with a mean of `mean` @@ -60,16 +59,16 @@ where /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { + pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { // Check that the provided covariance matrix is symmetric if cov.lower_triangle() != cov.upper_triangle().transpose() { return Err(StatsError::BadParams); } let cov_det = LU::new(cov.clone()).determinant(); - let pdf_const = (Real::two_pi() + let pdf_const = ((2. * PI) .powi(mean.nrows() as i32) .recip() - .mul(cov_det.abs())) + * cov_det.abs()) .sqrt(); // Store the Cholesky decomposition of the covariance matrix // for sampling @@ -86,7 +85,7 @@ where } } -impl Distribution> for MultivariateNormal +impl Distribution> for MultivariateNormal where N: Dim + DimMin + DimName, DefaultAllocator: Allocator, @@ -110,45 +109,42 @@ where } } -impl Min> for MultivariateNormal +impl Min> for MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Returns the minimum value in the domain of the /// multivariate normal distribution represented by a real vector - fn min(&self) -> VectorN { + fn min(&self) -> VectorN { VectorN::min_value() } } -impl Max> for MultivariateNormal +impl Max> for MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Returns the maximum value in the domain of the /// multivariate normal distribution represented by a real vector - fn max(&self) -> VectorN { + fn max(&self) -> VectorN { VectorN::max_value() } } -impl Mean> for MultivariateNormal +impl Mean> for MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Returns the mean of the normal distribution @@ -156,33 +152,31 @@ where /// # Remarks /// /// This is the same mean used to construct the distribution - fn mean(&self) -> VectorN { + fn mean(&self) -> VectorN { self.mu.clone() } } -impl Covariance> for MultivariateNormal +impl Covariance> for MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Returns the covariance matrix of the multivariate normal distribution - fn variance(&self) -> MatrixN { + fn variance(&self) -> MatrixN { self.cov.clone() } } -impl Entropy for MultivariateNormal +impl Entropy for MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Returns the entropy of the multivariate normal distribution @@ -194,20 +188,19 @@ where /// ``` /// /// where `Σ` is the covariance matrix and `det` is the determinant - fn entropy(&self) -> Real { - LU::new(self.variance().clone().scale(Real::two_pi() * Real::e())) + fn entropy(&self) -> f64 { + LU::new(self.variance().clone().scale(2. * PI * E)) .determinant() .ln() } } -impl Mode> for MultivariateNormal +impl Mode> for MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Returns the mode of the multivariate normal distribution @@ -219,18 +212,17 @@ where /// ``` /// /// where `μ` is the mean - fn mode(&self) -> VectorN { + fn mode(&self) -> VectorN { self.mu.clone() } } -impl Continuous, Real> for MultivariateNormal +impl Continuous, f64> for MultivariateNormal where - Real: RealField, N: Dim + DimMin + DimName, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, DefaultAllocator: Allocator<(usize, usize), >::Output>, { /// Calculates the probability density function for the multivariate @@ -244,10 +236,9 @@ where /// /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution - fn pdf(&self, x: VectorN) -> Real { + fn pdf(&self, x: VectorN) -> f64 { let dv = x - &self.mu; - let exp_term = nalgebra::convert::(-0.5) - * *(&dv.transpose() * &self.precision * &dv) + let exp_term = -0.5 * *(&dv.transpose() * &self.precision * &dv) .get((0, 0)) .unwrap(); self.pdf_const * exp_term.exp() @@ -263,10 +254,9 @@ where /// /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution - fn ln_pdf(&self, x: VectorN) -> Real { + fn ln_pdf(&self, x: VectorN) -> f64 { let dv = x - &self.mu; - let exp_term = nalgebra::convert::(-0.5) - * *(&dv.transpose() * &self.precision * &dv) + let exp_term = -0.5 * *(&dv.transpose() * &self.precision * &dv) .get((0, 0)) .unwrap(); self.pdf_const.ln() + exp_term From e1a6b7ef81b85f1796d7ce85590dc808bf994aba Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Tue, 7 Jan 2020 11:30:51 +0800 Subject: [PATCH 12/24] doc: Simplify ln_pdf documentation --- src/distribution/multivariate_normal.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index ab65df8d..ca8ed639 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -244,16 +244,7 @@ where self.pdf_const * exp_term.exp() } /// Calculates the log probability density function for the multivariate - /// normal distribution at `x` - /// - /// # Formula - /// - /// ```ignore - /// ln((2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ))) - /// ``` - /// - /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant - /// of the covariance matrix, and `k` is the dimension of the distribution + /// normal distribution at `x`. Equivalent to pdf(x).ln(). fn ln_pdf(&self, x: VectorN) -> f64 { let dv = x - &self.mu; let exp_term = -0.5 * *(&dv.transpose() * &self.precision * &dv) From bc97a56de58abe42d87aafb03d8f33642d5b8f48 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Tue, 7 Jan 2020 11:32:25 +0800 Subject: [PATCH 13/24] fix: run cargo fmt --- src/distribution/multivariate_normal.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index ca8ed639..df288e5f 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -10,7 +10,7 @@ use nalgebra::{ use num_traits::bounds::Bounded; use rand::distributions::Distribution; use rand::Rng; -use std::f64::consts::{PI, E}; +use std::f64::consts::{E, PI}; /// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) /// distribution using the "nalgebra" crate for matrix operations @@ -65,11 +65,7 @@ where return Err(StatsError::BadParams); } let cov_det = LU::new(cov.clone()).determinant(); - let pdf_const = ((2. * PI) - .powi(mean.nrows() as i32) - .recip() - * cov_det.abs()) - .sqrt(); + let pdf_const = ((2. * PI).powi(mean.nrows() as i32).recip() * cov_det.abs()).sqrt(); // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { @@ -238,7 +234,8 @@ where /// of the covariance matrix, and `k` is the dimension of the distribution fn pdf(&self, x: VectorN) -> f64 { let dv = x - &self.mu; - let exp_term = -0.5 * *(&dv.transpose() * &self.precision * &dv) + let exp_term = -0.5 + * *(&dv.transpose() * &self.precision * &dv) .get((0, 0)) .unwrap(); self.pdf_const * exp_term.exp() @@ -247,7 +244,8 @@ where /// normal distribution at `x`. Equivalent to pdf(x).ln(). fn ln_pdf(&self, x: VectorN) -> f64 { let dv = x - &self.mu; - let exp_term = -0.5 * *(&dv.transpose() * &self.precision * &dv) + let exp_term = -0.5 + * *(&dv.transpose() * &self.precision * &dv) .get((0, 0)) .unwrap(); self.pdf_const.ln() + exp_term From 81241fce2e690a116d87c147955041db4a7d0c0a Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 16:21:25 +0800 Subject: [PATCH 14/24] fix: Check for NaN in mean and covariance in constructor --- src/distribution/multivariate_normal.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index df288e5f..6a1ae9ed 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -5,12 +5,13 @@ use crate::{Result, StatsError}; use nalgebra::{ base::allocator::Allocator, base::{dimension::DimName, MatrixN, VectorN}, - Cholesky, DefaultAllocator, Dim, DimMin, RealField, LU, U1, + Cholesky, DefaultAllocator, Dim, DimMin, LU, U1, }; use num_traits::bounds::Bounded; use rand::distributions::Distribution; use rand::Rng; use std::f64::consts::{E, PI}; +use std::f64; /// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) /// distribution using the "nalgebra" crate for matrix operations @@ -61,7 +62,9 @@ where /// symmetric or positive-definite pub fn new(mean: &VectorN, cov: &MatrixN) -> Result { // Check that the provided covariance matrix is symmetric - if cov.lower_triangle() != cov.upper_triangle().transpose() { + // Check that mean and covariance do not contain NaN + if cov.lower_triangle() != cov.upper_triangle().transpose() + || mean.iter().any(|f| f.is_nan()) || cov.iter().any(|f| f.is_nan()) { return Err(StatsError::BadParams); } let cov_det = LU::new(cov.clone()).determinant(); From 09e70d6a2b97e19bf5c564b9b4f571972078e17a Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 16:21:56 +0800 Subject: [PATCH 15/24] fix: pdf constant computation --- src/distribution/multivariate_normal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 6a1ae9ed..e3f6a430 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -68,7 +68,7 @@ where return Err(StatsError::BadParams); } let cov_det = LU::new(cov.clone()).determinant(); - let pdf_const = ((2. * PI).powi(mean.nrows() as i32).recip() * cov_det.abs()).sqrt(); + let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs()).recip().sqrt(); // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { From 7a5b0c221a68c4a5c5766929e4cf118811d4b14d Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 16:22:48 +0800 Subject: [PATCH 16/24] fix: min value and max value reflect mathematical values rather than programmatic --- src/distribution/multivariate_normal.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index e3f6a430..5346d742 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -119,7 +119,7 @@ where /// Returns the minimum value in the domain of the /// multivariate normal distribution represented by a real vector fn min(&self) -> VectorN { - VectorN::min_value() + VectorN::::repeat(f64::NEG_INFINITY) } } @@ -134,7 +134,7 @@ where /// Returns the maximum value in the domain of the /// multivariate normal distribution represented by a real vector fn max(&self) -> VectorN { - VectorN::max_value() + VectorN::::repeat(f64::INFINITY) } } From 18ede17d801ad4bf844ac673f6921bf3391650ba Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 16:23:13 +0800 Subject: [PATCH 17/24] fix: entropy computation --- src/distribution/multivariate_normal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 5346d742..f5a514c0 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -188,7 +188,7 @@ where /// /// where `Σ` is the covariance matrix and `det` is the determinant fn entropy(&self) -> f64 { - LU::new(self.variance().clone().scale(2. * PI * E)) + 0.5 * LU::new(self.variance().clone().scale(2. * PI * E)) .determinant() .ln() } From 4d5e977b603a36e173c09f1d809839a38a3b50cc Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 16:43:49 +0800 Subject: [PATCH 18/24] test: Add tests for distribution::multivariate_normal --- src/distribution/multivariate_normal.rs | 156 ++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index f5a514c0..de7a6e62 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -254,3 +254,159 @@ where self.pdf_const.ln() + exp_term } } + +#[cfg_attr(rustfmt, rustfmt_skip)] +#[cfg(test)] +mod test { + use std::f64; + use crate::statistics::*; + use crate::distribution::{MultivariateNormal, Continuous}; + use crate::distribution::internal::*; + use nalgebra::base::dimension::U2; + use nalgebra::{Matrix2, Vector2, Matrix3, Vector3, VectorN, MatrixN, Dim, DimMin, DimName, DefaultAllocator, U1}; + use nalgebra::base::allocator::Allocator; + use num_traits::real::Real; + use core::fmt::Debug; + use num_traits::bounds::Bounded; + + fn try_create(mean: VectorN, covariance: MatrixN) -> MultivariateNormal + where + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, + { + let mvn = MultivariateNormal::new(&mean, &covariance); + assert!(mvn.is_ok()); + mvn.unwrap() + } + + fn create_case(mean: VectorN, covariance: MatrixN) + where + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, + { + let mvn = try_create(mean.clone(), covariance.clone()); + assert_eq!(mean, mvn.mean()); + assert_eq!(covariance, mvn.variance()); + } + + fn bad_create_case(mean: VectorN, covariance: MatrixN) + where + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, + { + let mvn = MultivariateNormal::new(&mean, &covariance); + assert!(mvn.is_err()); + } + + fn test_case(mean: VectorN, covariance: MatrixN, expected: T, eval: F) + where + T: Debug + PartialEq, + F: Fn(MultivariateNormal) -> T, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, + { + let mvn = try_create(mean, covariance); + let x = eval(mvn); + assert_eq!(expected, x); + } + + fn test_almost(mean: VectorN, covariance: MatrixN, expected: f64, acc: f64, eval: F) + where + F: Fn(MultivariateNormal) -> f64, + N: Dim + DimMin + DimName, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator<(usize, usize), >::Output>, + { + let mvn = try_create(mean, covariance); + let x = eval(mvn); + assert_almost_eq!(expected, x, acc); + } + + #[test] + fn test_create() { + create_case(Vector2::new(0., 0.), Matrix2::new(1., 0., 0., 1.)); + create_case(Vector2::new(10., 5.), Matrix2::new(2., 1., 1., 2.)); + create_case(Vector3::new(4., 5., 6.), Matrix3::new(2., 1., 0., 1., 2., 1., 0., 1., 2.)); + create_case(Vector2::new(0., f64::INFINITY), Matrix2::identity()); + create_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY)); + } + + #[test] + fn test_bad_create() { + // Covariance not symmetric + bad_create_case(Vector2::zeros(), Matrix2::new(1., 1., 0., 1.)); + // Covariance not positive-definite + bad_create_case(Vector2::zeros(), Matrix2::new(1., 2., 2., 1.)); + // NaN in mean + bad_create_case(Vector2::new(0., f64::NAN), Matrix2::identity()); + // NaN in mean + bad_create_case(Vector2::zeros(), Matrix2::new(1., 0., 0., f64::NAN)); + } + + #[test] + fn test_variance() { + test_case(Vector2::zeros(), Matrix2::identity(), Matrix2::new(1., 0., 0., 1.), |x| x.variance()); + test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), |x| x.variance()); + } + + #[test] + fn test_entropy() { + test_case(Vector2::zeros(), Matrix2::identity(), 2.8378770664093453, |x| x.entropy()); + test_case(Vector2::zeros(), Matrix2::new(1., 0.5, 0.5, 1.), 2.694036030183455, |x| x.entropy()); + test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), f64::INFINITY, |x| x.entropy()); + } + + #[test] + fn test_mode() { + test_case(Vector2::zeros(), Matrix2::identity(), Vector2::new(0., 0.), |x| x.mode()); + test_case(Vector2::::repeat(f64::INFINITY), Matrix2::identity(), Vector2::new(f64::INFINITY, f64::INFINITY), |x| x.mode()); + } + + #[test] + fn test_min_max() { + test_case(Vector2::zeros(), Matrix2::identity(), Vector2::new(f64::NEG_INFINITY, f64::NEG_INFINITY), |x| x.min()); + test_case(Vector2::zeros(), Matrix2::identity(), Vector2::new(f64::INFINITY, f64::INFINITY), |x| x.max()); + test_case(Vector2::new(10., 1.), Matrix2::identity(), Vector2::new(f64::NEG_INFINITY, f64::NEG_INFINITY), |x| x.min()); + test_case(Vector2::new(-3., 5.), Matrix2::identity(), Vector2::new(f64::INFINITY, f64::INFINITY), |x| x.max()); + } + + #[test] + fn test_pdf() { + test_case(Vector2::zeros(), Matrix2::identity(), 0.05854983152431917, |x| x.pdf(Vector2::new(1., 1.))); + test_almost(Vector2::zeros(), Matrix2::identity(), 0.013064233284684921, 1e-15, |x| x.pdf(Vector2::new(1., 2.))); + test_almost(Vector2::zeros(), Matrix2::identity(), 1.8618676045881531e-23, 1e-35, |x| x.pdf(Vector2::new(1., 10.))); + test_almost(Vector2::zeros(), Matrix2::identity(), 5.920684802611216e-45, 1e-58, |x| x.pdf(Vector2::new(10., 10.))); + test_almost(Vector2::zeros(), Matrix2::new(1., 0.9, 0.9, 1.), 1.6576716577547003e-05, 1e-18, |x| x.pdf(Vector2::new(1., -1.))); + test_almost(Vector2::zeros(), Matrix2::new(1., 0.99, 0.99, 1.), 4.1970621773477824e-44, 1e-54, |x| x.pdf(Vector2::new(1., -1.))); + test_almost(Vector2::new(0.5, -0.2), Matrix2::new(2.0, 0.3, 0.3, 0.5), 0.0013075203140666656, 1e-15, |x| x.pdf(Vector2::new(2., 2.))); + test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), 0.0, |x| x.pdf(Vector2::new(10., 10.))); + test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), 0.0, |x| x.pdf(Vector2::new(100., 100.))); + } + + #[test] + fn test_ln_pdf() { + test_case(Vector2::zeros(), Matrix2::identity(), (0.05854983152431917).ln(), |x| x.ln_pdf(Vector2::new(1., 1.))); + test_almost(Vector2::zeros(), Matrix2::identity(), (0.013064233284684921f64).ln(), 1e-15, |x| x.ln_pdf(Vector2::new(1., 2.))); + test_almost(Vector2::zeros(), Matrix2::identity(), (1.8618676045881531e-23f64).ln(), 1e-15, |x| x.ln_pdf(Vector2::new(1., 10.))); + test_almost(Vector2::zeros(), Matrix2::identity(), (5.920684802611216e-45f64).ln(), 1e-15, |x| x.ln_pdf(Vector2::new(10., 10.))); + test_almost(Vector2::zeros(), Matrix2::new(1., 0.9, 0.9, 1.), (1.6576716577547003e-05f64).ln(), 1e-14, |x| x.ln_pdf(Vector2::new(1., -1.))); + test_almost(Vector2::zeros(), Matrix2::new(1., 0.99, 0.99, 1.), (4.1970621773477824e-44f64).ln(), 1e-12, |x| x.ln_pdf(Vector2::new(1., -1.))); + test_almost(Vector2::new(0.5, -0.2), Matrix2::new(2.0, 0.3, 0.3, 0.5), (0.0013075203140666656f64).ln(), 1e-15, |x| x.ln_pdf(Vector2::new(2., 2.))); + test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), f64::NEG_INFINITY, |x| x.ln_pdf(Vector2::new(10., 10.))); + test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), f64::NEG_INFINITY, |x| x.ln_pdf(Vector2::new(100., 100.))); + } +} \ No newline at end of file From 7ed0d319381d3c80b9ff34fbe3b269df21ed66ea Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 16:45:07 +0800 Subject: [PATCH 19/24] fix: remove unused imports --- src/distribution/multivariate_normal.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index de7a6e62..627daef6 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -7,7 +7,6 @@ use nalgebra::{ base::{dimension::DimName, MatrixN, VectorN}, Cholesky, DefaultAllocator, Dim, DimMin, LU, U1, }; -use num_traits::bounds::Bounded; use rand::distributions::Distribution; use rand::Rng; use std::f64::consts::{E, PI}; @@ -261,13 +260,10 @@ mod test { use std::f64; use crate::statistics::*; use crate::distribution::{MultivariateNormal, Continuous}; - use crate::distribution::internal::*; - use nalgebra::base::dimension::U2; use nalgebra::{Matrix2, Vector2, Matrix3, Vector3, VectorN, MatrixN, Dim, DimMin, DimName, DefaultAllocator, U1}; use nalgebra::base::allocator::Allocator; use num_traits::real::Real; use core::fmt::Debug; - use num_traits::bounds::Bounded; fn try_create(mean: VectorN, covariance: MatrixN) -> MultivariateNormal where From 1ae983e5b4ae37b1467a8ad89475c393ae2b91d5 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 16:45:37 +0800 Subject: [PATCH 20/24] fix: cargo fmt --- src/distribution/multivariate_normal.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 627daef6..4debbfcd 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -9,8 +9,8 @@ use nalgebra::{ }; use rand::distributions::Distribution; use rand::Rng; -use std::f64::consts::{E, PI}; use std::f64; +use std::f64::consts::{E, PI}; /// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) /// distribution using the "nalgebra" crate for matrix operations @@ -63,11 +63,15 @@ where // Check that the provided covariance matrix is symmetric // Check that mean and covariance do not contain NaN if cov.lower_triangle() != cov.upper_triangle().transpose() - || mean.iter().any(|f| f.is_nan()) || cov.iter().any(|f| f.is_nan()) { + || mean.iter().any(|f| f.is_nan()) + || cov.iter().any(|f| f.is_nan()) + { return Err(StatsError::BadParams); } let cov_det = LU::new(cov.clone()).determinant(); - let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs()).recip().sqrt(); + let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs()) + .recip() + .sqrt(); // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { @@ -405,4 +409,4 @@ mod test { test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), f64::NEG_INFINITY, |x| x.ln_pdf(Vector2::new(10., 10.))); test_case(Vector2::zeros(), Matrix2::new(f64::INFINITY, 0., 0., f64::INFINITY), f64::NEG_INFINITY, |x| x.ln_pdf(Vector2::new(100., 100.))); } -} \ No newline at end of file +} From c8446ed35f8c0c10098353b1a96f8786eceed69f Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Wed, 8 Jan 2020 17:33:43 +0800 Subject: [PATCH 21/24] fix: correct comment --- src/distribution/multivariate_normal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 4debbfcd..8119f89e 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -353,7 +353,7 @@ mod test { bad_create_case(Vector2::zeros(), Matrix2::new(1., 2., 2., 1.)); // NaN in mean bad_create_case(Vector2::new(0., f64::NAN), Matrix2::identity()); - // NaN in mean + // NaN in Covariance Matrix bad_create_case(Vector2::zeros(), Matrix2::new(1., 0., 0., f64::NAN)); } From a9e1e8dc780243f55ecae826ad96e06ae683ffc7 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Thu, 9 Jan 2020 00:05:51 +0800 Subject: [PATCH 22/24] fix: move num-traits to dev dependencies --- Cargo.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index db576a46..f5e89cec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,4 +20,6 @@ path = "src/lib.rs" [dependencies] rand = "0.7" nalgebra = "0.19" + +[dev-dependencies] num-traits = "0.2.10" \ No newline at end of file From 913771502366b94a887e5a4c5e2775eb51a56df7 Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Thu, 9 Jan 2020 00:05:51 +0800 Subject: [PATCH 23/24] Revert "fix: move num-traits to dev dependencies" This reverts commit a9e1e8dc780243f55ecae826ad96e06ae683ffc7. --- Cargo.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f5e89cec..db576a46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,4 @@ path = "src/lib.rs" [dependencies] rand = "0.7" nalgebra = "0.19" - -[dev-dependencies] num-traits = "0.2.10" \ No newline at end of file From 1e2b0c56a8b056c7148b5bc99a838d1c62863b3f Mon Sep 17 00:00:00 2001 From: Theodore Lee Date: Thu, 9 Jan 2020 00:08:26 +0800 Subject: [PATCH 24/24] fix: Remove dependency on num-traits --- Cargo.toml | 3 +-- src/distribution/multivariate_normal.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index db576a46..d9aace77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,5 +19,4 @@ path = "src/lib.rs" [dependencies] rand = "0.7" -nalgebra = "0.19" -num-traits = "0.2.10" \ No newline at end of file +nalgebra = "0.19" \ No newline at end of file diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 8119f89e..80b4e812 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -266,7 +266,6 @@ mod test { use crate::distribution::{MultivariateNormal, Continuous}; use nalgebra::{Matrix2, Vector2, Matrix3, Vector3, VectorN, MatrixN, Dim, DimMin, DimName, DefaultAllocator, U1}; use nalgebra::base::allocator::Allocator; - use num_traits::real::Real; use core::fmt::Debug; fn try_create(mean: VectorN, covariance: MatrixN) -> MultivariateNormal @@ -399,7 +398,7 @@ mod test { #[test] fn test_ln_pdf() { - test_case(Vector2::zeros(), Matrix2::identity(), (0.05854983152431917).ln(), |x| x.ln_pdf(Vector2::new(1., 1.))); + test_case(Vector2::zeros(), Matrix2::identity(), (0.05854983152431917f64).ln(), |x| x.ln_pdf(Vector2::new(1., 1.))); test_almost(Vector2::zeros(), Matrix2::identity(), (0.013064233284684921f64).ln(), 1e-15, |x| x.ln_pdf(Vector2::new(1., 2.))); test_almost(Vector2::zeros(), Matrix2::identity(), (1.8618676045881531e-23f64).ln(), 1e-15, |x| x.ln_pdf(Vector2::new(1., 10.))); test_almost(Vector2::zeros(), Matrix2::identity(), (5.920684802611216e-45f64).ln(), 1e-15, |x| x.ln_pdf(Vector2::new(10., 10.)));