From cc369fbb41835a5999278fd164033b94b47e0aeb Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 26 Sep 2022 01:48:53 +0900 Subject: [PATCH] HouseholderWork and HouseholderWorkImpl --- lax/src/qr.rs | 110 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/lax/src/qr.rs b/lax/src/qr.rs index bdfb7571..bbaf9cd0 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -18,6 +18,116 @@ pub trait QR_: Sized { fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; } +pub struct HouseholderWork { + pub m: i32, + pub n: i32, + pub layout: MatrixLayout, + pub tau: Vec>, + pub work: Vec>, +} + +pub trait HouseholderWorkImpl: Sized { + type Elem: Scalar; + fn new(l: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]>; + fn eval(self, a: &mut [Self::Elem]) -> Result>; +} + +macro_rules! impl_householder_work { + ($s:ty, $qrf:path, $lqf: path) => { + impl HouseholderWorkImpl for HouseholderWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let m = layout.lda(); + let n = layout.len(); + let k = m.min(n); + let mut tau = vec_uninit(k as usize); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + match layout { + MatrixLayout::F { .. } => unsafe { + $qrf( + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }, + MatrixLayout::C { .. } => unsafe { + $lqf( + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }, + } + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(HouseholderWork { + n, + m, + layout, + tau, + work, + }) + } + + fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + match self.layout { + MatrixLayout::F { .. } => unsafe { + $qrf( + &self.m, + &self.n, + AsPtr::as_mut_ptr(a), + &self.m, + AsPtr::as_mut_ptr(&mut self.tau), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ); + }, + MatrixLayout::C { .. } => unsafe { + $lqf( + &self.m, + &self.n, + AsPtr::as_mut_ptr(a), + &self.m, + AsPtr::as_mut_ptr(&mut self.tau), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ); + }, + } + info.as_lapack_result()?; + Ok(unsafe { self.tau.slice_assume_init_ref() }) + } + + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _eig = self.calc(a)?; + Ok(unsafe { self.tau.assume_init() }) + } + } + }; +} +impl_householder_work!(c64, lapack_sys::zgeqrf_, lapack_sys::zgelqf_); +impl_householder_work!(c32, lapack_sys::cgeqrf_, lapack_sys::cgelqf_); +impl_householder_work!(f64, lapack_sys::dgeqrf_, lapack_sys::dgelqf_); +impl_householder_work!(f32, lapack_sys::sgeqrf_, lapack_sys::sgelqf_); + macro_rules! impl_qr { ($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => { impl QR_ for $scalar {