From 77235f4f3fbb705d177fc36cadf2b247e913bef9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 29 Sep 2022 20:54:44 +0900 Subject: [PATCH] InvWork --- lax/src/solve.rs | 63 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index d0f764fd..89dc1113 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -64,6 +64,69 @@ pub trait Solve_: Scalar + Sized { fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } +pub struct InvWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +pub trait InvWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>; +} + +macro_rules! impl_inv_work { + ($s:ty, $tri:path) => { + impl InvWorkImpl for InvWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $tri( + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(InvWork { layout, work }) + } + + fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $tri( + &self.layout.len(), + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(()) + } + } + }; +} + +impl_inv_work!(c64, lapack_sys::zgetri_); +impl_inv_work!(c32, lapack_sys::cgetri_); +impl_inv_work!(f64, lapack_sys::dgetri_); +impl_inv_work!(f32, lapack_sys::sgetri_); + macro_rules! impl_solve { ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { impl Solve_ for $scalar {