New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Multivariate Normal Distribution with nalgebra #100
Merged
Merged
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
1dec41c
feat: multivariate_normal distribution
sphqxe aa2da3b
doc: add documentation to distribution::multivariate_normal
sphqxe 36a943e
fix: add checks to constructor for covariance matrix
sphqxe a6f4c51
fix: Build with nalgebra-mvn with dep nalgebra updated to 0.19.0, upd…
sphqxe 3009cdc
fix: Make trait bounds consistent
sphqxe e8b5a66
fix: Make dependencies less specific
sphqxe d9044b8
fix: Update nalgebra-mvn to v0.2
sphqxe 989df35
fix: Run cargo fmt
sphqxe d4d0b7c
fix: Remove dependence on nalgebra-mvn
sphqxe db05bd3
fix: Remove unused imports, refactor
sphqxe 79df75f
fix: Enforce f64 as numeric type
sphqxe e1a6b7e
doc: Simplify ln_pdf documentation
sphqxe bc97a56
fix: run cargo fmt
sphqxe 81241fc
fix: Check for NaN in mean and covariance in constructor
sphqxe 09e70d6
fix: pdf constant computation
sphqxe 7a5b0c2
fix: min value and max value reflect mathematical values rather than …
sphqxe 18ede17
fix: entropy computation
sphqxe 4d5e977
test: Add tests for distribution::multivariate_normal
sphqxe 7ed0d31
fix: remove unused imports
sphqxe 1ae983e
fix: cargo fmt
sphqxe c8446ed
fix: correct comment
sphqxe a9e1e8d
fix: move num-traits to dev dependencies
sphqxe 9137715
Revert "fix: move num-traits to dev dependencies"
sphqxe 1e2b0c5
fix: Remove dependency on num-traits
sphqxe File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,5 @@ path = "src/lib.rs" | |
|
||
[dependencies] | ||
rand = "0.7" | ||
nalgebra = "0.19" | ||
num-traits = "0.2.10" | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
use crate::distribution::Continuous; | ||
use crate::distribution::Normal; | ||
use crate::statistics::{Covariance, Entropy, Max, Mean, Min, Mode}; | ||
use crate::{Result, StatsError}; | ||
use nalgebra::{ | ||
base::allocator::Allocator, | ||
base::{dimension::DimName, MatrixN, VectorN}, | ||
Cholesky, DefaultAllocator, Dim, DimMin, RealField, LU, U1, | ||
}; | ||
use num_traits::bounds::Bounded; | ||
use rand::distributions::Distribution; | ||
use rand::Rng; | ||
|
||
/// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) | ||
/// distribution using the "nalgebra" crate for matrix operations | ||
/// | ||
/// # Examples | ||
/// | ||
/// ``` | ||
/// use statrs::distribution::{MultivariateNormal, Continuous}; | ||
/// use nalgebra::base::dimension::U2; | ||
/// use nalgebra::{Vector2, Matrix2}; | ||
/// use statrs::statistics::{Mean, Covariance}; | ||
/// | ||
/// let mvn = MultivariateNormal::<f64, U2>::new(&Vector2::<f64>::zeros(), &Matrix2::<f64>::identity()).unwrap(); | ||
/// assert_eq!(mvn.mean(), Vector2::<f64>::new(0., 0.)); | ||
/// assert_eq!(mvn.variance(), Matrix2::<f64>::new(1., 0., 0., 1.)); | ||
/// assert_eq!(mvn.pdf(Vector2::<f64>::new(1., 1.)), 0.05854983152431917); | ||
/// ``` | ||
#[derive(Debug, Clone)] | ||
pub struct MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
cov_chol_decomp: MatrixN<Real, N>, | ||
mu: VectorN<Real, N>, | ||
cov: MatrixN<Real, N>, | ||
precision: MatrixN<Real, N>, | ||
pdf_const: Real, | ||
} | ||
|
||
impl<Real, N> MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Constructs a new multivariate normal distribution with a mean of `mean` | ||
/// and covariance matrix `cov` | ||
/// | ||
/// # Errors | ||
/// | ||
/// Returns an error if the given covariance matrix is not | ||
/// symmetric or positive-definite | ||
pub fn new(mean: &VectorN<Real, N>, cov: &MatrixN<Real, N>) -> Result<Self> { | ||
// Check that the provided covariance matrix is symmetric | ||
if cov.lower_triangle() != cov.upper_triangle().transpose() { | ||
return Err(StatsError::BadParams); | ||
} | ||
let cov_det = LU::new(cov.clone()).determinant(); | ||
let pdf_const = (Real::two_pi() | ||
.powi(mean.nrows() as i32) | ||
.recip() | ||
.mul(cov_det.abs())) | ||
.sqrt(); | ||
// Store the Cholesky decomposition of the covariance matrix | ||
// for sampling | ||
match Cholesky::new(cov.clone()) { | ||
None => Err(StatsError::BadParams), | ||
Some(cholesky_decomp) => Ok(MultivariateNormal { | ||
cov_chol_decomp: cholesky_decomp.clone().unpack(), | ||
mu: mean.clone(), | ||
cov: cov.clone(), | ||
precision: cholesky_decomp.inverse(), | ||
pdf_const: pdf_const, | ||
}), | ||
} | ||
} | ||
} | ||
|
||
impl<N> Distribution<VectorN<f64, N>> for MultivariateNormal<f64, N> | ||
where | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<f64, N>, | ||
DefaultAllocator: Allocator<f64, N, N>, | ||
DefaultAllocator: Allocator<f64, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Samples from the multivariate normal distribution | ||
/// | ||
/// # Formula | ||
/// L * Z + μ | ||
/// | ||
/// where `L` is the Cholesky decomposition of the covariance matrix, | ||
/// `Z` is a vector of normally distributed random variables, and | ||
/// `μ` is the mean vector | ||
|
||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> VectorN<f64, N> { | ||
let d = Normal::new(0., 1.).unwrap(); | ||
let z = VectorN::<f64, N>::from_distribution(&d, rng); | ||
(self.cov_chol_decomp.clone() * z) + self.mu.clone() | ||
} | ||
} | ||
|
||
impl<Real, N> Min<VectorN<Real, N>> for MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Returns the minimum value in the domain of the | ||
/// multivariate normal distribution represented by a real vector | ||
fn min(&self) -> VectorN<Real, N> { | ||
VectorN::min_value() | ||
} | ||
} | ||
|
||
impl<Real, N> Max<VectorN<Real, N>> for MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Returns the maximum value in the domain of the | ||
/// multivariate normal distribution represented by a real vector | ||
fn max(&self) -> VectorN<Real, N> { | ||
VectorN::max_value() | ||
} | ||
} | ||
|
||
impl<Real, N> Mean<VectorN<Real, N>> for MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Returns the mean of the normal distribution | ||
/// | ||
/// # Remarks | ||
/// | ||
/// This is the same mean used to construct the distribution | ||
fn mean(&self) -> VectorN<Real, N> { | ||
self.mu.clone() | ||
} | ||
} | ||
|
||
impl<Real, N> Covariance<MatrixN<Real, N>> for MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Returns the covariance matrix of the multivariate normal distribution | ||
fn variance(&self) -> MatrixN<Real, N> { | ||
self.cov.clone() | ||
} | ||
} | ||
|
||
impl<Real, N> Entropy<Real> for MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Returns the entropy of the multivariate normal distribution | ||
/// | ||
/// # Formula | ||
/// | ||
/// ```ignore | ||
/// (1 / 2) * ln(det(2 * π * e * Σ)) | ||
/// ``` | ||
/// | ||
/// where `Σ` is the covariance matrix and `det` is the determinant | ||
fn entropy(&self) -> Real { | ||
LU::new(self.variance().clone().scale(Real::two_pi() * Real::e())) | ||
.determinant() | ||
.ln() | ||
} | ||
} | ||
|
||
impl<Real, N> Mode<VectorN<Real, N>> for MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Returns the mode of the multivariate normal distribution | ||
/// | ||
/// # Formula | ||
/// | ||
/// ```ignore | ||
/// μ | ||
/// ``` | ||
/// | ||
/// where `μ` is the mean | ||
fn mode(&self) -> VectorN<Real, N> { | ||
self.mu.clone() | ||
} | ||
} | ||
|
||
impl<Real, N> Continuous<VectorN<Real, N>, Real> for MultivariateNormal<Real, N> | ||
where | ||
Real: RealField, | ||
N: Dim + DimMin<N, Output = N> + DimName, | ||
DefaultAllocator: Allocator<Real, N>, | ||
DefaultAllocator: Allocator<Real, N, N>, | ||
DefaultAllocator: Allocator<Real, U1, N>, | ||
DefaultAllocator: Allocator<(usize, usize), <N as DimMin<N>>::Output>, | ||
{ | ||
/// Calculates the probability density function for the multivariate | ||
/// normal distribution at `x` | ||
/// | ||
/// # Formula | ||
/// | ||
/// ```ignore | ||
/// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) | ||
/// ``` | ||
/// | ||
/// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant | ||
/// of the covariance matrix, and `k` is the dimension of the distribution | ||
fn pdf(&self, x: VectorN<Real, N>) -> Real { | ||
let dv = x - &self.mu; | ||
let exp_term = nalgebra::convert::<f64, Real>(-0.5) | ||
* *(&dv.transpose() * &self.precision * &dv) | ||
.get((0, 0)) | ||
.unwrap(); | ||
self.pdf_const * exp_term.exp() | ||
} | ||
/// Calculates the log probability density function for the multivariate | ||
/// normal distribution at `x` | ||
/// | ||
/// # Formula | ||
/// | ||
/// ```ignore | ||
/// ln((2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ))) | ||
/// ``` | ||
/// | ||
/// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant | ||
/// of the covariance matrix, and `k` is the dimension of the distribution | ||
fn ln_pdf(&self, x: VectorN<Real, N>) -> Real { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be expressed in terms of |
||
let dv = x - &self.mu; | ||
let exp_term = nalgebra::convert::<f64, Real>(-0.5) | ||
* *(&dv.transpose() * &self.precision * &dv) | ||
.get((0, 0)) | ||
.unwrap(); | ||
self.pdf_const.ln() + exp_term | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need the explicit dependency on
num-traits
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, removed it.