Skip to content

Commit

Permalink
Merge pull request #328 from rust-ndarray/use-lapack-sys-directly
Browse files Browse the repository at this point in the history
Use lapack-sys crate directly from lax crate
  • Loading branch information
termoshtt committed Aug 31, 2022
2 parents 0cfde73 + 2fc0ac8 commit 791713f
Show file tree
Hide file tree
Showing 15 changed files with 682 additions and 411 deletions.
2 changes: 1 addition & 1 deletion lax/Cargo.toml
Expand Up @@ -32,7 +32,7 @@ intel-mkl-system = ["intel-mkl-src/mkl-dynamic-lp64-seq"]
thiserror = "1.0.24"
cauchy = "0.4.0"
num-traits = "0.2.14"
lapack = "0.18.0"
lapack-sys = "0.14.0"

[dependencies.intel-mkl-src]
version = "0.7.0"
Expand Down
43 changes: 36 additions & 7 deletions lax/src/cholesky.rs
Expand Up @@ -29,7 +29,7 @@ macro_rules! impl_cholesky {
}
let mut info = 0;
unsafe {
$trf(uplo as u8, n, a, n, &mut info);
$trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info);
}
info.as_lapack_result()?;
if matches!(l, MatrixLayout::C { .. }) {
Expand All @@ -45,7 +45,7 @@ macro_rules! impl_cholesky {
}
let mut info = 0;
unsafe {
$tri(uplo as u8, n, a, l.lda(), &mut info);
$tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info);
}
info.as_lapack_result()?;
if matches!(l, MatrixLayout::C { .. }) {
Expand All @@ -70,7 +70,16 @@ macro_rules! impl_cholesky {
}
}
unsafe {
$trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
$trs(
uplo.as_ptr(),
&n,
&nrhs,
AsPtr::as_ptr(a),
&l.lda(),
AsPtr::as_mut_ptr(b),
&n,
&mut info,
);
}
info.as_lapack_result()?;
if matches!(l, MatrixLayout::C { .. }) {
Expand All @@ -84,7 +93,27 @@ macro_rules! impl_cholesky {
};
} // end macro_rules

impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs);
impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs);
impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs);
impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs);
impl_cholesky!(
f64,
lapack_sys::dpotrf_,
lapack_sys::dpotri_,
lapack_sys::dpotrs_
);
impl_cholesky!(
f32,
lapack_sys::spotrf_,
lapack_sys::spotri_,
lapack_sys::spotrs_
);
impl_cholesky!(
c64,
lapack_sys::zpotrf_,
lapack_sys::zpotri_,
lapack_sys::zpotrs_
);
impl_cholesky!(
c32,
lapack_sys::cpotrf_,
lapack_sys::cpotri_,
lapack_sys::cpotrs_
);
174 changes: 82 additions & 92 deletions lax/src/eig.rs
Expand Up @@ -20,7 +20,7 @@ macro_rules! impl_eig_complex {
fn eig(
calc_v: bool,
l: MatrixLayout,
mut a: &mut [Self],
a: &mut [Self],
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
let (n, _) = l.size();
// LAPACK assumes a column-major input. A row-major input can
Expand All @@ -35,74 +35,69 @@ macro_rules! impl_eig_complex {
// eigenvalues are the eigenvalues computed with `A`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (b'V', b'N'),
MatrixLayout::F { .. } => (b'N', b'V'),
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
}
} else {
(b'N', b'N')
(EigenVectorFlag::Not, EigenVectorFlag::Not)
};
let mut eigs = unsafe { vec_uninit(n as usize) };
let mut rwork = unsafe { vec_uninit(2 * n as usize) };
let mut rwork: Vec<Self::Real> = unsafe { vec_uninit(2 * n as usize) };

let mut vl = if jobvl == b'V' {
Some(unsafe { vec_uninit((n * n) as usize) })
} else {
None
};
let mut vr = if jobvr == b'V' {
Some(unsafe { vec_uninit((n * n) as usize) })
} else {
None
};
let mut vl: Option<Vec<Self>> =
jobvl.then(|| unsafe { vec_uninit((n * n) as usize) });
let mut vr: Option<Vec<Self>> =
jobvr.then(|| unsafe { vec_uninit((n * n) as usize) });

// calc work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe {
$ev(
jobvl,
jobvr,
n,
&mut a,
n,
&mut eigs,
&mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
&mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work_size,
-1,
&mut rwork,
jobvl.as_ptr(),
jobvr.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
AsPtr::as_mut_ptr(&mut rwork),
&mut info,
)
};
info.as_lapack_result()?;

// actal ev
let lwork = work_size[0].to_usize().unwrap();
let mut work = unsafe { vec_uninit(lwork) };
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
let lwork = lwork as i32;
unsafe {
$ev(
jobvl,
jobvr,
n,
&mut a,
n,
&mut eigs,
&mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
&mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work,
lwork as i32,
&mut rwork,
jobvl.as_ptr(),
jobvr.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work),
&lwork,
AsPtr::as_mut_ptr(&mut rwork),
&mut info,
)
};
info.as_lapack_result()?;

// Hermite conjugate
if jobvl == b'V' {
if jobvl.is_calc() {
for c in vl.as_mut().unwrap().iter_mut() {
c.im = -c.im
}
Expand All @@ -114,16 +109,16 @@ macro_rules! impl_eig_complex {
};
}

impl_eig_complex!(c64, lapack::zgeev);
impl_eig_complex!(c32, lapack::cgeev);
impl_eig_complex!(c64, lapack_sys::zgeev_);
impl_eig_complex!(c32, lapack_sys::cgeev_);

macro_rules! impl_eig_real {
($scalar:ty, $ev:path) => {
impl Eig_ for $scalar {
fn eig(
calc_v: bool,
l: MatrixLayout,
mut a: &mut [Self],
a: &mut [Self],
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
let (n, _) = l.size();
// LAPACK assumes a column-major input. A row-major input can
Expand All @@ -144,67 +139,62 @@ macro_rules! impl_eig_real {
// `sgeev`/`dgeev`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (b'V', b'N'),
MatrixLayout::F { .. } => (b'N', b'V'),
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
}
} else {
(b'N', b'N')
(EigenVectorFlag::Not, EigenVectorFlag::Not)
};
let mut eig_re = unsafe { vec_uninit(n as usize) };
let mut eig_im = unsafe { vec_uninit(n as usize) };
let mut eig_re: Vec<Self> = unsafe { vec_uninit(n as usize) };
let mut eig_im: Vec<Self> = unsafe { vec_uninit(n as usize) };

let mut vl = if jobvl == b'V' {
Some(unsafe { vec_uninit((n * n) as usize) })
} else {
None
};
let mut vr = if jobvr == b'V' {
Some(unsafe { vec_uninit((n * n) as usize) })
} else {
None
};
let mut vl: Option<Vec<Self>> =
jobvl.then(|| unsafe { vec_uninit((n * n) as usize) });
let mut vr: Option<Vec<Self>> =
jobvr.then(|| unsafe { vec_uninit((n * n) as usize) });

// calc work size
let mut info = 0;
let mut work_size = [0.0];
let mut work_size: [Self; 1] = [0.0];
unsafe {
$ev(
jobvl,
jobvr,
n,
&mut a,
n,
&mut eig_re,
&mut eig_im,
vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work_size,
-1,
jobvl.as_ptr(),
jobvr.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(&mut eig_re),
AsPtr::as_mut_ptr(&mut eig_im),
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
&mut info,
)
};
info.as_lapack_result()?;

// actual ev
let lwork = work_size[0].to_usize().unwrap();
let mut work = unsafe { vec_uninit(lwork) };
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
let lwork = lwork as i32;
unsafe {
$ev(
jobvl,
jobvr,
n,
&mut a,
n,
&mut eig_re,
&mut eig_im,
vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work,
lwork as i32,
jobvl.as_ptr(),
jobvr.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(&mut eig_re),
AsPtr::as_mut_ptr(&mut eig_im),
AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work),
&lwork,
&mut info,
)
};
Expand Down Expand Up @@ -254,7 +244,7 @@ macro_rules! impl_eig_real {
for row in 0..n {
let re = v[row + col * n];
let mut im = v[row + (col + 1) * n];
if jobvl == b'V' {
if jobvl.is_calc() {
im = -im;
}
eigvecs[row + col * n] = Self::complex(re, im);
Expand All @@ -270,5 +260,5 @@ macro_rules! impl_eig_real {
};
}

impl_eig_real!(f64, lapack::dgeev);
impl_eig_real!(f32, lapack::sgeev);
impl_eig_real!(f64, lapack_sys::dgeev_);
impl_eig_real!(f32, lapack_sys::sgeev_);

0 comments on commit 791713f

Please sign in to comment.