Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge Solve_, Solveh_ and Cholesky_ into Lapack trait #346

Merged
merged 15 commits into from Oct 1, 2022
135 changes: 63 additions & 72 deletions lax/src/cholesky.rs
@@ -1,59 +1,25 @@
//! Factorize positive-definite symmetric/Hermitian matrices using Cholesky algorithm

use super::*;
use crate::{error::*, layout::*};
use cauchy::*;

#[cfg_attr(doc, katexit::katexit)]
/// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition
///
/// For a given positive definite matrix $A$,
/// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where
/// Compute Cholesky decomposition according to [UPLO]
///
/// - $L$ is lower matrix
/// - $U$ is upper matrix
/// LAPACK correspondance
/// ----------------------
///
/// This is designed as two step computation according to LAPACK API
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrf | dpotrf | cpotrf | zpotrf |
///
/// 1. Factorize input matrix $A$ into $L$ or $U$
/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$
/// using $U$ or $L$.
pub trait Cholesky_: Sized {
/// Compute Cholesky decomposition $A = U^T U$ or $A = L L^T$ according to [UPLO]
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrf | dpotrf | cpotrf | zpotrf |
///
pub trait CholeskyImpl: Scalar {
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;

/// Compute inverse matrix $A^{-1}$ using $U$ or $L$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotri | dpotri | cpotri | zpotri |
///
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;

/// Solve linear equation $Ax = b$ using $U$ or $L$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrs | dpotrs | cpotrs | zpotrs |
///
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_cholesky {
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
impl Cholesky_ for $scalar {
macro_rules! impl_cholesky_ {
($s:ty, $trf:path) => {
impl CholeskyImpl for $s {
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
if matches!(l, MatrixLayout::C { .. }) {
Expand All @@ -69,7 +35,30 @@ macro_rules! impl_cholesky {
}
Ok(())
}
}
};
}
impl_cholesky_!(c64, lapack_sys::zpotrf_);
impl_cholesky_!(c32, lapack_sys::cpotrf_);
impl_cholesky_!(f64, lapack_sys::dpotrf_);
impl_cholesky_!(f32, lapack_sys::spotrf_);

/// Compute inverse matrix using Cholesky factroization result
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotri | dpotri | cpotri | zpotri |
///
pub trait InvCholeskyImpl: Scalar {
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
}

macro_rules! impl_inv_cholesky {
($s:ty, $tri:path) => {
impl InvCholeskyImpl for $s {
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
if matches!(l, MatrixLayout::C { .. }) {
Expand All @@ -85,7 +74,30 @@ macro_rules! impl_cholesky {
}
Ok(())
}
}
};
}
impl_inv_cholesky!(c64, lapack_sys::zpotri_);
impl_inv_cholesky!(c32, lapack_sys::cpotri_);
impl_inv_cholesky!(f64, lapack_sys::dpotri_);
impl_inv_cholesky!(f32, lapack_sys::spotri_);

/// Solve linear equation using Cholesky factroization result
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrs | dpotrs | cpotrs | zpotrs |
///
pub trait SolveCholeskyImpl: Scalar {
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_solve_cholesky {
($s:ty, $trs:path) => {
impl SolveCholeskyImpl for $s {
fn solve_cholesky(
l: MatrixLayout,
mut uplo: UPLO,
Expand Down Expand Up @@ -123,29 +135,8 @@ macro_rules! impl_cholesky {
}
}
};
} // end macro_rules

impl_cholesky!(
f64,
lapack_sys::dpotrf_,
lapack_sys::dpotri_,
lapack_sys::dpotrs_
);
impl_cholesky!(
f32,
lapack_sys::spotrf_,
lapack_sys::spotri_,
lapack_sys::spotrs_
);
impl_cholesky!(
c64,
lapack_sys::zpotrf_,
lapack_sys::zpotri_,
lapack_sys::zpotrs_
);
impl_cholesky!(
c32,
lapack_sys::cpotrf_,
lapack_sys::cpotri_,
lapack_sys::cpotrs_
);
}
impl_solve_cholesky!(c64, lapack_sys::zpotrs_);
impl_solve_cholesky!(c32, lapack_sys::cpotrs_);
impl_solve_cholesky!(f64, lapack_sys::dpotrs_);
impl_solve_cholesky!(f32, lapack_sys::spotrs_);