Skip to content

Commit

Permalink
InvWork
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Sep 29, 2022
1 parent 07ab31d commit 77235f4
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions lax/src/solve.rs
Expand Up @@ -64,6 +64,69 @@ pub trait Solve_: Scalar + Sized {
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
}

pub struct InvWork<T: Scalar> {
pub layout: MatrixLayout,
pub work: Vec<MaybeUninit<T>>,
}

pub trait InvWorkImpl: Sized {
type Elem: Scalar;
fn new(layout: MatrixLayout) -> Result<Self>;
fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>;
}

macro_rules! impl_inv_work {
($s:ty, $tri:path) => {
impl InvWorkImpl for InvWork<$s> {
type Elem = $s;

fn new(layout: MatrixLayout) -> Result<Self> {
let (n, _) = layout.size();
let mut info = 0;
let mut work_size = [Self::Elem::zero()];
unsafe {
$tri(
&n,
std::ptr::null_mut(),
&layout.lda(),
std::ptr::null(),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
&mut info,
)
};
info.as_lapack_result()?;
let lwork = work_size[0].to_usize().unwrap();
let work = vec_uninit(lwork);
Ok(InvWork { layout, work })
}

fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
let lwork = self.work.len().to_i32().unwrap();
let mut info = 0;
unsafe {
$tri(
&self.layout.len(),
AsPtr::as_mut_ptr(a),
&self.layout.lda(),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(&mut self.work),
&lwork,
&mut info,
)
};
info.as_lapack_result()?;
Ok(())
}
}
};
}

impl_inv_work!(c64, lapack_sys::zgetri_);
impl_inv_work!(c32, lapack_sys::cgetri_);
impl_inv_work!(f64, lapack_sys::dgetri_);
impl_inv_work!(f32, lapack_sys::sgetri_);

macro_rules! impl_solve {
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
impl Solve_ for $scalar {
Expand Down

0 comments on commit 77235f4

Please sign in to comment.