diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 4e886ea4..af48f257 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -84,14 +84,14 @@ extern crate openblas_src as _src; #[cfg(any(feature = "netlib-system", feature = "netlib-static"))] extern crate netlib_src as _src; -pub mod error; -pub mod flags; -pub mod layout; - +pub mod alloc; pub mod cholesky; pub mod eig; pub mod eigh; pub mod eigh_generalized; +pub mod error; +pub mod flags; +pub mod layout; pub mod least_squares; pub mod opnorm; pub mod qr; @@ -101,16 +101,12 @@ pub mod solveh; pub mod svd; pub mod svddc; pub mod triangular; +pub mod tridiagonal; -mod alloc; -mod tridiagonal; - -pub use self::cholesky::*; pub use self::flags::*; pub use self::least_squares::LeastSquaresOwned; -pub use self::opnorm::*; pub use self::svd::{SvdOwned, SvdRef}; -pub use self::tridiagonal::*; +pub use self::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal}; use self::{alloc::*, error::*, layout::*}; use cauchy::*; @@ -120,7 +116,7 @@ pub type Pivot = Vec; #[cfg_attr(doc, katexit::katexit)] /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: Tridiagonal_ { +pub trait Lapack: Scalar { /// Compute right eigenvalue and eigenvectors for a general matrix fn eig( calc_v: bool, @@ -306,6 +302,19 @@ pub trait Lapack: Tridiagonal_ { a: &[Self], b: &mut [Self], ) -> Result<()>; + + /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using + /// partial pivoting with row interchanges. + fn lu_tridiagonal(a: Tridiagonal) -> Result>; + + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; + + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()>; } macro_rules! impl_lapack { @@ -491,6 +500,28 @@ macro_rules! impl_lapack { use triangular::*; SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b) } + + fn lu_tridiagonal(a: Tridiagonal) -> Result> { + use tridiagonal::*; + let work = LuTridiagonalWork::<$s>::new(a.l); + work.eval(a) + } + + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { + use tridiagonal::*; + let mut work = RcondTridiagonalWork::<$s>::new(lu.a.l); + work.calc(lu) + } + + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()> { + use tridiagonal::*; + SolveTridiagonalImpl::solve_tridiagonal(lu, bl, t, b) + } } }; } diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs deleted file mode 100644 index 3d28c63a..00000000 --- a/lax/src/tridiagonal.rs +++ /dev/null @@ -1,259 +0,0 @@ -//! Implement linear solver using LU decomposition -//! for tridiagonal matrix - -use crate::{error::*, layout::*, *}; -use cauchy::*; -use num_traits::Zero; -use std::ops::{Index, IndexMut}; - -/// Represents a tridiagonal matrix as 3 one-dimensional vectors. -/// -/// ```text -/// [d0, u1, 0, ..., 0, -/// l1, d1, u2, ..., -/// 0, l2, d2, -/// ... ..., u{n-1}, -/// 0, ..., l{n-1}, d{n-1},] -/// ``` -#[derive(Clone, PartialEq, Eq)] -pub struct Tridiagonal { - /// layout of raw matrix - pub l: MatrixLayout, - /// (n-1) sub-diagonal elements of matrix. - pub dl: Vec, - /// (n) diagonal elements of matrix. - pub d: Vec, - /// (n-1) super-diagonal elements of matrix. - pub du: Vec, -} - -impl Tridiagonal { - fn opnorm_one(&self) -> A::Real { - let mut col_sum: Vec = self.d.iter().map(|val| val.abs()).collect(); - for i in 0..col_sum.len() { - if i < self.dl.len() { - col_sum[i] += self.dl[i].abs(); - } - if i > 0 { - col_sum[i] += self.du[i - 1].abs(); - } - } - let mut max = A::Real::zero(); - for &val in &col_sum { - if max < val { - max = val; - } - } - max - } -} - -/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`. -#[derive(Clone, PartialEq)] -pub struct LUFactorizedTridiagonal { - /// A tridiagonal matrix which consists of - /// - l : layout of raw matrix - /// - dl: (n-1) multipliers that define the matrix L. - /// - d : (n) diagonal elements of the upper triangular matrix U. - /// - du: (n-1) elements of the first super-diagonal of U. - pub a: Tridiagonal, - /// (n-2) elements of the second super-diagonal of U. - pub du2: Vec, - /// The pivot indices that define the permutation matrix `P`. - pub ipiv: Pivot, - - a_opnorm_one: A::Real, -} - -impl Index<(i32, i32)> for Tridiagonal { - type Output = A; - #[inline] - fn index(&self, (row, col): (i32, i32)) -> &A { - let (n, _) = self.l.size(); - assert!( - std::cmp::max(row, col) < n, - "ndarray: index {:?} is out of bounds for array of shape {}", - [row, col], - n - ); - match row - col { - 0 => &self.d[row as usize], - 1 => &self.dl[col as usize], - -1 => &self.du[row as usize], - _ => panic!( - "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", - [row, col] - ), - } - } -} - -impl Index<[i32; 2]> for Tridiagonal { - type Output = A; - #[inline] - fn index(&self, [row, col]: [i32; 2]) -> &A { - &self[(row, col)] - } -} - -impl IndexMut<(i32, i32)> for Tridiagonal { - #[inline] - fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A { - let (n, _) = self.l.size(); - assert!( - std::cmp::max(row, col) < n, - "ndarray: index {:?} is out of bounds for array of shape {}", - [row, col], - n - ); - match row - col { - 0 => &mut self.d[row as usize], - 1 => &mut self.dl[col as usize], - -1 => &mut self.du[row as usize], - _ => panic!( - "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", - [row, col] - ), - } - } -} - -impl IndexMut<[i32; 2]> for Tridiagonal { - #[inline] - fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A { - &mut self[(row, col)] - } -} - -/// Wraps `*gttrf`, `*gtcon` and `*gttrs` -pub trait Tridiagonal_: Scalar + Sized { - /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using - /// partial pivoting with row interchanges. - fn lu_tridiagonal(a: Tridiagonal) -> Result>; - - fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; - - fn solve_tridiagonal( - lu: &LUFactorizedTridiagonal, - bl: MatrixLayout, - t: Transpose, - b: &mut [Self], - ) -> Result<()>; -} - -macro_rules! impl_tridiagonal { - (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { - impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork); - }; - (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { - impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, ); - }; - (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { - impl Tridiagonal_ for $scalar { - fn lu_tridiagonal(mut a: Tridiagonal) -> Result> { - let (n, _) = a.l.size(); - let mut du2 = vec_uninit( (n - 2) as usize); - let mut ipiv = vec_uninit( n as usize); - // We have to calc one-norm before LU factorization - let a_opnorm_one = a.opnorm_one(); - let mut info = 0; - unsafe { - $gttrf( - &n, - AsPtr::as_mut_ptr(&mut a.dl), - AsPtr::as_mut_ptr(&mut a.d), - AsPtr::as_mut_ptr(&mut a.du), - AsPtr::as_mut_ptr(&mut du2), - AsPtr::as_mut_ptr(&mut ipiv), - &mut info, - ) - }; - info.as_lapack_result()?; - let du2 = unsafe { du2.assume_init() }; - let ipiv = unsafe { ipiv.assume_init() }; - Ok(LUFactorizedTridiagonal { - a, - du2, - ipiv, - a_opnorm_one, - }) - } - - fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { - let (n, _) = lu.a.l.size(); - let ipiv = &lu.ipiv; - let mut work: Vec> = vec_uninit(2 * n as usize); - $( - let mut $iwork: Vec> = vec_uninit(n as usize); - )* - let mut rcond = Self::Real::zero(); - let mut info = 0; - unsafe { - $gtcon( - NormType::One.as_ptr(), - &n, - AsPtr::as_ptr(&lu.a.dl), - AsPtr::as_ptr(&lu.a.d), - AsPtr::as_ptr(&lu.a.du), - AsPtr::as_ptr(&lu.du2), - ipiv.as_ptr(), - &lu.a_opnorm_one, - &mut rcond, - AsPtr::as_mut_ptr(&mut work), - $(AsPtr::as_mut_ptr(&mut $iwork),)* - &mut info, - ); - } - info.as_lapack_result()?; - Ok(rcond) - } - - fn solve_tridiagonal( - lu: &LUFactorizedTridiagonal, - b_layout: MatrixLayout, - t: Transpose, - b: &mut [Self], - ) -> Result<()> { - let (n, _) = lu.a.l.size(); - let ipiv = &lu.ipiv; - // Transpose if b is C-continuous - let mut b_t = None; - let b_layout = match b_layout { - MatrixLayout::C { .. } => { - let (layout, t) = transpose(b_layout, b); - b_t = Some(t); - layout - } - MatrixLayout::F { .. } => b_layout, - }; - let (ldb, nrhs) = b_layout.size(); - let mut info = 0; - unsafe { - $gttrs( - t.as_ptr(), - &n, - &nrhs, - AsPtr::as_ptr(&lu.a.dl), - AsPtr::as_ptr(&lu.a.d), - AsPtr::as_ptr(&lu.a.du), - AsPtr::as_ptr(&lu.du2), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), - &ldb, - &mut info, - ); - } - info.as_lapack_result()?; - if let Some(b_t) = b_t { - transpose_over(b_layout, &b_t, b); - } - Ok(()) - } - } - }; -} // impl_tridiagonal! - -impl_tridiagonal!(@real, f64, lapack_sys::dgttrf_, lapack_sys::dgtcon_, lapack_sys::dgttrs_); -impl_tridiagonal!(@real, f32, lapack_sys::sgttrf_, lapack_sys::sgtcon_, lapack_sys::sgttrs_); -impl_tridiagonal!(@complex, c64, lapack_sys::zgttrf_, lapack_sys::zgtcon_, lapack_sys::zgttrs_); -impl_tridiagonal!(@complex, c32, lapack_sys::cgttrf_, lapack_sys::cgtcon_, lapack_sys::cgttrs_); diff --git a/lax/src/tridiagonal/lu.rs b/lax/src/tridiagonal/lu.rs new file mode 100644 index 00000000..e159bec6 --- /dev/null +++ b/lax/src/tridiagonal/lu.rs @@ -0,0 +1,101 @@ +use crate::*; +use cauchy::*; +use num_traits::Zero; + +/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`. +#[derive(Clone, PartialEq)] +pub struct LUFactorizedTridiagonal { + /// A tridiagonal matrix which consists of + /// - l : layout of raw matrix + /// - dl: (n-1) multipliers that define the matrix L. + /// - d : (n) diagonal elements of the upper triangular matrix U. + /// - du: (n-1) elements of the first super-diagonal of U. + pub a: Tridiagonal, + /// (n-2) elements of the second super-diagonal of U. + pub du2: Vec, + /// The pivot indices that define the permutation matrix `P`. + pub ipiv: Pivot, + + pub a_opnorm_one: A::Real, +} + +impl Tridiagonal { + fn opnorm_one(&self) -> A::Real { + let mut col_sum: Vec = self.d.iter().map(|val| val.abs()).collect(); + for i in 0..col_sum.len() { + if i < self.dl.len() { + col_sum[i] += self.dl[i].abs(); + } + if i > 0 { + col_sum[i] += self.du[i - 1].abs(); + } + } + let mut max = A::Real::zero(); + for &val in &col_sum { + if max < val { + max = val; + } + } + max + } +} + +pub struct LuTridiagonalWork { + pub layout: MatrixLayout, + pub du2: Vec>, + pub ipiv: Vec>, +} + +pub trait LuTridiagonalWorkImpl { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Self; + fn eval(self, a: Tridiagonal) -> Result>; +} + +macro_rules! impl_lu_tridiagonal_work { + ($s:ty, $trf:path) => { + impl LuTridiagonalWorkImpl for LuTridiagonalWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let du2 = vec_uninit((n - 2) as usize); + let ipiv = vec_uninit(n as usize); + LuTridiagonalWork { layout, du2, ipiv } + } + + fn eval( + mut self, + mut a: Tridiagonal, + ) -> Result> { + let (n, _) = self.layout.size(); + // We have to calc one-norm before LU factorization + let a_opnorm_one = a.opnorm_one(); + let mut info = 0; + unsafe { + $trf( + &n, + AsPtr::as_mut_ptr(&mut a.dl), + AsPtr::as_mut_ptr(&mut a.d), + AsPtr::as_mut_ptr(&mut a.du), + AsPtr::as_mut_ptr(&mut self.du2), + AsPtr::as_mut_ptr(&mut self.ipiv), + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(LUFactorizedTridiagonal { + a, + du2: unsafe { self.du2.assume_init() }, + ipiv: unsafe { self.ipiv.assume_init() }, + a_opnorm_one, + }) + } + } + }; +} + +impl_lu_tridiagonal_work!(c64, lapack_sys::zgttrf_); +impl_lu_tridiagonal_work!(c32, lapack_sys::cgttrf_); +impl_lu_tridiagonal_work!(f64, lapack_sys::dgttrf_); +impl_lu_tridiagonal_work!(f32, lapack_sys::sgttrf_); diff --git a/lax/src/tridiagonal/matrix.rs b/lax/src/tridiagonal/matrix.rs new file mode 100644 index 00000000..47401430 --- /dev/null +++ b/lax/src/tridiagonal/matrix.rs @@ -0,0 +1,84 @@ +use crate::layout::*; +use cauchy::*; +use std::ops::{Index, IndexMut}; + +/// Represents a tridiagonal matrix as 3 one-dimensional vectors. +/// +/// ```text +/// [d0, u1, 0, ..., 0, +/// l1, d1, u2, ..., +/// 0, l2, d2, +/// ... ..., u{n-1}, +/// 0, ..., l{n-1}, d{n-1},] +/// ``` +#[derive(Clone, PartialEq, Eq)] +pub struct Tridiagonal { + /// layout of raw matrix + pub l: MatrixLayout, + /// (n-1) sub-diagonal elements of matrix. + pub dl: Vec, + /// (n) diagonal elements of matrix. + pub d: Vec, + /// (n-1) super-diagonal elements of matrix. + pub du: Vec, +} + +impl Index<(i32, i32)> for Tridiagonal { + type Output = A; + #[inline] + fn index(&self, (row, col): (i32, i32)) -> &A { + let (n, _) = self.l.size(); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); + match row - col { + 0 => &self.d[row as usize], + 1 => &self.dl[col as usize], + -1 => &self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + +impl Index<[i32; 2]> for Tridiagonal { + type Output = A; + #[inline] + fn index(&self, [row, col]: [i32; 2]) -> &A { + &self[(row, col)] + } +} + +impl IndexMut<(i32, i32)> for Tridiagonal { + #[inline] + fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A { + let (n, _) = self.l.size(); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); + match row - col { + 0 => &mut self.d[row as usize], + 1 => &mut self.dl[col as usize], + -1 => &mut self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + +impl IndexMut<[i32; 2]> for Tridiagonal { + #[inline] + fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A { + &mut self[(row, col)] + } +} diff --git a/lax/src/tridiagonal/mod.rs b/lax/src/tridiagonal/mod.rs new file mode 100644 index 00000000..834ddd8a --- /dev/null +++ b/lax/src/tridiagonal/mod.rs @@ -0,0 +1,12 @@ +//! Implement linear solver using LU decomposition +//! for tridiagonal matrix + +mod lu; +mod matrix; +mod rcond; +mod solve; + +pub use lu::*; +pub use matrix::*; +pub use rcond::*; +pub use solve::*; diff --git a/lax/src/tridiagonal/rcond.rs b/lax/src/tridiagonal/rcond.rs new file mode 100644 index 00000000..a309cae4 --- /dev/null +++ b/lax/src/tridiagonal/rcond.rs @@ -0,0 +1,109 @@ +use crate::*; +use cauchy::*; +use num_traits::Zero; + +pub struct RcondTridiagonalWork { + pub work: Vec>, + pub iwork: Option>>, +} + +pub trait RcondTridiagonalWorkImpl { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Self; + fn calc( + &mut self, + lu: &LUFactorizedTridiagonal, + ) -> Result<::Real>; +} + +macro_rules! impl_rcond_tridiagonal_work_c { + ($c:ty, $gtcon:path) => { + impl RcondTridiagonalWorkImpl for RcondTridiagonalWork<$c> { + type Elem = $c; + + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let work = vec_uninit(2 * n as usize); + RcondTridiagonalWork { work, iwork: None } + } + + fn calc( + &mut self, + lu: &LUFactorizedTridiagonal, + ) -> Result<::Real> { + let (n, _) = lu.a.l.size(); + let ipiv = &lu.ipiv; + let mut rcond = ::Real::zero(); + let mut info = 0; + unsafe { + $gtcon( + NormType::One.as_ptr(), + &n, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + &lu.a_opnorm_one, + &mut rcond, + AsPtr::as_mut_ptr(&mut self.work), + &mut info, + ); + } + info.as_lapack_result()?; + Ok(rcond) + } + } + }; +} + +impl_rcond_tridiagonal_work_c!(c64, lapack_sys::zgtcon_); +impl_rcond_tridiagonal_work_c!(c32, lapack_sys::cgtcon_); + +macro_rules! impl_rcond_tridiagonal_work_r { + ($c:ty, $gtcon:path) => { + impl RcondTridiagonalWorkImpl for RcondTridiagonalWork<$c> { + type Elem = $c; + + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let work = vec_uninit(2 * n as usize); + let iwork = vec_uninit(n as usize); + RcondTridiagonalWork { + work, + iwork: Some(iwork), + } + } + + fn calc( + &mut self, + lu: &LUFactorizedTridiagonal, + ) -> Result<::Real> { + let (n, _) = lu.a.l.size(); + let mut rcond = ::Real::zero(); + let mut info = 0; + unsafe { + $gtcon( + NormType::One.as_ptr(), + &n, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + AsPtr::as_ptr(&lu.ipiv), + &lu.a_opnorm_one, + &mut rcond, + AsPtr::as_mut_ptr(&mut self.work), + AsPtr::as_mut_ptr(self.iwork.as_mut().unwrap()), + &mut info, + ); + } + info.as_lapack_result()?; + Ok(rcond) + } + } + }; +} + +impl_rcond_tridiagonal_work_r!(f64, lapack_sys::dgtcon_); +impl_rcond_tridiagonal_work_r!(f32, lapack_sys::sgtcon_); diff --git a/lax/src/tridiagonal/solve.rs b/lax/src/tridiagonal/solve.rs new file mode 100644 index 00000000..43f7d120 --- /dev/null +++ b/lax/src/tridiagonal/solve.rs @@ -0,0 +1,64 @@ +use crate::{error::*, layout::*, *}; +use cauchy::*; + +pub trait SolveTridiagonalImpl: Scalar { + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()>; +} + +macro_rules! impl_solve_tridiagonal { + ($s:ty, $trs:path) => { + impl SolveTridiagonalImpl for $s { + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + b_layout: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()> { + let (n, _) = lu.a.l.size(); + let ipiv = &lu.ipiv; + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + let (layout, t) = transpose(b_layout, b); + b_t = Some(t); + layout + } + MatrixLayout::F { .. } => b_layout, + }; + let (ldb, nrhs) = b_layout.size(); + let mut info = 0; + unsafe { + $trs( + t.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &ldb, + &mut info, + ); + } + info.as_lapack_result()?; + if let Some(b_t) = b_t { + transpose_over(b_layout, &b_t, b); + } + Ok(()) + } + } + }; +} + +impl_solve_tridiagonal!(c64, lapack_sys::zgttrs_); +impl_solve_tridiagonal!(c32, lapack_sys::cgttrs_); +impl_solve_tridiagonal!(f64, lapack_sys::dgttrs_); +impl_solve_tridiagonal!(f32, lapack_sys::sgttrs_);