From 27c7fa11230424c59c61d45b9957e89afe1e4032 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 3 Oct 2022 21:56:16 +0900 Subject: [PATCH] Merge Triangular_ into Lapack --- lax/src/lib.rs | 26 +++++++++++++++++++++++--- lax/src/triangular.rs | 26 +++++++++++++++++--------- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 4989fe15..4e886ea4 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -100,9 +100,9 @@ pub mod solve; pub mod solveh; pub mod svd; pub mod svddc; +pub mod triangular; mod alloc; -mod triangular; mod tridiagonal; pub use self::cholesky::*; @@ -110,7 +110,6 @@ pub use self::flags::*; pub use self::least_squares::LeastSquaresOwned; pub use self::opnorm::*; pub use self::svd::{SvdOwned, SvdRef}; -pub use self::triangular::*; pub use self::tridiagonal::*; use self::{alloc::*, error::*, layout::*}; @@ -121,7 +120,7 @@ pub type Pivot = Vec; #[cfg_attr(doc, katexit::katexit)] /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: Triangular_ + Tridiagonal_ { +pub trait Lapack: Tridiagonal_ { /// Compute right eigenvalue and eigenvectors for a general matrix fn eig( calc_v: bool, @@ -298,6 +297,15 @@ pub trait Lapack: Triangular_ + Tridiagonal_ { /// $$ /// fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; + + fn solve_triangular( + al: MatrixLayout, + bl: MatrixLayout, + uplo: UPLO, + d: Diag, + a: &[Self], + b: &mut [Self], + ) -> Result<()>; } macro_rules! impl_lapack { @@ -471,6 +479,18 @@ macro_rules! impl_lapack { let mut work = OperatorNormWork::<$s>::new(t, l); work.calc(a) } + + fn solve_triangular( + al: MatrixLayout, + bl: MatrixLayout, + uplo: UPLO, + d: Diag, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { + use triangular::*; + SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b) + } } }; } diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index 14f29807..da4dc4cf 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -1,10 +1,18 @@ -//! Implement linear solver and inverse matrix +//! Linear problem for triangular matrices use crate::{error::*, layout::*, *}; use cauchy::*; -/// Wraps `*trtri` and `*trtrs` -pub trait Triangular_: Scalar { +/// Solve linear problem for triangular matrices +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | strtrs | dtrtrs | ctrtrs | ztrtrs | +/// +pub trait SolveTriangularImpl: Scalar { fn solve_triangular( al: MatrixLayout, bl: MatrixLayout, @@ -16,8 +24,8 @@ pub trait Triangular_: Scalar { } macro_rules! impl_triangular { - ($scalar:ty, $trtri:path, $trtrs:path) => { - impl Triangular_ for $scalar { + ($scalar:ty, $trtrs:path) => { + impl SolveTriangularImpl for $scalar { fn solve_triangular( a_layout: MatrixLayout, b_layout: MatrixLayout, @@ -79,7 +87,7 @@ macro_rules! impl_triangular { }; } // impl_triangular! -impl_triangular!(f64, lapack_sys::dtrtri_, lapack_sys::dtrtrs_); -impl_triangular!(f32, lapack_sys::strtri_, lapack_sys::strtrs_); -impl_triangular!(c64, lapack_sys::ztrtri_, lapack_sys::ztrtrs_); -impl_triangular!(c32, lapack_sys::ctrtri_, lapack_sys::ctrtrs_); +impl_triangular!(f64, lapack_sys::dtrtrs_); +impl_triangular!(f32, lapack_sys::strtrs_); +impl_triangular!(c64, lapack_sys::ztrtrs_); +impl_triangular!(c32, lapack_sys::ctrtrs_);