Skip to content

Commit

Permalink
Merge pull request #346 from rust-ndarray/lax-solve-impl
Browse files Browse the repository at this point in the history
Merge `Solve_`, `Solveh_` and `Cholesky_` into `Lapack` trait
  • Loading branch information
termoshtt committed Oct 1, 2022
2 parents 07ab31d + 7e61539 commit acd7858
Show file tree
Hide file tree
Showing 4 changed files with 494 additions and 341 deletions.
135 changes: 63 additions & 72 deletions lax/src/cholesky.rs
@@ -1,59 +1,25 @@
//! Factorize positive-definite symmetric/Hermitian matrices using Cholesky algorithm

use super::*;
use crate::{error::*, layout::*};
use cauchy::*;

#[cfg_attr(doc, katexit::katexit)]
/// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition
///
/// For a given positive definite matrix $A$,
/// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where
/// Compute Cholesky decomposition according to [UPLO]
///
/// - $L$ is lower matrix
/// - $U$ is upper matrix
/// LAPACK correspondance
/// ----------------------
///
/// This is designed as two step computation according to LAPACK API
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrf | dpotrf | cpotrf | zpotrf |
///
/// 1. Factorize input matrix $A$ into $L$ or $U$
/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$
/// using $U$ or $L$.
pub trait Cholesky_: Sized {
/// Compute Cholesky decomposition $A = U^T U$ or $A = L L^T$ according to [UPLO]
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrf | dpotrf | cpotrf | zpotrf |
///
pub trait CholeskyImpl: Scalar {
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;

/// Compute inverse matrix $A^{-1}$ using $U$ or $L$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotri | dpotri | cpotri | zpotri |
///
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;

/// Solve linear equation $Ax = b$ using $U$ or $L$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrs | dpotrs | cpotrs | zpotrs |
///
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_cholesky {
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
impl Cholesky_ for $scalar {
macro_rules! impl_cholesky_ {
($s:ty, $trf:path) => {
impl CholeskyImpl for $s {
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
if matches!(l, MatrixLayout::C { .. }) {
Expand All @@ -69,7 +35,30 @@ macro_rules! impl_cholesky {
}
Ok(())
}
}
};
}
impl_cholesky_!(c64, lapack_sys::zpotrf_);
impl_cholesky_!(c32, lapack_sys::cpotrf_);
impl_cholesky_!(f64, lapack_sys::dpotrf_);
impl_cholesky_!(f32, lapack_sys::spotrf_);

/// Compute inverse matrix using Cholesky factroization result
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotri | dpotri | cpotri | zpotri |
///
pub trait InvCholeskyImpl: Scalar {
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
}

macro_rules! impl_inv_cholesky {
($s:ty, $tri:path) => {
impl InvCholeskyImpl for $s {
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
if matches!(l, MatrixLayout::C { .. }) {
Expand All @@ -85,7 +74,30 @@ macro_rules! impl_cholesky {
}
Ok(())
}
}
};
}
impl_inv_cholesky!(c64, lapack_sys::zpotri_);
impl_inv_cholesky!(c32, lapack_sys::cpotri_);
impl_inv_cholesky!(f64, lapack_sys::dpotri_);
impl_inv_cholesky!(f32, lapack_sys::spotri_);

/// Solve linear equation using Cholesky factroization result
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | spotrs | dpotrs | cpotrs | zpotrs |
///
pub trait SolveCholeskyImpl: Scalar {
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_solve_cholesky {
($s:ty, $trs:path) => {
impl SolveCholeskyImpl for $s {
fn solve_cholesky(
l: MatrixLayout,
mut uplo: UPLO,
Expand Down Expand Up @@ -123,29 +135,8 @@ macro_rules! impl_cholesky {
}
}
};
} // end macro_rules

impl_cholesky!(
f64,
lapack_sys::dpotrf_,
lapack_sys::dpotri_,
lapack_sys::dpotrs_
);
impl_cholesky!(
f32,
lapack_sys::spotrf_,
lapack_sys::spotri_,
lapack_sys::spotrs_
);
impl_cholesky!(
c64,
lapack_sys::zpotrf_,
lapack_sys::zpotri_,
lapack_sys::zpotrs_
);
impl_cholesky!(
c32,
lapack_sys::cpotrf_,
lapack_sys::cpotri_,
lapack_sys::cpotrs_
);
}
impl_solve_cholesky!(c64, lapack_sys::zpotrs_);
impl_solve_cholesky!(c32, lapack_sys::cpotrs_);
impl_solve_cholesky!(f64, lapack_sys::dpotrs_);
impl_solve_cholesky!(f32, lapack_sys::spotrs_);

0 comments on commit acd7858

Please sign in to comment.