Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multivariate normal distribution in ndarray-rand #582

Open
ManifoldFR opened this issue Jan 14, 2019 · 2 comments · May be fixed by #583
Open

Multivariate normal distribution in ndarray-rand #582

ManifoldFR opened this issue Jan 14, 2019 · 2 comments · May be fixed by #583

Comments

@ManifoldFR
Copy link

Implementing multivariate normal distributions involves a bit of boilerplate, and maybe the use of ndarray-linalg to perform a Cholesky decomposition. Would it be interesting to implement it on the crate's end ? I made a fork and started writing some code.

@jturner314
Copy link
Member

Fwiw, I've been using this for a while in my own code:

use failure::{Context, ResultExt};
use ndarray::{Data, DataClone, DataOwned, OwnedRepr, ViewRepr};
use ndarray::prelude::*;
use ndarray_linalg::cholesky::{CholeskyInto, UPLO};
use ndarray_rand::RandomExt;
use rand::distributions::{Distribution, Normal};
use rand::Rng;
use std::clone::Clone;
use std::fmt::{self, Debug};
use std::ops::AddAssign;

// ...

/// Multivariate Gaussian distribution.
#[derive(PartialEq, Deserialize, Serialize)]
#[serde(bound(deserialize = "S: DataOwned, S::Elem: ::serde::Deserialize<'de>"))]
#[serde(bound(serialize = "S: Data, S::Elem: ::serde::Serialize"))]
pub struct GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    pub mean: ArrayBase<S, Ix1>,
    pub covariance: ArrayBase<S, Ix2>,
}

pub type GaussianDistro = GaussianDistroBase<OwnedRepr<f64>>;

pub type GaussianDistroView<'a> = GaussianDistroBase<ViewRepr<&'a f64>>;

impl<S> GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    pub fn len(&self) -> usize {
        assert_eq!(self.mean.len(), self.covariance.len_of(Axis(0)));
        assert_eq!(self.mean.len(), self.covariance.len_of(Axis(1)));
        self.mean.len()
    }

    pub fn to_owned(&self) -> GaussianDistro {
        GaussianDistro {
            mean: self.mean.to_owned(),
            covariance: self.covariance.to_owned(),
        }
    }

    pub fn view(&self) -> GaussianDistroView {
        GaussianDistroView {
            mean: self.mean.view(),
            covariance: self.covariance.view(),
        }
    }
}

impl<S> Clone for GaussianDistroBase<S>
where
    S: DataClone<Elem = f64>,
{
    fn clone(&self) -> Self {
        GaussianDistroBase {
            mean: self.mean.clone(),
            covariance: self.covariance.clone(),
        }
    }
}

impl<S> Debug for GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.debug_struct("GaussianDistroBase")
            .field("mean", &self.mean)
            .field("covariance", &self.covariance)
            .finish()
    }
}

#[derive(Debug, Fail)]
#[fail(display = "error sampling from multivariate normal distribution: {}", _0)]
pub struct GaussianSampleError(Context<String>);

impl From<Context<String>> for GaussianSampleError {
    fn from(context: Context<String>) -> GaussianSampleError {
        GaussianSampleError(context)
    }
}

impl<S> Distribution<Result<Array1<f64>, GaussianSampleError>> for GaussianDistroBase<S>
where
    S: Data<Elem = f64>,
{
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<Array1<f64>, GaussianSampleError> {
        let mut cov = self.covariance.to_owned();
        // Add a small multiple of I for numerical reasons.
        cov.diag_mut().add_assign(1e2 * ::std::f64::EPSILON);
        let chol = cov.cholesky_into(UPLO::Lower)
            .context("error factoring covariance".into())?;
        Ok(chol.dot(&Array1::random_using(self.len(), Normal::new(0., 1.), rng)) + &self.mean)
    }
}

// ...

It's probably a bit more complex than what you're looking for (since GaussianDistroBase is generic over storage S), but it could be simplified.

For the purpose of ndarray-rand, adding a dependency on ndarray-linalg seems unfortunate because ndarray-linalg requires non-Rust code (the LAPACK implementation). I suppose it would be fine if we put the functionality behind a feature flag. What do you think @bluss?

@ManifoldFR
Copy link
Author

ManifoldFR commented Jan 14, 2019

Yes, a feature flag seems appropriate, I'm using the same thing on a MCMC algorithms crate I'm working on (for adding support to multivariate distributions using ndarray).

Your code seems a bit overkill, I was only planning on implementation for OwnedRepr<f64> datatype at first, but I'll into making it a bit more generic!

I'll make a pull request so you can check my work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants