Skip to content

Commit

Permalink
Split generalized eigenvalue routine
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Sep 24, 2022
1 parent 33e2dc3 commit c953001
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 90 deletions.
99 changes: 9 additions & 90 deletions lax/src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,16 @@ pub trait Eigh_: Scalar {
uplo: UPLO,
a: &mut [Self],
) -> Result<Vec<Self::Real>>;

/// Compute generalized right eigenvalue and eigenvectors $Ax = \lambda B x$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:------|:------|:------|:------|
/// | ssygv | dsygv | chegv | zhegv |
///
fn eigh_generalized(
calc_eigenvec: bool,
layout: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
b: &mut [Self],
) -> Result<Vec<Self::Real>>;
}

macro_rules! impl_eigh {
(@real, $scalar:ty, $ev:path, $evg:path) => {
impl_eigh!(@body, $scalar, $ev, $evg, );
(@real, $scalar:ty, $ev:path) => {
impl_eigh!(@body, $scalar, $ev, );
};
(@complex, $scalar:ty, $ev:path, $evg:path) => {
impl_eigh!(@body, $scalar, $ev, $evg, rwork);
(@complex, $scalar:ty, $ev:path) => {
impl_eigh!(@body, $scalar, $ev, rwork);
};
(@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => {
(@body, $scalar:ty, $ev:path, $($rwork_ident:ident),*) => {
impl Eigh_ for $scalar {
fn eigh(
calc_v: bool,
Expand Down Expand Up @@ -106,75 +89,11 @@ macro_rules! impl_eigh {
let eigs = unsafe { eigs.assume_init() };
Ok(eigs)
}

fn eigh_generalized(
calc_v: bool,
layout: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
b: &mut [Self],
) -> Result<Vec<Self::Real>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize);

$(
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2);
)*

// calc work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe {
$evg(
&1, // ITYPE A*x = (lambda)*B*x
jobz.as_ptr(),
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(b),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;

// actual evg
let lwork = work_size[0].to_usize().unwrap();
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
let lwork = lwork as i32;
unsafe {
$evg(
&1, // ITYPE A*x = (lambda)*B*x
jobz.as_ptr(),
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(b),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work),
&lwork,
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;
let eigs = unsafe { eigs.assume_init() };
Ok(eigs)
}
}
};
} // impl_eigh!

impl_eigh!(@real, f64, lapack_sys::dsyev_, lapack_sys::dsygv_);
impl_eigh!(@real, f32, lapack_sys::ssyev_, lapack_sys::ssygv_);
impl_eigh!(@complex, c64, lapack_sys::zheev_, lapack_sys::zhegv_);
impl_eigh!(@complex, c32, lapack_sys::cheev_, lapack_sys::chegv_);
impl_eigh!(@real, f64, lapack_sys::dsyev_);
impl_eigh!(@real, f32, lapack_sys::ssyev_);
impl_eigh!(@complex, c64, lapack_sys::zheev_);
impl_eigh!(@complex, c32, lapack_sys::cheev_);
106 changes: 106 additions & 0 deletions lax/src/eigh_generalized.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use super::*;
use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

#[cfg_attr(doc, katexit::katexit)]
/// Eigenvalue problem for symmetric/hermite matrix
pub trait EighGeneralized_: Scalar {
/// Compute generalized right eigenvalue and eigenvectors $Ax = \lambda B x$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:------|:------|:------|:------|
/// | ssygv | dsygv | chegv | zhegv |
///
fn eigh_generalized(
calc_eigenvec: bool,
layout: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
b: &mut [Self],
) -> Result<Vec<Self::Real>>;
}

macro_rules! impl_eigh {
(@real, $scalar:ty, $evg:path) => {
impl_eigh!(@body, $scalar, $evg, );
};
(@complex, $scalar:ty, $evg:path) => {
impl_eigh!(@body, $scalar, $evg, rwork);
};
(@body, $scalar:ty, $evg:path, $($rwork_ident:ident),*) => {
impl EighGeneralized_ for $scalar {
fn eigh_generalized(
calc_v: bool,
layout: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
b: &mut [Self],
) -> Result<Vec<Self::Real>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize);

$(
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2);
)*

// calc work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe {
$evg(
&1, // ITYPE A*x = (lambda)*B*x
jobz.as_ptr(),
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(b),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;

// actual evg
let lwork = work_size[0].to_usize().unwrap();
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
let lwork = lwork as i32;
unsafe {
$evg(
&1, // ITYPE A*x = (lambda)*B*x
jobz.as_ptr(),
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(b),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work),
&lwork,
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;
let eigs = unsafe { eigs.assume_init() };
Ok(eigs)
}
}
};
} // impl_eigh!

impl_eigh!(@real, f64, lapack_sys::dsygv_);
impl_eigh!(@real, f32, lapack_sys::ssygv_);
impl_eigh!(@complex, c64, lapack_sys::zhegv_);
impl_eigh!(@complex, c32, lapack_sys::chegv_);
3 changes: 3 additions & 0 deletions lax/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub mod eig;
mod alloc;
mod cholesky;
mod eigh;
mod eigh_generalized;
mod least_squares;
mod opnorm;
mod qr;
Expand All @@ -102,6 +103,7 @@ mod tridiagonal;

pub use self::cholesky::*;
pub use self::eigh::*;
pub use self::eigh_generalized::*;
pub use self::flags::*;
pub use self::least_squares::*;
pub use self::opnorm::*;
Expand Down Expand Up @@ -130,6 +132,7 @@ pub trait Lapack:
+ Solveh_
+ Cholesky_
+ Eigh_
+ EighGeneralized_
+ Triangular_
+ Tridiagonal_
+ Rcond_
Expand Down

0 comments on commit c953001

Please sign in to comment.