diff --git a/lax/src/eig.rs b/lax/src/eig.rs index 30ff5885..e4462be3 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -83,322 +83,335 @@ pub trait EigWorkImpl: Sized { fn eval(self, a: &mut [Self::Elem]) -> Result>; } -impl EigWorkImpl for EigWork { - type Elem = c64; - - fn new(calc_v: bool, l: MatrixLayout) -> Result { - let (n, _) = l.size(); - let (jobvl, jobvr) = if calc_v { - match l { - MatrixLayout::C { .. } => (JobEv::All, JobEv::None), - MatrixLayout::F { .. } => (JobEv::None, JobEv::All), +macro_rules! impl_eig_work_c { + ($c:ty, $ev:path) => { + impl EigWorkImpl for EigWork<$c> { + type Elem = $c; + + fn new(calc_v: bool, l: MatrixLayout) -> Result { + let (n, _) = l.size(); + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), + } + } else { + (JobEv::None, JobEv::None) + }; + let mut eigs = vec_uninit(n as usize); + let mut rwork = vec_uninit(2 * n as usize); + + let mut vc_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let mut vc_r = jobvr.then(|| vec_uninit((n * n) as usize)); + + // calc work size + let mut info = 0; + let mut work_size = [<$c>::zero()]; + unsafe { + $ev( + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(vc_l.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vc_r.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), + &mut info, + ) + }; + info.as_lapack_result()?; + + let lwork = work_size[0].to_usize().unwrap(); + let work: Vec> = vec_uninit(lwork); + Ok(Self { + n, + jobvl, + jobvr, + eigs, + eigs_re: None, + eigs_im: None, + rwork: Some(rwork), + vc_l, + vc_r, + vr_l: None, + vr_r: None, + work, + }) } - } else { - (JobEv::None, JobEv::None) - }; - let mut eigs: Vec> = vec_uninit(n as usize); - let mut rwork: Vec> = vec_uninit(2 * n as usize); - - let mut vc_l: Option>> = jobvl.then(|| vec_uninit((n * n) as usize)); - let mut vc_r: Option>> = jobvr.then(|| vec_uninit((n * n) as usize)); - - // calc work size - let mut info = 0; - let mut work_size = [c64::zero()]; - unsafe { - lapack_sys::zgeev_( - jobvl.as_ptr(), - jobvr.as_ptr(), - &n, - std::ptr::null_mut(), - &n, - AsPtr::as_mut_ptr(&mut eigs), - AsPtr::as_mut_ptr(vc_l.as_deref_mut().unwrap_or(&mut [])), - &n, - AsPtr::as_mut_ptr(vc_r.as_deref_mut().unwrap_or(&mut [])), - &n, - AsPtr::as_mut_ptr(&mut work_size), - &(-1), - AsPtr::as_mut_ptr(&mut rwork), - &mut info, - ) - }; - info.as_lapack_result()?; - - let lwork = work_size[0].to_usize().unwrap(); - let work: Vec> = vec_uninit(lwork); - Ok(Self { - n, - jobvl, - jobvr, - eigs, - eigs_re: None, - eigs_im: None, - rwork: Some(rwork), - vc_l, - vc_r, - vr_l: None, - vr_r: None, - work, - }) - } - fn calc<'work>(&'work mut self, a: &mut [c64]) -> Result> { - let lwork = self.work.len().to_i32().unwrap(); - let mut info = 0; - unsafe { - lapack_sys::zgeev_( - self.jobvl.as_ptr(), - self.jobvr.as_ptr(), - &self.n, - AsPtr::as_mut_ptr(a), - &self.n, - AsPtr::as_mut_ptr(&mut self.eigs), - AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(&mut self.work), - &lwork, - AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), - &mut info, - ) - }; - info.as_lapack_result()?; - - let eigs = unsafe { self.eigs.slice_assume_init_ref() }; - - // Hermite conjugate - if let Some(vl) = self.vc_l.as_mut() { - for value in vl { - let value = unsafe { value.assume_init_mut() }; - value.im = -value.im; + fn calc<'work>( + &'work mut self, + a: &mut [Self::Elem], + ) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $ev( + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(&mut self.eigs), + AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), + &mut info, + ) + }; + info.as_lapack_result()?; + // Hermite conjugate + if let Some(vl) = self.vc_l.as_mut() { + for value in vl { + let value = unsafe { value.assume_init_mut() }; + value.im = -value.im; + } + } + Ok(EigRef { + eigs: unsafe { self.eigs.slice_assume_init_ref() }, + vl: self + .vc_l + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + vr: self + .vc_r + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + }) } - } - Ok(EigRef { - eigs, - vl: self - .vc_l - .as_ref() - .map(|v| unsafe { v.slice_assume_init_ref() }), - vr: self - .vc_r - .as_ref() - .map(|v| unsafe { v.slice_assume_init_ref() }), - }) - } - fn eval(mut self, a: &mut [c64]) -> Result> { - let lwork = self.work.len().to_i32().unwrap(); - let mut info = 0; - unsafe { - lapack_sys::zgeev_( - self.jobvl.as_ptr(), - self.jobvr.as_ptr(), - &self.n, - AsPtr::as_mut_ptr(a), - &self.n, - AsPtr::as_mut_ptr(&mut self.eigs), - AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(&mut self.work), - &lwork, - AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), - &mut info, - ) - }; - info.as_lapack_result()?; - let eigs = unsafe { self.eigs.assume_init() }; - - // Hermite conjugate - if let Some(vl) = self.vc_l.as_mut() { - for value in vl { - let value = unsafe { value.assume_init_mut() }; - value.im = -value.im; + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $ev( + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(&mut self.eigs), + AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), + &mut info, + ) + }; + info.as_lapack_result()?; + // Hermite conjugate + if let Some(vl) = self.vc_l.as_mut() { + for value in vl { + let value = unsafe { value.assume_init_mut() }; + value.im = -value.im; + } + } + Ok(Eig { + eigs: unsafe { self.eigs.assume_init() }, + vl: self.vc_l.map(|v| unsafe { v.assume_init() }), + vr: self.vc_r.map(|v| unsafe { v.assume_init() }), + }) } } - Ok(Eig { - eigs, - vl: self.vc_l.map(|v| unsafe { v.assume_init() }), - vr: self.vc_r.map(|v| unsafe { v.assume_init() }), - }) - } + }; } -impl EigWorkImpl for EigWork { - type Elem = f64; +impl_eig_work_c!(c32, lapack_sys::cgeev_); +impl_eig_work_c!(c64, lapack_sys::zgeev_); + +macro_rules! impl_eig_work_r { + ($f:ty, $ev:path) => { + impl EigWorkImpl for EigWork<$f> { + type Elem = $f; + + fn new(calc_v: bool, l: MatrixLayout) -> Result { + let (n, _) = l.size(); + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), + } + } else { + (JobEv::None, JobEv::None) + }; + let mut eigs_re = vec_uninit(n as usize); + let mut eigs_im = vec_uninit(n as usize); + let mut vr_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let mut vr_r = jobvr.then(|| vec_uninit((n * n) as usize)); + let vc_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let vc_r = jobvr.then(|| vec_uninit((n * n) as usize)); + + // calc work size + let mut info = 0; + let mut work_size: [$f; 1] = [0.0]; + unsafe { + $ev( + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs_re), + AsPtr::as_mut_ptr(&mut eigs_im), + AsPtr::as_mut_ptr(vr_l.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr_r.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; - fn new(calc_v: bool, l: MatrixLayout) -> Result { - let (n, _) = l.size(); - let (jobvl, jobvr) = if calc_v { - match l { - MatrixLayout::C { .. } => (JobEv::All, JobEv::None), - MatrixLayout::F { .. } => (JobEv::None, JobEv::All), + // actual ev + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + + Ok(Self { + n, + jobvr, + jobvl, + eigs: vec_uninit(n as usize), + eigs_re: Some(eigs_re), + eigs_im: Some(eigs_im), + rwork: None, + vr_l, + vr_r, + vc_l, + vc_r, + work, + }) } - } else { - (JobEv::None, JobEv::None) - }; - let mut eigs_re: Vec> = vec_uninit(n as usize); - let mut eigs_im: Vec> = vec_uninit(n as usize); - - let mut vr_l: Option>> = jobvl.then(|| vec_uninit((n * n) as usize)); - let mut vr_r: Option>> = jobvr.then(|| vec_uninit((n * n) as usize)); - let vc_l: Option>> = jobvl.then(|| vec_uninit((n * n) as usize)); - let vc_r: Option>> = jobvr.then(|| vec_uninit((n * n) as usize)); - - // calc work size - let mut info = 0; - let mut work_size: [f64; 1] = [0.0]; - unsafe { - lapack_sys::dgeev_( - jobvl.as_ptr(), - jobvr.as_ptr(), - &n, - std::ptr::null_mut(), - &n, - AsPtr::as_mut_ptr(&mut eigs_re), - AsPtr::as_mut_ptr(&mut eigs_im), - AsPtr::as_mut_ptr(vr_l.as_deref_mut().unwrap_or(&mut [])), - &n, - AsPtr::as_mut_ptr(vr_r.as_deref_mut().unwrap_or(&mut [])), - &n, - AsPtr::as_mut_ptr(&mut work_size), - &(-1), - &mut info, - ) - }; - info.as_lapack_result()?; - - // actual ev - let lwork = work_size[0].to_usize().unwrap(); - let work: Vec> = vec_uninit(lwork); - - Ok(Self { - n, - jobvr, - jobvl, - eigs: vec_uninit(n as usize), - eigs_re: Some(eigs_re), - eigs_im: Some(eigs_im), - rwork: None, - vr_l, - vr_r, - vc_l, - vc_r, - work, - }) - } - fn calc<'work>(&'work mut self, a: &mut [f64]) -> Result> { - let lwork = self.work.len().to_i32().unwrap(); - let mut info = 0; - unsafe { - lapack_sys::dgeev_( - self.jobvl.as_ptr(), - self.jobvr.as_ptr(), - &self.n, - AsPtr::as_mut_ptr(a), - &self.n, - AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()), - AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()), - AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(&mut self.work), - &lwork, - &mut info, - ) - }; - info.as_lapack_result()?; - - let eigs_re: &[f64] = self - .eigs_re - .as_ref() - .map(|e| unsafe { e.slice_assume_init_ref() }) - .unwrap(); - let eigs_im: &[f64] = self - .eigs_im - .as_ref() - .map(|e| unsafe { e.slice_assume_init_ref() }) - .unwrap(); - reconstruct_eigs(eigs_re, eigs_im, &mut self.eigs); - - if let Some(v) = self.vr_l.as_ref() { - let v = unsafe { v.slice_assume_init_ref() }; - reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); - } - if let Some(v) = self.vr_r.as_ref() { - let v = unsafe { v.slice_assume_init_ref() }; - reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); - } + fn calc<'work>( + &'work mut self, + a: &mut [Self::Elem], + ) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $ev( + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; - Ok(EigRef { - eigs: unsafe { self.eigs.slice_assume_init_ref() }, - vl: self - .vc_l - .as_ref() - .map(|v| unsafe { v.slice_assume_init_ref() }), - vr: self - .vc_r - .as_ref() - .map(|v| unsafe { v.slice_assume_init_ref() }), - }) - } + let eigs_re = self + .eigs_re + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + let eigs_im = self + .eigs_im + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + reconstruct_eigs(eigs_re, eigs_im, &mut self.eigs); + + if let Some(v) = self.vr_l.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); + } + if let Some(v) = self.vr_r.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); + } - fn eval(mut self, a: &mut [f64]) -> Result> { - let lwork = self.work.len().to_i32().unwrap(); - let mut info = 0; - unsafe { - lapack_sys::dgeev_( - self.jobvl.as_ptr(), - self.jobvr.as_ptr(), - &self.n, - AsPtr::as_mut_ptr(a), - &self.n, - AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()), - AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()), - AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])), - &self.n, - AsPtr::as_mut_ptr(&mut self.work), - &lwork, - &mut info, - ) - }; - info.as_lapack_result()?; - - let eigs_re: &[f64] = self - .eigs_re - .as_ref() - .map(|e| unsafe { e.slice_assume_init_ref() }) - .unwrap(); - let eigs_im: &[f64] = self - .eigs_im - .as_ref() - .map(|e| unsafe { e.slice_assume_init_ref() }) - .unwrap(); - reconstruct_eigs(eigs_re, eigs_im, &mut self.eigs); - - if let Some(v) = self.vr_l.as_ref() { - let v = unsafe { v.slice_assume_init_ref() }; - reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); - } - if let Some(v) = self.vr_r.as_ref() { - let v = unsafe { v.slice_assume_init_ref() }; - reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); - } + Ok(EigRef { + eigs: unsafe { self.eigs.slice_assume_init_ref() }, + vl: self + .vc_l + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + vr: self + .vc_r + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + }) + } - Ok(Eig { - eigs: unsafe { self.eigs.assume_init() }, - vl: self.vc_l.map(|v| unsafe { v.assume_init() }), - vr: self.vc_r.map(|v| unsafe { v.assume_init() }), - }) - } + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $ev( + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + let eigs_re = self + .eigs_re + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + let eigs_im = self + .eigs_im + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + reconstruct_eigs(eigs_re, eigs_im, &mut self.eigs); + + if let Some(v) = self.vr_l.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); + } + if let Some(v) = self.vr_r.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap()); + } + + Ok(Eig { + eigs: unsafe { self.eigs.assume_init() }, + vl: self.vc_l.map(|v| unsafe { v.assume_init() }), + vr: self.vc_r.map(|v| unsafe { v.assume_init() }), + }) + } + } + }; } +impl_eig_work_r!(f32, lapack_sys::sgeev_); +impl_eig_work_r!(f64, lapack_sys::dgeev_); macro_rules! impl_eig_complex { ($scalar:ty, $ev:path) => {