diff --git a/lax/src/eig.rs b/lax/src/eig.rs index beb480e6..b730d3c9 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -77,26 +77,28 @@ impl EigWork { } } +#[derive(Debug, Clone, PartialEq)] +pub struct Eig { + pub eigs: Vec, + pub vr: Option>, + pub vl: Option>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct EigRef<'work, T: Scalar> { + pub eigs: &'work [T::Complex], + pub vr: Option<&'work [T::Complex]>, + pub vl: Option<&'work [T::Complex]>, +} + pub trait EigWorkImpl: Sized { type Elem: Scalar; /// Create new working memory for eigenvalues compution. fn new(calc_v: bool, l: MatrixLayout) -> Result; /// Compute eigenvalues and vectors on this working memory. - fn calc<'work>( - &'work mut self, - a: &mut [Self::Elem], - ) -> Result<( - &'work [::Complex], - Option<&'work [::Complex]>, - )>; + fn calc<'work>(&'work mut self, a: &mut [Self::Elem]) -> Result>; /// Compute eigenvalues and vectors by consuming this working memory. - fn eval( - self, - a: &mut [Self::Elem], - ) -> Result<( - Vec<::Complex>, - Option::Complex>>, - )>; + fn eval(self, a: &mut [Self::Elem]) -> Result>; } impl EigWorkImpl for EigWork { @@ -157,7 +159,7 @@ impl EigWorkImpl for EigWork { }) } - fn calc<'work>(&'work mut self, a: &mut [c64]) -> Result<(&'work [c64], Option<&'work [c64]>)> { + fn calc<'work>(&'work mut self, a: &mut [c64]) -> Result> { let lwork = self.work.len().to_i32().unwrap(); let mut info = 0; unsafe { @@ -193,15 +195,20 @@ impl EigWorkImpl for EigWork { value.im = -value.im; } } - let v = match (self.vl.as_ref(), self.vr.as_ref()) { - (Some(v), None) | (None, Some(v)) => Some(unsafe { v.slice_assume_init_ref() }), - (None, None) => None, - _ => unreachable!(), - }; - Ok((eigs, v)) + Ok(EigRef { + eigs, + vl: self + .vl + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + vr: self + .vr + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + }) } - fn eval(mut self, a: &mut [c64]) -> Result<(Vec, Option>)> { + fn eval(mut self, a: &mut [c64]) -> Result> { let lwork = self.work.len().to_i32().unwrap(); let mut info = 0; unsafe { @@ -232,12 +239,11 @@ impl EigWorkImpl for EigWork { value.im = -value.im; } } - let v = match (self.vl, self.vr) { - (Some(v), None) | (None, Some(v)) => Some(unsafe { v.assume_init() }), - (None, None) => None, - _ => unreachable!(), - }; - Ok((eigs, v)) + Ok(Eig { + eigs, + vl: self.vl.map(|v| unsafe { v.assume_init() }), + vr: self.vr.map(|v| unsafe { v.assume_init() }), + }) } } @@ -301,14 +307,11 @@ impl EigWorkImpl for EigWork { }) } - fn calc<'work>( - &'work mut self, - _a: &mut [f64], - ) -> Result<(&'work [c64], Option<&'work [c64]>)> { + fn calc<'work>(&'work mut self, _a: &mut [f64]) -> Result> { todo!() } - fn eval(mut self, a: &mut [f64]) -> Result<(Vec, Option>)> { + fn eval(mut self, a: &mut [f64]) -> Result> { let lwork = self.work.len().to_i32().unwrap(); let mut info = 0; unsafe { @@ -343,12 +346,7 @@ impl EigWorkImpl for EigWork { .map(|(&re, &im)| c64::new(re, im)) .collect(); - if self.jobvl.is_calc() || self.jobvr.is_calc() { - let eigvecs = reconstruct_eigenvectors(self.jobvl, self.n, &eig_im, vr, vl); - Ok((eigs, Some(eigvecs))) - } else { - Ok((eigs, None)) - } + Ok(Eig { eigs, vl, vr }) } }