Skip to content

Commit

Permalink
WIP: EigWorkImpl for f64
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Sep 23, 2022
1 parent 0a78593 commit ede3cd9
Showing 1 changed file with 153 additions and 30 deletions.
183 changes: 153 additions & 30 deletions lax/src/eig.rs
Expand Up @@ -40,16 +40,18 @@ pub struct EigWork<T: Scalar> {
pub jobvl: JobEv,

/// Eigenvalues used in complex routines
pub eigs: Option<Vec<MaybeUninit<T>>>,
pub eigs: Vec<MaybeUninit<T::Complex>>,
/// Real part of eigenvalues used in real routines
pub eigs_re: Option<Vec<MaybeUninit<T>>>,
pub eigs_re: Option<Vec<MaybeUninit<T::Real>>>,
/// Imaginary part of eigenvalues used in real routines
pub eigs_im: Option<Vec<MaybeUninit<T>>>,
pub eigs_im: Option<Vec<MaybeUninit<T::Real>>>,

/// Left eigenvectors
pub vl: Option<Vec<MaybeUninit<T>>>,
pub vc_l: Option<Vec<MaybeUninit<T::Complex>>>,
pub vr_l: Option<Vec<MaybeUninit<T::Real>>>,
/// Right eigenvectors
pub vr: Option<Vec<MaybeUninit<T>>>,
pub vc_r: Option<Vec<MaybeUninit<T::Complex>>>,
pub vr_r: Option<Vec<MaybeUninit<T::Real>>>,

/// Working memory
pub work: Vec<MaybeUninit<T>>,
Expand Down Expand Up @@ -97,8 +99,8 @@ impl EigWorkImpl for EigWork<c64> {
let mut eigs: Vec<MaybeUninit<c64>> = vec_uninit(n as usize);
let mut rwork: Vec<MaybeUninit<f64>> = vec_uninit(2 * n as usize);

let mut vl: Option<Vec<MaybeUninit<c64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
let mut vr: Option<Vec<MaybeUninit<c64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
let mut vc_l: Option<Vec<MaybeUninit<c64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
let mut vc_r: Option<Vec<MaybeUninit<c64>>> = jobvr.then(|| vec_uninit((n * n) as usize));

// calc work size
let mut info = 0;
Expand All @@ -111,9 +113,9 @@ impl EigWorkImpl for EigWork<c64> {
std::ptr::null_mut(),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(vl.as_deref_mut().unwrap_or(&mut [])),
AsPtr::as_mut_ptr(vc_l.as_deref_mut().unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(vr.as_deref_mut().unwrap_or(&mut [])),
AsPtr::as_mut_ptr(vc_r.as_deref_mut().unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
Expand All @@ -129,12 +131,14 @@ impl EigWorkImpl for EigWork<c64> {
n,
jobvl,
jobvr,
eigs: Some(eigs),
eigs,
eigs_re: None,
eigs_im: None,
rwork: Some(rwork),
vl,
vr,
vc_l,
vc_r,
vr_l: None,
vr_r: None,
work,
})
}
Expand All @@ -149,10 +153,10 @@ impl EigWorkImpl for EigWork<c64> {
&self.n,
AsPtr::as_mut_ptr(a),
&self.n,
AsPtr::as_mut_ptr(self.eigs.as_mut().unwrap()),
AsPtr::as_mut_ptr(self.vl.as_deref_mut().unwrap_or(&mut [])),
AsPtr::as_mut_ptr(&mut self.eigs),
AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])),
&self.n,
AsPtr::as_mut_ptr(self.vr.as_deref_mut().unwrap_or(&mut [])),
AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])),
&self.n,
AsPtr::as_mut_ptr(&mut self.work),
&lwork,
Expand All @@ -162,14 +166,10 @@ impl EigWorkImpl for EigWork<c64> {
};
info.as_lapack_result()?;

let eigs = self
.eigs
.as_ref()
.map(|v| unsafe { v.slice_assume_init_ref() })
.unwrap();
let eigs = unsafe { self.eigs.slice_assume_init_ref() };

// Hermite conjugate
if let Some(vl) = self.vl.as_mut() {
if let Some(vl) = self.vc_l.as_mut() {
for value in vl {
let value = unsafe { value.assume_init_mut() };
value.im = -value.im;
Expand All @@ -178,11 +178,11 @@ impl EigWorkImpl for EigWork<c64> {
Ok(EigRef {
eigs,
vl: self
.vl
.vc_l
.as_ref()
.map(|v| unsafe { v.slice_assume_init_ref() }),
vr: self
.vr
.vc_r
.as_ref()
.map(|v| unsafe { v.slice_assume_init_ref() }),
})
Expand All @@ -198,10 +198,10 @@ impl EigWorkImpl for EigWork<c64> {
&self.n,
AsPtr::as_mut_ptr(a),
&self.n,
AsPtr::as_mut_ptr(self.eigs.as_mut().unwrap()),
AsPtr::as_mut_ptr(self.vl.as_deref_mut().unwrap_or(&mut [])),
AsPtr::as_mut_ptr(&mut self.eigs),
AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])),
&self.n,
AsPtr::as_mut_ptr(self.vr.as_deref_mut().unwrap_or(&mut [])),
AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])),
&self.n,
AsPtr::as_mut_ptr(&mut self.work),
&lwork,
Expand All @@ -210,21 +210,134 @@ impl EigWorkImpl for EigWork<c64> {
)
};
info.as_lapack_result()?;
let eigs = self.eigs.map(|v| unsafe { v.assume_init() }).unwrap();
let eigs = unsafe { self.eigs.assume_init() };

// Hermite conjugate
if let Some(vl) = self.vl.as_mut() {
if let Some(vl) = self.vc_l.as_mut() {
for value in vl {
let value = unsafe { value.assume_init_mut() };
value.im = -value.im;
}
}
Ok(Eig {
eigs,
vl: self.vl.map(|v| unsafe { v.assume_init() }),
vr: self.vr.map(|v| unsafe { v.assume_init() }),
vl: self.vc_l.map(|v| unsafe { v.assume_init() }),
vr: self.vc_r.map(|v| unsafe { v.assume_init() }),
})
}
}

impl EigWorkImpl for EigWork<f64> {
type Elem = f64;

fn new(calc_v: bool, l: MatrixLayout) -> Result<Self> {
let (n, _) = l.size();
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
}
} else {
(JobEv::None, JobEv::None)
};
let mut eigs_re: Vec<MaybeUninit<f64>> = vec_uninit(n as usize);
let mut eigs_im: Vec<MaybeUninit<f64>> = vec_uninit(n as usize);

let mut vr_l: Option<Vec<MaybeUninit<f64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
let mut vr_r: Option<Vec<MaybeUninit<f64>>> = jobvr.then(|| vec_uninit((n * n) as usize));

// calc work size
let mut info = 0;
let mut work_size: [f64; 1] = [0.0];
unsafe {
lapack_sys::dgeev_(
jobvl.as_ptr(),
jobvr.as_ptr(),
&n,
std::ptr::null_mut(),
&n,
AsPtr::as_mut_ptr(&mut eigs_re),
AsPtr::as_mut_ptr(&mut eigs_im),
AsPtr::as_mut_ptr(vr_l.as_deref_mut().unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(vr_r.as_deref_mut().unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
&mut info,
)
};
info.as_lapack_result()?;

// actual ev
let lwork = work_size[0].to_usize().unwrap();
let work: Vec<MaybeUninit<f64>> = vec_uninit(lwork);

Ok(Self {
n,
jobvr,
jobvl,
eigs: vec_uninit(n as usize),
eigs_re: Some(eigs_re),
eigs_im: Some(eigs_im),
rwork: None,
vr_l,
vr_r,
vc_l: None,
vc_r: None,
work,
})
}

fn calc<'work>(&'work mut self, _a: &mut [f64]) -> Result<EigRef<'work, f64>> {
todo!()
}

fn eval(mut self, a: &mut [f64]) -> Result<Eig<f64>> {
let lwork = self.work.len().to_i32().unwrap();
let mut info = 0;
unsafe {
lapack_sys::dgeev_(
self.jobvl.as_ptr(),
self.jobvr.as_ptr(),
&self.n,
AsPtr::as_mut_ptr(a),
&self.n,
AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()),
AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()),
AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])),
&self.n,
AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])),
&self.n,
AsPtr::as_mut_ptr(&mut self.work),
&lwork,
&mut info,
)
};
info.as_lapack_result()?;

let eigs_re = unsafe { self.eigs_re.unwrap().assume_init() };
let eigs_im = unsafe { self.eigs_im.unwrap().assume_init() };

let n = self.n as usize;
let vl = self.vr_l.map(|v| {
let v = unsafe { v.assume_init() };
let mut vc = vec_uninit(n * n);
reconstruct_eigenvectors(false, &eigs_im, &v, &mut vc);
unsafe { vc.assume_init() }
});
let vr = self.vr_r.map(|v| {
let v = unsafe { v.assume_init() };
let mut vc = vec_uninit(n * n);
reconstruct_eigenvectors(true, &eigs_im, &v, &mut vc);
unsafe { vc.assume_init() }
});

reconstruct_eigs(&eigs_re, &eigs_im, &mut self.eigs);
let eigs = unsafe { self.eigs.assume_init() };

Ok(Eig { eigs, vl, vr })
}
}

macro_rules! impl_eig_complex {
Expand Down Expand Up @@ -497,3 +610,13 @@ fn reconstruct_eigenvectors<T: Scalar>(
}
}
}

/// Create complex eigenvalues from real and imaginary parts.
fn reconstruct_eigs<T: Scalar>(re: &[T], im: &[T], eigs: &mut [MaybeUninit<T::Complex>]) {
let n = eigs.len();
assert_eq!(re.len(), n);
assert_eq!(im.len(), n);
for i in 0..n {
eigs[i].write(T::complex(re[i], im[i]));
}
}

0 comments on commit ede3cd9

Please sign in to comment.