From ce962f14fb53aa2f92741f4f2ff8ed27fd413b6e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 15:08:56 +0900 Subject: [PATCH 01/21] Add lapack-sys --- lax/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/lax/Cargo.toml b/lax/Cargo.toml index 1ed31f73..d62f2064 100644 --- a/lax/Cargo.toml +++ b/lax/Cargo.toml @@ -33,6 +33,7 @@ thiserror = "1.0.24" cauchy = "0.4.0" num-traits = "0.2.14" lapack = "0.18.0" +lapack-sys = "0.14.0" [dependencies.intel-mkl-src] version = "0.7.0" From 328abd3524d97517a75a336c01bb70c46942872e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 15:46:49 +0900 Subject: [PATCH 02/21] Add `as_ptr` for primitive enums --- lax/src/lib.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 41c15237..c04efe60 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -141,6 +141,11 @@ impl UPLO { UPLO::Lower => UPLO::Upper, } } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const UPLO as *const i8 + } } #[derive(Debug, Clone, Copy)] @@ -151,6 +156,13 @@ pub enum Transpose { Hermite = b'C', } +impl Transpose { + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const Transpose as *const i8 + } +} + #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum NormType { @@ -167,6 +179,11 @@ impl NormType { NormType::Frobenius => NormType::Frobenius, } } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const NormType as *const i8 + } } /// Create a vector without initialization From 6d3924b97b0139884406da2d2bc5268eba91dc3e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 16:14:55 +0900 Subject: [PATCH 03/21] `AsPtr` helper trait to get pointer of slice --- lax/src/lib.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index c04efe60..00aacecd 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -126,6 +126,31 @@ impl Lapack for f64 {} impl Lapack for c32 {} impl Lapack for c64 {} +/// Helper for getting pointer of slice +pub(crate) trait AsPtr: Sized { + type Elem; + fn as_ptr(vec: &[Self]) -> *const Self::Elem; + fn as_mut_ptr(vec: &mut [Self]) -> *mut Self::Elem; +} + +macro_rules! impl_as_ptr { + ($target:ty, $elem:ty) => { + impl AsPtr for $target { + type Elem = $elem; + fn as_ptr(vec: &[Self]) -> *const Self::Elem { + vec.as_ptr() as *const _ + } + fn as_mut_ptr(vec: &mut [Self]) -> *mut Self::Elem { + vec.as_mut_ptr() as *mut _ + } + } + }; +} +impl_as_ptr!(f32, f32); +impl_as_ptr!(f64, f64); +impl_as_ptr!(c32, lapack_sys::__BindgenComplex); +impl_as_ptr!(c64, lapack_sys::__BindgenComplex); + /// Upper/Lower specification for seveal usages #[derive(Debug, Clone, Copy)] #[repr(u8)] From 7dd8e258435620ce559997d9536a79e68962b4a6 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 16:15:41 +0900 Subject: [PATCH 04/21] Use lapack_sys in OperatorNorm_ trait impl --- lax/src/opnorm.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index dd84f441..ddcc2c85 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -1,6 +1,6 @@ //! Operator norms of matrices -use super::NormType; +use super::{AsPtr, NormType}; use crate::{layout::MatrixLayout, *}; use cauchy::*; @@ -18,18 +18,27 @@ macro_rules! impl_opnorm { MatrixLayout::F { .. } => t, MatrixLayout::C { .. } => t.transpose(), }; - let mut work = if matches!(t, NormType::Infinity) { + let mut work: Vec = if matches!(t, NormType::Infinity) { unsafe { vec_uninit(m as usize) } } else { Vec::new() }; - unsafe { $lange(t as u8, m, n, a, m, &mut work) } + unsafe { + $lange( + t.as_ptr(), + &m, + &n, + AsPtr::as_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut work), + ) + } } } }; } // impl_opnorm! -impl_opnorm!(f64, lapack::dlange); -impl_opnorm!(f32, lapack::slange); -impl_opnorm!(c64, lapack::zlange); -impl_opnorm!(c32, lapack::clange); +impl_opnorm!(f64, lapack_sys::dlange_); +impl_opnorm!(f32, lapack_sys::slange_); +impl_opnorm!(c64, lapack_sys::zlange_); +impl_opnorm!(c32, lapack_sys::clange_); From b96462591a3fb3039fd6667cc34eae05a94cf341 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 17:23:36 +0900 Subject: [PATCH 05/21] Use lapack_sys in Cholesky_ --- lax/src/cholesky.rs | 43 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 8305efe5..94e683ff 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -29,7 +29,7 @@ macro_rules! impl_cholesky { } let mut info = 0; unsafe { - $trf(uplo as u8, n, a, n, &mut info); + $trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -45,7 +45,7 @@ macro_rules! impl_cholesky { } let mut info = 0; unsafe { - $tri(uplo as u8, n, a, l.lda(), &mut info); + $tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -70,7 +70,16 @@ macro_rules! impl_cholesky { } } unsafe { - $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); + $trs( + uplo.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(a), + &l.lda(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -84,7 +93,27 @@ macro_rules! impl_cholesky { }; } // end macro_rules -impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); -impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); -impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); -impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); +impl_cholesky!( + f64, + lapack_sys::dpotrf_, + lapack_sys::dpotri_, + lapack_sys::dpotrs_ +); +impl_cholesky!( + f32, + lapack_sys::spotrf_, + lapack_sys::spotri_, + lapack_sys::spotrs_ +); +impl_cholesky!( + c64, + lapack_sys::zpotrf_, + lapack_sys::zpotri_, + lapack_sys::zpotrs_ +); +impl_cholesky!( + c32, + lapack_sys::cpotrf_, + lapack_sys::cpotri_, + lapack_sys::cpotrs_ +); From c0c3a374c05c614f2894949b0f99275a864116d4 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 17:37:32 +0900 Subject: [PATCH 06/21] Introduce EigenVectorFlag --- lax/src/eig.rs | 40 ++++++++++++++++++++-------------------- lax/src/eigh.rs | 12 ++++++------ lax/src/lib.rs | 15 +++++++++++++++ 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index a6d6a16e..1f15ddfe 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -35,21 +35,21 @@ macro_rules! impl_eig_complex { // eigenvalues are the eigenvalues computed with `A`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (b'V', b'N'), - MatrixLayout::F { .. } => (b'N', b'V'), + MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), + MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), } } else { - (b'N', b'N') + (EigenVectorFlag::Not, EigenVectorFlag::Not) }; let mut eigs = unsafe { vec_uninit(n as usize) }; let mut rwork = unsafe { vec_uninit(2 * n as usize) }; - let mut vl = if jobvl == b'V' { + let mut vl = if jobvl == EigenVectorFlag::Calc { Some(unsafe { vec_uninit((n * n) as usize) }) } else { None }; - let mut vr = if jobvr == b'V' { + let mut vr = if jobvr == EigenVectorFlag::Calc { Some(unsafe { vec_uninit((n * n) as usize) }) } else { None @@ -60,8 +60,8 @@ macro_rules! impl_eig_complex { let mut work_size = [Self::zero()]; unsafe { $ev( - jobvl, - jobvr, + jobvl as u8, + jobvr as u8, n, &mut a, n, @@ -83,8 +83,8 @@ macro_rules! impl_eig_complex { let mut work = unsafe { vec_uninit(lwork) }; unsafe { $ev( - jobvl, - jobvr, + jobvl as u8, + jobvr as u8, n, &mut a, n, @@ -102,7 +102,7 @@ macro_rules! impl_eig_complex { info.as_lapack_result()?; // Hermite conjugate - if jobvl == b'V' { + if jobvl == EigenVectorFlag::Calc { for c in vl.as_mut().unwrap().iter_mut() { c.im = -c.im } @@ -144,21 +144,21 @@ macro_rules! impl_eig_real { // `sgeev`/`dgeev`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (b'V', b'N'), - MatrixLayout::F { .. } => (b'N', b'V'), + MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), + MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), } } else { - (b'N', b'N') + (EigenVectorFlag::Not, EigenVectorFlag::Not) }; let mut eig_re = unsafe { vec_uninit(n as usize) }; let mut eig_im = unsafe { vec_uninit(n as usize) }; - let mut vl = if jobvl == b'V' { + let mut vl = if jobvl == EigenVectorFlag::Calc { Some(unsafe { vec_uninit((n * n) as usize) }) } else { None }; - let mut vr = if jobvr == b'V' { + let mut vr = if jobvr == EigenVectorFlag::Calc { Some(unsafe { vec_uninit((n * n) as usize) }) } else { None @@ -169,8 +169,8 @@ macro_rules! impl_eig_real { let mut work_size = [0.0]; unsafe { $ev( - jobvl, - jobvr, + jobvl as u8, + jobvr as u8, n, &mut a, n, @@ -192,8 +192,8 @@ macro_rules! impl_eig_real { let mut work = unsafe { vec_uninit(lwork) }; unsafe { $ev( - jobvl, - jobvr, + jobvl as u8, + jobvr as u8, n, &mut a, n, @@ -254,7 +254,7 @@ macro_rules! impl_eig_real { for row in 0..n { let re = v[row + col * n]; let mut im = v[row + (col + 1) * n]; - if jobvl == b'V' { + if jobvl == EigenVectorFlag::Calc { im = -im; } eigvecs[row + col * n] = Self::complex(re, im); diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index ad8963dc..238c95c2 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -41,7 +41,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; + let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; let mut eigs = unsafe { vec_uninit(n as usize) }; $( @@ -53,7 +53,7 @@ macro_rules! impl_eigh { let mut work_size = [Self::zero()]; unsafe { $ev( - jobz, + jobz as u8, uplo as u8, n, &mut a, @@ -72,7 +72,7 @@ macro_rules! impl_eigh { let mut work = unsafe { vec_uninit(lwork) }; unsafe { $ev( - jobz, + jobz as u8, uplo as u8, n, &mut a, @@ -97,7 +97,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; + let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; let mut eigs = unsafe { vec_uninit(n as usize) }; $( @@ -110,7 +110,7 @@ macro_rules! impl_eigh { unsafe { $evg( &[1], - jobz, + jobz as u8, uplo as u8, n, &mut a, @@ -132,7 +132,7 @@ macro_rules! impl_eigh { unsafe { $evg( &[1], - jobz, + jobz as u8, uplo as u8, n, &mut a, diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 00aacecd..a6d838b7 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -211,6 +211,21 @@ impl NormType { } } +/// Flag for calculating eigenvectors or not +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum EigenVectorFlag { + Calc = b'V', + Not = b'N', +} + +impl EigenVectorFlag { + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const EigenVectorFlag as *const i8 + } +} + /// Create a vector without initialization /// /// Safety From 25dc071c7efa751e26510eeec44febf667331b97 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 17:57:52 +0900 Subject: [PATCH 07/21] Use lapack_sys in Eig_ impl for real --- lax/src/eig.rs | 71 +++++++++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index 1f15ddfe..dae63240 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -123,7 +123,7 @@ macro_rules! impl_eig_real { fn eig( calc_v: bool, l: MatrixLayout, - mut a: &mut [Self], + a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); // LAPACK assumes a column-major input. A row-major input can @@ -150,15 +150,15 @@ macro_rules! impl_eig_real { } else { (EigenVectorFlag::Not, EigenVectorFlag::Not) }; - let mut eig_re = unsafe { vec_uninit(n as usize) }; - let mut eig_im = unsafe { vec_uninit(n as usize) }; + let mut eig_re: Vec = unsafe { vec_uninit(n as usize) }; + let mut eig_im: Vec = unsafe { vec_uninit(n as usize) }; - let mut vl = if jobvl == EigenVectorFlag::Calc { + let mut vl: Option> = if jobvl == EigenVectorFlag::Calc { Some(unsafe { vec_uninit((n * n) as usize) }) } else { None }; - let mut vr = if jobvr == EigenVectorFlag::Calc { + let mut vr: Option> = if jobvr == EigenVectorFlag::Calc { Some(unsafe { vec_uninit((n * n) as usize) }) } else { None @@ -166,22 +166,22 @@ macro_rules! impl_eig_real { // calc work size let mut info = 0; - let mut work_size = [0.0]; + let mut work_size: [Self; 1] = [0.0]; unsafe { $ev( - jobvl as u8, - jobvr as u8, - n, - &mut a, - n, - &mut eig_re, - &mut eig_im, - vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eig_re), + AsPtr::as_mut_ptr(&mut eig_im), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), &mut info, ) }; @@ -189,22 +189,23 @@ macro_rules! impl_eig_real { // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $ev( - jobvl as u8, - jobvr as u8, - n, - &mut a, - n, - &mut eig_re, - &mut eig_im, - vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eig_re), + AsPtr::as_mut_ptr(&mut eig_im), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work), + &lwork, &mut info, ) }; @@ -270,5 +271,5 @@ macro_rules! impl_eig_real { }; } -impl_eig_real!(f64, lapack::dgeev); -impl_eig_real!(f32, lapack::sgeev); +impl_eig_real!(f64, lapack_sys::dgeev_); +impl_eig_real!(f32, lapack_sys::sgeev_); From e21e932298240205c991e74eb537d40f2139cbcb Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 18:10:24 +0900 Subject: [PATCH 08/21] Introduce EigenVectorFlag::then --- lax/src/eig.rs | 30 ++++++++---------------------- lax/src/lib.rs | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index dae63240..eff4591a 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -44,16 +44,8 @@ macro_rules! impl_eig_complex { let mut eigs = unsafe { vec_uninit(n as usize) }; let mut rwork = unsafe { vec_uninit(2 * n as usize) }; - let mut vl = if jobvl == EigenVectorFlag::Calc { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; - let mut vr = if jobvr == EigenVectorFlag::Calc { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; + let mut vl = jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); + let mut vr = jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); // calc work size let mut info = 0; @@ -102,7 +94,7 @@ macro_rules! impl_eig_complex { info.as_lapack_result()?; // Hermite conjugate - if jobvl == EigenVectorFlag::Calc { + if jobvl.is_calc() { for c in vl.as_mut().unwrap().iter_mut() { c.im = -c.im } @@ -153,16 +145,10 @@ macro_rules! impl_eig_real { let mut eig_re: Vec = unsafe { vec_uninit(n as usize) }; let mut eig_im: Vec = unsafe { vec_uninit(n as usize) }; - let mut vl: Option> = if jobvl == EigenVectorFlag::Calc { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; - let mut vr: Option> = if jobvr == EigenVectorFlag::Calc { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; + let mut vl: Option> = + jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); + let mut vr: Option> = + jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); // calc work size let mut info = 0; @@ -255,7 +241,7 @@ macro_rules! impl_eig_real { for row in 0..n { let re = v[row + col * n]; let mut im = v[row + (col + 1) * n]; - if jobvl == EigenVectorFlag::Calc { + if jobvl.is_calc() { im = -im; } eigvecs[row + col * n] = Self::complex(re, im); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index a6d838b7..26b740bb 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -220,6 +220,21 @@ pub enum EigenVectorFlag { } impl EigenVectorFlag { + pub fn is_calc(&self) -> bool { + match self { + EigenVectorFlag::Calc => true, + EigenVectorFlag::Not => false, + } + } + + pub fn then T>(&self, f: F) -> Option { + if self.is_calc() { + Some(f()) + } else { + None + } + } + /// To use Fortran LAPACK API in lapack-sys crate pub fn as_ptr(&self) -> *const i8 { self as *const EigenVectorFlag as *const i8 From 3c4dc5ea2d2b586c0560443ad8d5f94a795c1f6f Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 18:18:43 +0900 Subject: [PATCH 09/21] Use lapack_sys in Eig_ impl for complex --- lax/src/eig.rs | 69 ++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index eff4591a..bf3b9f8c 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -20,7 +20,7 @@ macro_rules! impl_eig_complex { fn eig( calc_v: bool, l: MatrixLayout, - mut a: &mut [Self], + a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); // LAPACK assumes a column-major input. A row-major input can @@ -42,29 +42,31 @@ macro_rules! impl_eig_complex { (EigenVectorFlag::Not, EigenVectorFlag::Not) }; let mut eigs = unsafe { vec_uninit(n as usize) }; - let mut rwork = unsafe { vec_uninit(2 * n as usize) }; + let mut rwork: Vec = unsafe { vec_uninit(2 * n as usize) }; - let mut vl = jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); - let mut vr = jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); + let mut vl: Option> = + jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); + let mut vr: Option> = + jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); // calc work size let mut info = 0; let mut work_size = [Self::zero()]; unsafe { $ev( - jobvl as u8, - jobvr as u8, - n, - &mut a, - n, - &mut eigs, - &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - &mut rwork, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), &mut info, ) }; @@ -72,22 +74,23 @@ macro_rules! impl_eig_complex { // actal ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $ev( - jobvl as u8, - jobvr as u8, - n, - &mut a, - n, - &mut eigs, - &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - &mut rwork, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work), + &lwork, + AsPtr::as_mut_ptr(&mut rwork), &mut info, ) }; @@ -106,8 +109,8 @@ macro_rules! impl_eig_complex { }; } -impl_eig_complex!(c64, lapack::zgeev); -impl_eig_complex!(c32, lapack::cgeev); +impl_eig_complex!(c64, lapack_sys::zgeev_); +impl_eig_complex!(c32, lapack_sys::cgeev_); macro_rules! impl_eig_real { ($scalar:ty, $ev:path) => { From 4bf441472401f777b74394ab72fa7bbf0f95458f Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 18:46:16 +0900 Subject: [PATCH 10/21] Use lapack_sys in Eigh_ --- lax/src/eigh.rs | 108 ++++++++++++++++++++++++------------------------ 1 file changed, 55 insertions(+), 53 deletions(-) diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index 238c95c2..a8403e90 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -37,7 +37,7 @@ macro_rules! impl_eigh { calc_v: bool, layout: MatrixLayout, uplo: UPLO, - mut a: &mut [Self], + a: &mut [Self], ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); @@ -45,7 +45,7 @@ macro_rules! impl_eigh { let mut eigs = unsafe { vec_uninit(n as usize) }; $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; )* // calc work size @@ -53,15 +53,15 @@ macro_rules! impl_eigh { let mut work_size = [Self::zero()]; unsafe { $ev( - jobz as u8, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + jobz.as_ptr() , + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -69,18 +69,19 @@ macro_rules! impl_eigh { // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $ev( - jobz as u8, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + jobz.as_ptr(), + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work), + &lwork, + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -92,8 +93,8 @@ macro_rules! impl_eigh { calc_v: bool, layout: MatrixLayout, uplo: UPLO, - mut a: &mut [Self], - mut b: &mut [Self], + a: &mut [Self], + b: &mut [Self], ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); @@ -101,7 +102,7 @@ macro_rules! impl_eigh { let mut eigs = unsafe { vec_uninit(n as usize) }; $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit(3 * n as usize - 2) }; )* // calc work size @@ -109,18 +110,18 @@ macro_rules! impl_eigh { let mut work_size = [Self::zero()]; unsafe { $evg( - &[1], - jobz as u8, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + &1, // ITYPE A*x = (lambda)*B*x + jobz.as_ptr(), + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(b), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -128,21 +129,22 @@ macro_rules! impl_eigh { // actual evg let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $evg( - &[1], - jobz as u8, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + &1, // ITYPE A*x = (lambda)*B*x + jobz.as_ptr(), + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(b), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work), + &lwork, + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -153,7 +155,7 @@ macro_rules! impl_eigh { }; } // impl_eigh! -impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv); -impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv); -impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv); -impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv); +impl_eigh!(@real, f64, lapack_sys::dsyev_, lapack_sys::dsygv_); +impl_eigh!(@real, f32, lapack_sys::ssyev_, lapack_sys::ssygv_); +impl_eigh!(@complex, c64, lapack_sys::zheev_, lapack_sys::zhegv_); +impl_eigh!(@complex, c32, lapack_sys::cheev_, lapack_sys::chegv_); From 6be0ad1b2459dd9c4f6ed423857de271677986fb Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 19:01:01 +0900 Subject: [PATCH 11/21] Use lapack_sys in least_squares.rs --- lax/src/least_squares.rs | 66 ++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index fc378aa6..97f9a839 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -97,20 +97,20 @@ macro_rules! impl_least_squares { )* unsafe { $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, + &m, + &n, + &nrhs, + AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)), + &a_layout.lda(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &b_layout.lda(), + AsPtr::as_mut_ptr(&mut singular_values), + &rcond, &mut rank, - &mut work_size, - -1, - $(&mut $rwork,)* - &mut iwork_size, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork),)* + iwork_size.as_mut_ptr(), &mut info, ) }; @@ -118,29 +118,29 @@ macro_rules! impl_least_squares { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; let liwork = iwork_size[0].to_usize().unwrap(); - let mut iwork = unsafe { vec_uninit( liwork) }; + let mut iwork = unsafe { vec_uninit(liwork) }; $( let lrwork = $rwork[0].to_usize().unwrap(); - let mut $rwork = unsafe { vec_uninit( lrwork) }; + let mut $rwork: Vec = unsafe { vec_uninit(lrwork) }; )* unsafe { $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, + &m, + &n, + &nrhs, + AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)), + &a_layout.lda(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &b_layout.lda(), + AsPtr::as_mut_ptr(&mut singular_values), + &rcond, &mut rank, - &mut work, - lwork as i32, - $(&mut $rwork,)* - &mut iwork, + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + $(AsPtr::as_mut_ptr(&mut $rwork),)* + iwork.as_mut_ptr(), &mut info, ); } @@ -161,7 +161,7 @@ macro_rules! impl_least_squares { }; } -impl_least_squares!(@real, f64, lapack::dgelsd); -impl_least_squares!(@real, f32, lapack::sgelsd); -impl_least_squares!(@complex, c64, lapack::zgelsd); -impl_least_squares!(@complex, c32, lapack::cgelsd); +impl_least_squares!(@real, f64, lapack_sys::dgelsd_); +impl_least_squares!(@real, f32, lapack_sys::sgelsd_); +impl_least_squares!(@complex, c64, lapack_sys::zgelsd_); +impl_least_squares!(@complex, c32, lapack_sys::cgelsd_); From b470f5f6d66685d02f2bff13faa397a8a447325a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 19:05:28 +0900 Subject: [PATCH 12/21] Use lapack_sys in rcond.rs --- lax/src/rcond.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs index 91d7458c..75a93a9f 100644 --- a/lax/src/rcond.rs +++ b/lax/src/rcond.rs @@ -54,22 +54,22 @@ macro_rules! impl_rcond_complex { let (n, _) = l.size(); let mut rcond = Self::Real::zero(); let mut info = 0; - let mut work = unsafe { vec_uninit(2 * n as usize) }; - let mut rwork = unsafe { vec_uninit(2 * n as usize) }; + let mut work: Vec = unsafe { vec_uninit(2 * n as usize) }; + let mut rwork: Vec = unsafe { vec_uninit(2 * n as usize) }; let norm_type = match l { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, - } as u8; + }; unsafe { $gecon( - norm_type, - n, - a, - l.lda(), - anorm, + norm_type.as_ptr(), + &n, + AsPtr::as_ptr(a), + &l.lda(), + &anorm, &mut rcond, - &mut work, - &mut rwork, + AsPtr::as_mut_ptr(&mut work), + AsPtr::as_mut_ptr(&mut rwork), &mut info, ) }; @@ -81,5 +81,5 @@ macro_rules! impl_rcond_complex { }; } -impl_rcond_complex!(c32, lapack::cgecon); -impl_rcond_complex!(c64, lapack::zgecon); +impl_rcond_complex!(c32, lapack_sys::cgecon_); +impl_rcond_complex!(c64, lapack_sys::zgecon_); From 68b8ae54c44704ba13db455dce1c0e0f1a1a6357 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 19:09:34 +0900 Subject: [PATCH 13/21] Use lapack_sys in solve.rs --- lax/src/solve.rs | 79 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index e09a7e8d..9c19c874 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -35,7 +35,16 @@ macro_rules! impl_solve { let k = ::std::cmp::min(row, col); let mut ipiv = unsafe { vec_uninit(k as usize) }; let mut info = 0; - unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; + unsafe { + $getrf( + &l.lda(), + &l.len(), + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_mut_ptr(), + &mut info, + ) + }; info.as_lapack_result()?; Ok(ipiv) } @@ -50,20 +59,30 @@ macro_rules! impl_solve { // calc work size let mut info = 0; let mut work_size = [Self::zero()]; - unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; + unsafe { + $getri( + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; info.as_lapack_result()?; // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { $getri( - l.len(), - a, - l.lda(), - ipiv, - &mut work, - lwork as i32, + &l.len(), + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ) }; @@ -116,7 +135,19 @@ macro_rules! impl_solve { *b_elem = b_elem.conj(); } } - unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; + unsafe { + $getrs( + t.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b), + &ldb, + &mut info, + ) + }; if conj { for b_elem in &mut *b { *b_elem = b_elem.conj(); @@ -129,7 +160,27 @@ macro_rules! impl_solve { }; } // impl_solve! -impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); -impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); -impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); -impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); +impl_solve!( + f64, + lapack_sys::dgetrf_, + lapack_sys::dgetri_, + lapack_sys::dgetrs_ +); +impl_solve!( + f32, + lapack_sys::sgetrf_, + lapack_sys::sgetri_, + lapack_sys::sgetrs_ +); +impl_solve!( + c64, + lapack_sys::zgetrf_, + lapack_sys::zgetri_, + lapack_sys::zgetrs_ +); +impl_solve!( + c32, + lapack_sys::cgetrf_, + lapack_sys::cgetri_, + lapack_sys::cgetrs_ +); From f155620c58646edda38aadbffb6b18a8e6a62fe2 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 19:25:06 +0900 Subject: [PATCH 14/21] Use lapack_sys in solveh.rs --- lax/src/solveh.rs | 86 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 22 deletions(-) diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 1a4d6e3e..c5259dda 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -30,13 +30,13 @@ macro_rules! impl_solveh { let mut work_size = [Self::zero()]; unsafe { $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work_size, - -1, + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), &mut info, ) }; @@ -44,16 +44,16 @@ macro_rules! impl_solveh { // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work, - lwork as i32, + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ) }; @@ -64,8 +64,18 @@ macro_rules! impl_solveh { fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); let mut info = 0; - let mut work = unsafe { vec_uninit(n as usize) }; - unsafe { $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info) }; + let mut work: Vec = unsafe { vec_uninit(n as usize) }; + unsafe { + $tri( + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut work), + &mut info, + ) + }; info.as_lapack_result()?; Ok(()) } @@ -79,7 +89,19 @@ macro_rules! impl_solveh { ) -> Result<()> { let (n, _) = l.size(); let mut info = 0; - unsafe { $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info) }; + unsafe { + $trs( + uplo.as_ptr(), + &n, + &1, + AsPtr::as_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ) + }; info.as_lapack_result()?; Ok(()) } @@ -87,7 +109,27 @@ macro_rules! impl_solveh { }; } // impl_solveh! -impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs); -impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs); -impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs); -impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs); +impl_solveh!( + f64, + lapack_sys::dsytrf_, + lapack_sys::dsytri_, + lapack_sys::dsytrs_ +); +impl_solveh!( + f32, + lapack_sys::ssytrf_, + lapack_sys::ssytri_, + lapack_sys::ssytrs_ +); +impl_solveh!( + c64, + lapack_sys::zhetrf_, + lapack_sys::zhetri_, + lapack_sys::zhetrs_ +); +impl_solveh!( + c32, + lapack_sys::chetrf_, + lapack_sys::chetri_, + lapack_sys::chetrs_ +); From 0ab9c599da20004e7a1b5e92a5f7b5bbb36a8bf8 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 20:00:22 +0900 Subject: [PATCH 15/21] Use lapack_sys in svd.rs --- lax/src/svd.rs | 74 ++++++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/lax/src/svd.rs b/lax/src/svd.rs index c990cd27..0ee56428 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -21,6 +21,10 @@ impl FlagSVD { FlagSVD::No } } + + fn as_ptr(&self) -> *const i8 { + self as *const FlagSVD as *const i8 + } } /// Result of SVD @@ -49,7 +53,7 @@ macro_rules! impl_svd { }; (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => { impl SVD_ for $scalar { - fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result> { + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result> { let ju = match l { MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), @@ -75,7 +79,7 @@ macro_rules! impl_svd { let mut s = unsafe { vec_uninit( k as usize) }; $( - let mut $rwork_ident = unsafe { vec_uninit( 5 * k as usize) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit( 5 * k as usize) }; )* // eval work size @@ -83,20 +87,20 @@ macro_rules! impl_svd { let mut work_size = [Self::zero()]; unsafe { $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + ju.as_ptr(), + jvt.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -104,23 +108,23 @@ macro_rules! impl_svd { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let mut work: Vec = unsafe { vec_uninit( lwork) }; unsafe { $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + ju.as_ptr(), + jvt.as_ptr() , + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -134,7 +138,7 @@ macro_rules! impl_svd { }; } // impl_svd! -impl_svd!(@real, f64, lapack::dgesvd); -impl_svd!(@real, f32, lapack::sgesvd); -impl_svd!(@complex, c64, lapack::zgesvd); -impl_svd!(@complex, c32, lapack::cgesvd); +impl_svd!(@real, f64, lapack_sys::dgesvd_); +impl_svd!(@real, f32, lapack_sys::sgesvd_); +impl_svd!(@complex, c64, lapack_sys::zgesvd_); +impl_svd!(@complex, c32, lapack_sys::cgesvd_); From aff89b7e4254791fc1d28d1af79359374f0a8447 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 20:16:50 +0900 Subject: [PATCH 16/21] Use lapack_sys in svddc.rs --- lax/src/svddc.rs | 76 ++++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index a94bdace..c1198286 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -16,6 +16,12 @@ pub enum UVTFlag { None = b'N', } +impl UVTFlag { + fn as_ptr(&self) -> *const i8 { + self as *const UVTFlag as *const i8 + } +} + pub trait SVDDC_: Scalar { fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } @@ -29,7 +35,7 @@ macro_rules! impl_svddc { }; (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result> { + fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self],) -> Result> { let m = l.lda(); let n = l.len(); let k = m.min(n); @@ -58,7 +64,7 @@ macro_rules! impl_svddc { UVTFlag::None => 7 * mn, _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), }; - let mut $rwork_ident = unsafe { vec_uninit( lrwork) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit( lrwork) }; )* // eval work size @@ -67,20 +73,20 @@ macro_rules! impl_svddc { let mut work_size = [Self::zero()]; unsafe { $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work_size, - -1, - $(&mut $rwork_ident,)* - &mut iwork, + jobz.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &vt_row, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* + iwork.as_mut_ptr(), &mut info, ); } @@ -88,23 +94,23 @@ macro_rules! impl_svddc { // do svd let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let mut work: Vec = unsafe { vec_uninit( lwork) }; unsafe { $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* - &mut iwork, + jobz.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &vt_row, + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* + iwork.as_mut_ptr(), &mut info, ); } @@ -119,7 +125,7 @@ macro_rules! impl_svddc { }; } -impl_svddc!(@real, f32, lapack::sgesdd); -impl_svddc!(@real, f64, lapack::dgesdd); -impl_svddc!(@complex, c32, lapack::cgesdd); -impl_svddc!(@complex, c64, lapack::zgesdd); +impl_svddc!(@real, f32, lapack_sys::sgesdd_); +impl_svddc!(@real, f64, lapack_sys::dgesdd_); +impl_svddc!(@complex, c32, lapack_sys::cgesdd_); +impl_svddc!(@complex, c64, lapack_sys::zgesdd_); From e15c93c1628ac4741ab4be0905826976e5ae64f5 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 20:24:35 +0900 Subject: [PATCH 17/21] Use lapack_sys in triangular.rs --- lax/src/triangular.rs | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index a48b12b3..0288d6ba 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -10,6 +10,12 @@ pub enum Diag { NonUnit = b'N', } +impl Diag { + fn as_ptr(&self) -> *const i8 { + self as *const Diag as *const i8 + } +} + /// Wraps `*trtri` and `*trtrs` pub trait Triangular_: Scalar { fn solve_triangular( @@ -60,15 +66,15 @@ macro_rules! impl_triangular { let mut info = 0; unsafe { $trtrs( - uplo as u8, - Transpose::No as u8, - diag as u8, - m, - nrhs, - a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), + uplo.as_ptr(), + Transpose::No.as_ptr(), + diag.as_ptr(), + &m, + &nrhs, + AsPtr::as_ptr(a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a)), + &a_layout.lda(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &b_layout.lda(), &mut info, ); } @@ -84,7 +90,7 @@ macro_rules! impl_triangular { }; } // impl_triangular! -impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs); -impl_triangular!(f32, lapack::strtri, lapack::strtrs); -impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs); -impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs); +impl_triangular!(f64, lapack_sys::dtrtri_, lapack_sys::dtrtrs_); +impl_triangular!(f32, lapack_sys::strtri_, lapack_sys::strtrs_); +impl_triangular!(c64, lapack_sys::ztrtri_, lapack_sys::ztrtrs_); +impl_triangular!(c32, lapack_sys::ctrtri_, lapack_sys::ctrtrs_); From bd6f3feab1a01742020f055cc2e8807ce1c5c0d9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 21:04:16 +0900 Subject: [PATCH 18/21] Use lapack_sys in tridiagonal.rs --- lax/src/tridiagonal.rs | 62 ++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 2b1d98e2..c80ad4b5 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -157,7 +157,17 @@ macro_rules! impl_tridiagonal { // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); let mut info = 0; - unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) }; + unsafe { + $gttrf( + &n, + AsPtr::as_mut_ptr(&mut a.dl), + AsPtr::as_mut_ptr(&mut a.d), + AsPtr::as_mut_ptr(&mut a.du), + AsPtr::as_mut_ptr(&mut du2), + ipiv.as_mut_ptr(), + &mut info, + ) + }; info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, @@ -170,7 +180,7 @@ macro_rules! impl_tridiagonal { fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; - let mut work = unsafe { vec_uninit( 2 * n as usize) }; + let mut work: Vec = unsafe { vec_uninit( 2 * n as usize) }; $( let mut $iwork = unsafe { vec_uninit( n as usize) }; )* @@ -178,17 +188,17 @@ macro_rules! impl_tridiagonal { let mut info = 0; unsafe { $gtcon( - NormType::One as u8, - n, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - lu.a_opnorm_one, + NormType::One.as_ptr(), + &n, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + &lu.a_opnorm_one, &mut rcond, - &mut work, - $(&mut $iwork,)* + AsPtr::as_mut_ptr(&mut work), + $($iwork.as_mut_ptr(),)* &mut info, ); } @@ -217,16 +227,16 @@ macro_rules! impl_tridiagonal { let mut info = 0; unsafe { $gttrs( - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - ldb, + t.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &ldb, &mut info, ); } @@ -240,7 +250,7 @@ macro_rules! impl_tridiagonal { }; } // impl_tridiagonal! -impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs); -impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs); -impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs); -impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs); +impl_tridiagonal!(@real, f64, lapack_sys::dgttrf_, lapack_sys::dgtcon_, lapack_sys::dgttrs_); +impl_tridiagonal!(@real, f32, lapack_sys::sgttrf_, lapack_sys::sgtcon_, lapack_sys::sgttrs_); +impl_tridiagonal!(@complex, c64, lapack_sys::zgttrf_, lapack_sys::zgtcon_, lapack_sys::zgttrs_); +impl_tridiagonal!(@complex, c32, lapack_sys::cgttrf_, lapack_sys::cgtcon_, lapack_sys::cgttrs_); From efc877d06d4335fe7c1d08a6b0d39e762e5c2db6 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 21:12:18 +0900 Subject: [PATCH 19/21] Use lapack_sys in qr.rs --- lax/src/qr.rs | 146 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 98 insertions(+), 48 deletions(-) diff --git a/lax/src/qr.rs b/lax/src/qr.rs index 6460b8b9..33de0372 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -21,7 +21,7 @@ pub trait QR_: Sized { macro_rules! impl_qr { ($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => { impl QR_ for $scalar { - fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { + fn householder(l: MatrixLayout, a: &mut [Self]) -> Result> { let m = l.lda(); let n = l.len(); let k = m.min(n); @@ -33,10 +33,28 @@ macro_rules! impl_qr { unsafe { match l { MatrixLayout::F { .. } => { - $qrf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + $qrf( + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ); } MatrixLayout::C { .. } => { - $lqf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + $lqf( + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ); } } } @@ -44,30 +62,30 @@ macro_rules! impl_qr { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { match l { MatrixLayout::F { .. } => { $qrf( - m, - n, - &mut a, - m, - &mut tau, - &mut work, - lwork as i32, + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ); } MatrixLayout::C { .. } => { $lqf( - m, - n, - &mut a, - m, - &mut tau, - &mut work, - lwork as i32, + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ); } @@ -78,7 +96,7 @@ macro_rules! impl_qr { Ok(tau) } - fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()> { let m = l.lda(); let n = l.len(); let k = m.min(n); @@ -89,26 +107,58 @@ macro_rules! impl_qr { let mut work_size = [Self::zero()]; unsafe { match l { - MatrixLayout::F { .. } => { - $gqr(m, k, k, &mut a, m, &tau, &mut work_size, -1, &mut info) - } - MatrixLayout::C { .. } => { - $glq(k, n, k, &mut a, m, &tau, &mut work_size, -1, &mut info) - } + MatrixLayout::F { .. } => $gqr( + &m, + &k, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ), + MatrixLayout::C { .. } => $glq( + &k, + &n, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ), } }; // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { match l { - MatrixLayout::F { .. } => { - $gqr(m, k, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) - } - MatrixLayout::C { .. } => { - $glq(k, n, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) - } + MatrixLayout::F { .. } => $gqr( + &m, + &k, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + &mut info, + ), + MatrixLayout::C { .. } => $glq( + &k, + &n, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + &mut info, + ), } } info.as_lapack_result()?; @@ -127,29 +177,29 @@ macro_rules! impl_qr { impl_qr!( f64, - lapack::dgeqrf, - lapack::dgelqf, - lapack::dorgqr, - lapack::dorglq + lapack_sys::dgeqrf_, + lapack_sys::dgelqf_, + lapack_sys::dorgqr_, + lapack_sys::dorglq_ ); impl_qr!( f32, - lapack::sgeqrf, - lapack::sgelqf, - lapack::sorgqr, - lapack::sorglq + lapack_sys::sgeqrf_, + lapack_sys::sgelqf_, + lapack_sys::sorgqr_, + lapack_sys::sorglq_ ); impl_qr!( c64, - lapack::zgeqrf, - lapack::zgelqf, - lapack::zungqr, - lapack::zunglq + lapack_sys::zgeqrf_, + lapack_sys::zgelqf_, + lapack_sys::zungqr_, + lapack_sys::zunglq_ ); impl_qr!( c32, - lapack::cgeqrf, - lapack::cgelqf, - lapack::cungqr, - lapack::cunglq + lapack_sys::cgeqrf_, + lapack_sys::cgelqf_, + lapack_sys::cungqr_, + lapack_sys::cunglq_ ); From 5fc0f2d73cbdb1b1706ffc75f91e070b645818a8 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 21:14:26 +0900 Subject: [PATCH 20/21] Fix rcond not to use lapack --- lax/src/rcond.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs index 75a93a9f..fcd4211f 100644 --- a/lax/src/rcond.rs +++ b/lax/src/rcond.rs @@ -17,22 +17,22 @@ macro_rules! impl_rcond_real { let mut rcond = Self::Real::zero(); let mut info = 0; - let mut work = unsafe { vec_uninit(4 * n as usize) }; + let mut work: Vec = unsafe { vec_uninit(4 * n as usize) }; let mut iwork = unsafe { vec_uninit(n as usize) }; let norm_type = match l { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, - } as u8; + }; unsafe { $gecon( - norm_type, - n, - a, - l.lda(), - anorm, + norm_type.as_ptr(), + &n, + AsPtr::as_ptr(a), + &l.lda(), + &anorm, &mut rcond, - &mut work, - &mut iwork, + AsPtr::as_mut_ptr(&mut work), + iwork.as_mut_ptr(), &mut info, ) }; @@ -44,8 +44,8 @@ macro_rules! impl_rcond_real { }; } -impl_rcond_real!(f32, lapack::sgecon); -impl_rcond_real!(f64, lapack::dgecon); +impl_rcond_real!(f32, lapack_sys::sgecon_); +impl_rcond_real!(f64, lapack_sys::dgecon_); macro_rules! impl_rcond_complex { ($scalar:ty, $gecon:path) => { From 2fc0ac815a7cc92b67883998b79da8fe785ee509 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 31 Aug 2022 21:14:36 +0900 Subject: [PATCH 21/21] Drop lapack dependency --- lax/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lax/Cargo.toml b/lax/Cargo.toml index d62f2064..a36c34ab 100644 --- a/lax/Cargo.toml +++ b/lax/Cargo.toml @@ -32,7 +32,6 @@ intel-mkl-system = ["intel-mkl-src/mkl-dynamic-lp64-seq"] thiserror = "1.0.24" cauchy = "0.4.0" num-traits = "0.2.14" -lapack = "0.18.0" lapack-sys = "0.14.0" [dependencies.intel-mkl-src]