Skip to content

Commit

Permalink
Merge Triangular_ into Lapack
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Oct 3, 2022
1 parent 120fb07 commit 27c7fa1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
26 changes: 23 additions & 3 deletions lax/src/lib.rs
Expand Up @@ -100,17 +100,16 @@ pub mod solve;
pub mod solveh;
pub mod svd;
pub mod svddc;
pub mod triangular;

mod alloc;
mod triangular;
mod tridiagonal;

pub use self::cholesky::*;
pub use self::flags::*;
pub use self::least_squares::LeastSquaresOwned;
pub use self::opnorm::*;
pub use self::svd::{SvdOwned, SvdRef};
pub use self::triangular::*;
pub use self::tridiagonal::*;

use self::{alloc::*, error::*, layout::*};
Expand All @@ -121,7 +120,7 @@ pub type Pivot = Vec<i32>;

#[cfg_attr(doc, katexit::katexit)]
/// Trait for primitive types which implements LAPACK subroutines
pub trait Lapack: Triangular_ + Tridiagonal_ {
pub trait Lapack: Tridiagonal_ {
/// Compute right eigenvalue and eigenvectors for a general matrix
fn eig(
calc_v: bool,
Expand Down Expand Up @@ -298,6 +297,15 @@ pub trait Lapack: Triangular_ + Tridiagonal_ {
/// $$
///
fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real;

fn solve_triangular(
al: MatrixLayout,
bl: MatrixLayout,
uplo: UPLO,
d: Diag,
a: &[Self],
b: &mut [Self],
) -> Result<()>;
}

macro_rules! impl_lapack {
Expand Down Expand Up @@ -471,6 +479,18 @@ macro_rules! impl_lapack {
let mut work = OperatorNormWork::<$s>::new(t, l);
work.calc(a)
}

fn solve_triangular(
al: MatrixLayout,
bl: MatrixLayout,
uplo: UPLO,
d: Diag,
a: &[Self],
b: &mut [Self],
) -> Result<()> {
use triangular::*;
SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b)
}
}
};
}
Expand Down
26 changes: 17 additions & 9 deletions lax/src/triangular.rs
@@ -1,10 +1,18 @@
//! Implement linear solver and inverse matrix
//! Linear problem for triangular matrices

use crate::{error::*, layout::*, *};
use cauchy::*;

/// Wraps `*trtri` and `*trtrs`
pub trait Triangular_: Scalar {
/// Solve linear problem for triangular matrices
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | strtrs | dtrtrs | ctrtrs | ztrtrs |
///
pub trait SolveTriangularImpl: Scalar {
fn solve_triangular(
al: MatrixLayout,
bl: MatrixLayout,
Expand All @@ -16,8 +24,8 @@ pub trait Triangular_: Scalar {
}

macro_rules! impl_triangular {
($scalar:ty, $trtri:path, $trtrs:path) => {
impl Triangular_ for $scalar {
($scalar:ty, $trtrs:path) => {
impl SolveTriangularImpl for $scalar {
fn solve_triangular(
a_layout: MatrixLayout,
b_layout: MatrixLayout,
Expand Down Expand Up @@ -79,7 +87,7 @@ macro_rules! impl_triangular {
};
} // impl_triangular!

impl_triangular!(f64, lapack_sys::dtrtri_, lapack_sys::dtrtrs_);
impl_triangular!(f32, lapack_sys::strtri_, lapack_sys::strtrs_);
impl_triangular!(c64, lapack_sys::ztrtri_, lapack_sys::ztrtrs_);
impl_triangular!(c32, lapack_sys::ctrtri_, lapack_sys::ctrtrs_);
impl_triangular!(f64, lapack_sys::dtrtrs_);
impl_triangular!(f32, lapack_sys::strtrs_);
impl_triangular!(c64, lapack_sys::ztrtrs_);
impl_triangular!(c32, lapack_sys::ctrtrs_);

0 comments on commit 27c7fa1

Please sign in to comment.