diff --git a/lax/src/solve.rs b/lax/src/solve.rs index d0f764fd..2f251fb2 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -64,6 +64,50 @@ pub trait Solve_: Scalar + Sized { fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } +pub struct InvWork { + pub work: Vec>, +} + +pub trait InvWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>; +} + +macro_rules! impl_inv_work { + ($s:ty, $tri:path) => { + impl InvWorkImpl for InvWork { + type Elem = c64; + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $tri( + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(InvWork { work }) + } + + fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()> { + todo!() + } + } + }; +} + +impl_inv_work!(c64, lapack_sys::zgetri_); + macro_rules! impl_solve { ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { impl Solve_ for $scalar {