Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #341 from rust-ndarray/lax-eigh-generalized-work
Merge `Eigh_` into `Lapack` trait, add working memory management
- Loading branch information
Showing
3 changed files
with
388 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,180 +1,190 @@ | ||
//! Eigenvalue problem for symmetric/Hermite matricies | ||
//! | ||
//! LAPACK correspondance | ||
//! ---------------------- | ||
//! | ||
//! | f32 | f64 | c32 | c64 | | ||
//! |:------|:------|:------|:------| | ||
//! | ssyev | dsyev | cheev | zheev | | ||
|
||
use super::*; | ||
use crate::{error::*, layout::MatrixLayout}; | ||
use cauchy::*; | ||
use num_traits::{ToPrimitive, Zero}; | ||
|
||
#[cfg_attr(doc, katexit::katexit)] | ||
/// Eigenvalue problem for symmetric/hermite matrix | ||
pub trait Eigh_: Scalar { | ||
/// Compute right eigenvalue and eigenvectors $Ax = \lambda x$ | ||
/// | ||
/// LAPACK correspondance | ||
/// ---------------------- | ||
/// | ||
/// | f32 | f64 | c32 | c64 | | ||
/// |:------|:------|:------|:------| | ||
/// | ssyev | dsyev | cheev | zheev | | ||
/// | ||
fn eigh( | ||
calc_eigenvec: bool, | ||
layout: MatrixLayout, | ||
uplo: UPLO, | ||
a: &mut [Self], | ||
) -> Result<Vec<Self::Real>>; | ||
pub struct EighWork<T: Scalar> { | ||
pub n: i32, | ||
pub jobz: JobEv, | ||
pub eigs: Vec<MaybeUninit<T::Real>>, | ||
pub work: Vec<MaybeUninit<T>>, | ||
pub rwork: Option<Vec<MaybeUninit<T::Real>>>, | ||
} | ||
|
||
/// Compute generalized right eigenvalue and eigenvectors $Ax = \lambda B x$ | ||
/// | ||
/// LAPACK correspondance | ||
/// ---------------------- | ||
/// | ||
/// | f32 | f64 | c32 | c64 | | ||
/// |:------|:------|:------|:------| | ||
/// | ssygv | dsygv | chegv | zhegv | | ||
/// | ||
fn eigh_generalized( | ||
calc_eigenvec: bool, | ||
layout: MatrixLayout, | ||
uplo: UPLO, | ||
a: &mut [Self], | ||
b: &mut [Self], | ||
) -> Result<Vec<Self::Real>>; | ||
pub trait EighWorkImpl: Sized { | ||
type Elem: Scalar; | ||
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>; | ||
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) | ||
-> Result<&[<Self::Elem as Scalar>::Real]>; | ||
fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Vec<<Self::Elem as Scalar>::Real>>; | ||
} | ||
|
||
macro_rules! impl_eigh { | ||
(@real, $scalar:ty, $ev:path, $evg:path) => { | ||
impl_eigh!(@body, $scalar, $ev, $evg, ); | ||
}; | ||
(@complex, $scalar:ty, $ev:path, $evg:path) => { | ||
impl_eigh!(@body, $scalar, $ev, $evg, rwork); | ||
}; | ||
(@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => { | ||
impl Eigh_ for $scalar { | ||
fn eigh( | ||
calc_v: bool, | ||
layout: MatrixLayout, | ||
uplo: UPLO, | ||
a: &mut [Self], | ||
) -> Result<Vec<Self::Real>> { | ||
macro_rules! impl_eigh_work_c { | ||
($c:ty, $ev:path) => { | ||
impl EighWorkImpl for EighWork<$c> { | ||
type Elem = $c; | ||
|
||
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> { | ||
assert_eq!(layout.len(), layout.lda()); | ||
let n = layout.len(); | ||
let jobz = if calc_v { JobEv::All } else { JobEv::None }; | ||
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize); | ||
|
||
$( | ||
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2 as usize); | ||
)* | ||
|
||
// calc work size | ||
let jobz = if calc_eigenvectors { | ||
JobEv::All | ||
} else { | ||
JobEv::None | ||
}; | ||
let mut eigs = vec_uninit(n as usize); | ||
let mut rwork = vec_uninit(3 * n as usize - 2 as usize); | ||
let mut info = 0; | ||
let mut work_size = [Self::zero()]; | ||
let mut work_size = [Self::Elem::zero()]; | ||
unsafe { | ||
$ev( | ||
jobz.as_ptr() , | ||
uplo.as_ptr(), | ||
jobz.as_ptr(), | ||
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO | ||
&n, | ||
AsPtr::as_mut_ptr(a), | ||
std::ptr::null_mut(), | ||
&n, | ||
AsPtr::as_mut_ptr(&mut eigs), | ||
AsPtr::as_mut_ptr(&mut work_size), | ||
&(-1), | ||
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)* | ||
AsPtr::as_mut_ptr(&mut rwork), | ||
&mut info, | ||
); | ||
} | ||
info.as_lapack_result()?; | ||
|
||
// actual ev | ||
let lwork = work_size[0].to_usize().unwrap(); | ||
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork); | ||
let lwork = lwork as i32; | ||
let work = vec_uninit(lwork); | ||
Ok(EighWork { | ||
n, | ||
eigs, | ||
jobz, | ||
work, | ||
rwork: Some(rwork), | ||
}) | ||
} | ||
|
||
fn calc( | ||
&mut self, | ||
uplo: UPLO, | ||
a: &mut [Self::Elem], | ||
) -> Result<&[<Self::Elem as Scalar>::Real]> { | ||
let lwork = self.work.len().to_i32().unwrap(); | ||
let mut info = 0; | ||
unsafe { | ||
$ev( | ||
jobz.as_ptr(), | ||
self.jobz.as_ptr(), | ||
uplo.as_ptr(), | ||
&n, | ||
&self.n, | ||
AsPtr::as_mut_ptr(a), | ||
&n, | ||
AsPtr::as_mut_ptr(&mut eigs), | ||
AsPtr::as_mut_ptr(&mut work), | ||
&self.n, | ||
AsPtr::as_mut_ptr(&mut self.eigs), | ||
AsPtr::as_mut_ptr(&mut self.work), | ||
&lwork, | ||
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)* | ||
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), | ||
&mut info, | ||
); | ||
} | ||
info.as_lapack_result()?; | ||
|
||
let eigs = unsafe { eigs.assume_init() }; | ||
Ok(eigs) | ||
Ok(unsafe { self.eigs.slice_assume_init_ref() }) | ||
} | ||
|
||
fn eigh_generalized( | ||
calc_v: bool, | ||
layout: MatrixLayout, | ||
fn eval( | ||
mut self, | ||
uplo: UPLO, | ||
a: &mut [Self], | ||
b: &mut [Self], | ||
) -> Result<Vec<Self::Real>> { | ||
assert_eq!(layout.len(), layout.lda()); | ||
let n = layout.len(); | ||
let jobz = if calc_v { JobEv::All } else { JobEv::None }; | ||
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize); | ||
a: &mut [Self::Elem], | ||
) -> Result<Vec<<Self::Elem as Scalar>::Real>> { | ||
let _eig = self.calc(uplo, a)?; | ||
Ok(unsafe { self.eigs.assume_init() }) | ||
} | ||
} | ||
}; | ||
} | ||
impl_eigh_work_c!(c64, lapack_sys::zheev_); | ||
impl_eigh_work_c!(c32, lapack_sys::cheev_); | ||
|
||
$( | ||
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2); | ||
)* | ||
macro_rules! impl_eigh_work_r { | ||
($f:ty, $ev:path) => { | ||
impl EighWorkImpl for EighWork<$f> { | ||
type Elem = $f; | ||
|
||
// calc work size | ||
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> { | ||
assert_eq!(layout.len(), layout.lda()); | ||
let n = layout.len(); | ||
let jobz = if calc_eigenvectors { | ||
JobEv::All | ||
} else { | ||
JobEv::None | ||
}; | ||
let mut eigs = vec_uninit(n as usize); | ||
let mut info = 0; | ||
let mut work_size = [Self::zero()]; | ||
let mut work_size = [Self::Elem::zero()]; | ||
unsafe { | ||
$evg( | ||
&1, // ITYPE A*x = (lambda)*B*x | ||
$ev( | ||
jobz.as_ptr(), | ||
uplo.as_ptr(), | ||
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO | ||
&n, | ||
AsPtr::as_mut_ptr(a), | ||
&n, | ||
AsPtr::as_mut_ptr(b), | ||
std::ptr::null_mut(), | ||
&n, | ||
AsPtr::as_mut_ptr(&mut eigs), | ||
AsPtr::as_mut_ptr(&mut work_size), | ||
&(-1), | ||
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)* | ||
&mut info, | ||
); | ||
} | ||
info.as_lapack_result()?; | ||
|
||
// actual evg | ||
let lwork = work_size[0].to_usize().unwrap(); | ||
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork); | ||
let lwork = lwork as i32; | ||
let work = vec_uninit(lwork); | ||
Ok(EighWork { | ||
n, | ||
eigs, | ||
jobz, | ||
work, | ||
rwork: None, | ||
}) | ||
} | ||
|
||
fn calc( | ||
&mut self, | ||
uplo: UPLO, | ||
a: &mut [Self::Elem], | ||
) -> Result<&[<Self::Elem as Scalar>::Real]> { | ||
let lwork = self.work.len().to_i32().unwrap(); | ||
let mut info = 0; | ||
unsafe { | ||
$evg( | ||
&1, // ITYPE A*x = (lambda)*B*x | ||
jobz.as_ptr(), | ||
$ev( | ||
self.jobz.as_ptr(), | ||
uplo.as_ptr(), | ||
&n, | ||
&self.n, | ||
AsPtr::as_mut_ptr(a), | ||
&n, | ||
AsPtr::as_mut_ptr(b), | ||
&n, | ||
AsPtr::as_mut_ptr(&mut eigs), | ||
AsPtr::as_mut_ptr(&mut work), | ||
&self.n, | ||
AsPtr::as_mut_ptr(&mut self.eigs), | ||
AsPtr::as_mut_ptr(&mut self.work), | ||
&lwork, | ||
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)* | ||
&mut info, | ||
); | ||
} | ||
info.as_lapack_result()?; | ||
let eigs = unsafe { eigs.assume_init() }; | ||
Ok(eigs) | ||
Ok(unsafe { self.eigs.slice_assume_init_ref() }) | ||
} | ||
|
||
fn eval( | ||
mut self, | ||
uplo: UPLO, | ||
a: &mut [Self::Elem], | ||
) -> Result<Vec<<Self::Elem as Scalar>::Real>> { | ||
let _eig = self.calc(uplo, a)?; | ||
Ok(unsafe { self.eigs.assume_init() }) | ||
} | ||
} | ||
}; | ||
} // impl_eigh! | ||
|
||
impl_eigh!(@real, f64, lapack_sys::dsyev_, lapack_sys::dsygv_); | ||
impl_eigh!(@real, f32, lapack_sys::ssyev_, lapack_sys::ssygv_); | ||
impl_eigh!(@complex, c64, lapack_sys::zheev_, lapack_sys::zhegv_); | ||
impl_eigh!(@complex, c32, lapack_sys::cheev_, lapack_sys::chegv_); | ||
} | ||
impl_eigh_work_r!(f64, lapack_sys::dsyev_); | ||
impl_eigh_work_r!(f32, lapack_sys::ssyev_); |
Oops, something went wrong.