Skip to content

Commit

Permalink
Merge pull request #184 from bytesnake/master
Browse files Browse the repository at this point in the history
Add LOBPCG solver for large symmetric positive definite eigenproblems
  • Loading branch information
termoshtt committed May 6, 2020
2 parents 8b55efc + ae2ce6a commit c033cb9
Show file tree
Hide file tree
Showing 13 changed files with 1,194 additions and 19 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ optional = true
paste = "0.1.9"
criterion = "0.3.1"

[[bench]]
name = "truncated_eig"
harness = false

[[bench]]
name = "eigh"
harness = false

41 changes: 41 additions & 0 deletions benches/truncated_eig.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#[macro_use]
extern crate criterion;

use criterion::Criterion;
use ndarray::*;
use ndarray_linalg::*;

macro_rules! impl_teig {
($n:expr) => {
paste::item! {
fn [<teig $n>](c: &mut Criterion) {
c.bench_function(&format!("truncated_eig{}", $n), |b| {
let a: Array2<f64> = random(($n, $n));
let a = a.t().dot(&a);

b.iter(move || {
let _result = TruncatedEig::new(a.clone(), TruncatedOrder::Largest).decompose(1);
})
});
c.bench_function(&format!("truncated_eig{}_t", $n), |b| {
let a: Array2<f64> = random(($n, $n).f());
let a = a.t().dot(&a);

b.iter(|| {
let _result = TruncatedEig::new(a.clone(), TruncatedOrder::Largest).decompose(1);
})
});
}
}
};
}

impl_teig!(4);
impl_teig!(8);
impl_teig!(16);
impl_teig!(32);
impl_teig!(64);
impl_teig!(128);

criterion_group!(teig, teig4, teig8, teig16, teig32, teig64, teig128);
criterion_main!(teig);
23 changes: 23 additions & 0 deletions examples/truncated_eig.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
extern crate ndarray;
extern crate ndarray_linalg;

use ndarray::*;
use ndarray_linalg::*;

fn main() {
let n = 10;
let v = random_unitary(n);

// set eigenvalues in decreasing order
let t = Array1::linspace(n as f64, -(n as f64), n);

println!("Generate spectrum: {:?}", &t);

let t = Array2::from_diag(&t);
let a = v.dot(&t.dot(&v.t()));

// calculate the truncated eigenproblem decomposition
for (val, _) in TruncatedEig::new(a, TruncatedOrder::Largest) {
println!("Found eigenvalue {}", val[0]);
}
}
22 changes: 22 additions & 0 deletions examples/truncated_svd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
extern crate ndarray;
extern crate ndarray_linalg;

use ndarray::*;
use ndarray_linalg::*;

fn main() {
let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);

// calculate the truncated singular value decomposition for 2 singular values
let result = TruncatedSvd::new(a, TruncatedOrder::Largest).decompose(2).unwrap();

// acquire singular values, left-singular vectors and right-singular vectors
let (u, sigma, v_t) = result.values_vectors();
println!("Result of the singular value decomposition A = UΣV^T:");
println!(" === U ===");
println!("{:?}", u);
println!(" === Σ ===");
println!("{:?}", Array2::from_diag(&sigma));
println!(" === V^T ===");
println!("{:?}", v_t);
}
65 changes: 65 additions & 0 deletions src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ where
}
}

impl<A, S, S2> EighInto for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
S2: DataMut<Elem = A>,
{
type EigVal = Array1<A::Real>;

fn eigh_into(mut self, uplo: UPLO) -> Result<(Self::EigVal, Self)> {
let (val, _) = self.eigh_inplace(uplo)?;
Ok((val, self))
}
}

impl<A, S> Eigh for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
Expand All @@ -56,6 +70,21 @@ where
}
}

impl<A, S, S2> Eigh for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
where
A: Scalar + Lapack,
S: Data<Elem = A>,
S2: Data<Elem = A>,
{
type EigVal = Array1<A::Real>;
type EigVec = (Array2<A>, Array2<A>);

fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)> {
let (a, b) = (self.0.to_owned(), self.1.to_owned());
(a, b).eigh_into(uplo)
}
}

impl<A, S> EighInplace for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
Expand All @@ -75,6 +104,42 @@ where
}
}

impl<A, S, S2> EighInplace for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
S2: DataMut<Elem = A>,
{
type EigVal = Array1<A::Real>;

fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
let layout = self.0.square_layout()?;
// XXX Force layout to be Fortran (see #146)
match layout {
MatrixLayout::C(_) => self.0.swap_axes(0, 1),
MatrixLayout::F(_) => {}
}

let layout = self.1.square_layout()?;
match layout {
MatrixLayout::C(_) => self.1.swap_axes(0, 1),
MatrixLayout::F(_) => {}
}

let s = unsafe {
A::eigh_generalized(
true,
self.0.square_layout()?,
uplo,
self.0.as_allocated_mut()?,
self.1.as_allocated_mut()?,
)?
};

Ok((ArrayBase::from(s), self))
}
}

/// Calculate eigenvalues without eigenvectors
pub trait EigValsh {
type EigVal;
Expand Down
42 changes: 37 additions & 5 deletions src/lapack/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@ use super::{into_result, UPLO};
/// Wraps `*syev` for real and `*heev` for complex
pub trait Eigh_: Scalar {
unsafe fn eigh(calc_eigenvec: bool, l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Vec<Self::Real>>;
unsafe fn eigh_generalized(
calc_eigenvec: bool,
l: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
b: &mut [Self],
) -> Result<Vec<Self::Real>>;
}

macro_rules! impl_eigh {
($scalar:ty, $ev:path) => {
($scalar:ty, $ev:path, $evg:path) => {
impl Eigh_ for $scalar {
unsafe fn eigh(calc_v: bool, l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<Vec<Self::Real>> {
let (n, _) = l.size();
Expand All @@ -24,11 +31,36 @@ macro_rules! impl_eigh {
let info = $ev(l.lapacke_layout(), jobz, uplo as u8, n, &mut a, n, &mut w);
into_result(info, w)
}

unsafe fn eigh_generalized(
calc_v: bool,
l: MatrixLayout,
uplo: UPLO,
mut a: &mut [Self],
mut b: &mut [Self],
) -> Result<Vec<Self::Real>> {
let (n, _) = l.size();
let jobz = if calc_v { b'V' } else { b'N' };
let mut w = vec![Self::Real::zero(); n as usize];
let info = $evg(
l.lapacke_layout(),
1,
jobz,
uplo as u8,
n,
&mut a,
n,
&mut b,
n,
&mut w,
);
into_result(info, w)
}
}
};
} // impl_eigh!

impl_eigh!(f64, lapacke::dsyev);
impl_eigh!(f32, lapacke::ssyev);
impl_eigh!(c64, lapacke::zheev);
impl_eigh!(c32, lapacke::cheev);
impl_eigh!(f64, lapacke::dsyev, lapacke::dsygv);
impl_eigh!(f32, lapacke::ssyev, lapacke::ssygv);
impl_eigh!(c64, lapacke::zheev, lapacke::zhegv);
impl_eigh!(c32, lapacke::cheev, lapacke::chegv);
16 changes: 4 additions & 12 deletions src/lapack/svddc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use num_traits::Zero;

use crate::error::*;
use crate::layout::MatrixLayout;
use crate::types::*;
use crate::svddc::UVTFlag;
use crate::types::*;

use super::{SVDOutput, into_result};
use super::{into_result, SVDOutput};

pub trait SVDDC_: Scalar {
unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
Expand All @@ -15,11 +15,7 @@ pub trait SVDDC_: Scalar {
macro_rules! impl_svdd {
($scalar:ty, $gesdd:path) => {
impl SVDDC_ for $scalar {
unsafe fn svddc(
l: MatrixLayout,
jobz: UVTFlag,
mut a: &mut [Self],
) -> Result<SVDOutput<Self>> {
unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
let (m, n) = l.size();
let k = m.min(n);
let lda = l.lda();
Expand Down Expand Up @@ -51,11 +47,7 @@ macro_rules! impl_svdd {
SVDOutput {
s: s,
u: if jobz == UVTFlag::None { None } else { Some(u) },
vt: if jobz == UVTFlag::None {
None
} else {
Some(vt)
},
vt: if jobz == UVTFlag::None { None } else { Some(vt) },
},
)
}
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
//! - [Random matrix generators](generate/index.html)
//! - [Scalar trait](types/trait.Scalar.html)

#[macro_use]
extern crate ndarray;

extern crate blas_src;
extern crate lapack_src;

Expand All @@ -52,6 +55,7 @@ pub mod inner;
pub mod krylov;
pub mod lapack;
pub mod layout;
pub mod lobpcg;
pub mod norm;
pub mod operator;
pub mod opnorm;
Expand All @@ -73,6 +77,7 @@ pub use eigh::*;
pub use generate::*;
pub use inner::*;
pub use layout::*;
pub use lobpcg::{TruncatedEig, TruncatedOrder, TruncatedSvd};
pub use norm::*;
pub use operator::*;
pub use opnorm::*;
Expand Down

0 comments on commit c033cb9

Please sign in to comment.