diff --git a/lax/Cargo.toml b/lax/Cargo.toml index 1ed31f73..a36c34ab 100644 --- a/lax/Cargo.toml +++ b/lax/Cargo.toml @@ -32,7 +32,7 @@ intel-mkl-system = ["intel-mkl-src/mkl-dynamic-lp64-seq"] thiserror = "1.0.24" cauchy = "0.4.0" num-traits = "0.2.14" -lapack = "0.18.0" +lapack-sys = "0.14.0" [dependencies.intel-mkl-src] version = "0.7.0" diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 8305efe5..94e683ff 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -29,7 +29,7 @@ macro_rules! impl_cholesky { } let mut info = 0; unsafe { - $trf(uplo as u8, n, a, n, &mut info); + $trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -45,7 +45,7 @@ macro_rules! impl_cholesky { } let mut info = 0; unsafe { - $tri(uplo as u8, n, a, l.lda(), &mut info); + $tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -70,7 +70,16 @@ macro_rules! impl_cholesky { } } unsafe { - $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); + $trs( + uplo.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(a), + &l.lda(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -84,7 +93,27 @@ macro_rules! impl_cholesky { }; } // end macro_rules -impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); -impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); -impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); -impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); +impl_cholesky!( + f64, + lapack_sys::dpotrf_, + lapack_sys::dpotri_, + lapack_sys::dpotrs_ +); +impl_cholesky!( + f32, + lapack_sys::spotrf_, + lapack_sys::spotri_, + lapack_sys::spotrs_ +); +impl_cholesky!( + c64, + lapack_sys::zpotrf_, + lapack_sys::zpotri_, + lapack_sys::zpotrs_ +); +impl_cholesky!( + c32, + lapack_sys::cpotrf_, + lapack_sys::cpotri_, + lapack_sys::cpotrs_ +); diff --git a/lax/src/eig.rs b/lax/src/eig.rs index a6d6a16e..bf3b9f8c 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -20,7 +20,7 @@ macro_rules! impl_eig_complex { fn eig( calc_v: bool, l: MatrixLayout, - mut a: &mut [Self], + a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); // LAPACK assumes a column-major input. A row-major input can @@ -35,44 +35,38 @@ macro_rules! impl_eig_complex { // eigenvalues are the eigenvalues computed with `A`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (b'V', b'N'), - MatrixLayout::F { .. } => (b'N', b'V'), + MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), + MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), } } else { - (b'N', b'N') + (EigenVectorFlag::Not, EigenVectorFlag::Not) }; let mut eigs = unsafe { vec_uninit(n as usize) }; - let mut rwork = unsafe { vec_uninit(2 * n as usize) }; + let mut rwork: Vec = unsafe { vec_uninit(2 * n as usize) }; - let mut vl = if jobvl == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; - let mut vr = if jobvr == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; + let mut vl: Option> = + jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); + let mut vr: Option> = + jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); // calc work size let mut info = 0; let mut work_size = [Self::zero()]; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eigs, - &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - &mut rwork, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), &mut info, ) }; @@ -80,29 +74,30 @@ macro_rules! impl_eig_complex { // actal ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eigs, - &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - &mut rwork, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work), + &lwork, + AsPtr::as_mut_ptr(&mut rwork), &mut info, ) }; info.as_lapack_result()?; // Hermite conjugate - if jobvl == b'V' { + if jobvl.is_calc() { for c in vl.as_mut().unwrap().iter_mut() { c.im = -c.im } @@ -114,8 +109,8 @@ macro_rules! impl_eig_complex { }; } -impl_eig_complex!(c64, lapack::zgeev); -impl_eig_complex!(c32, lapack::cgeev); +impl_eig_complex!(c64, lapack_sys::zgeev_); +impl_eig_complex!(c32, lapack_sys::cgeev_); macro_rules! impl_eig_real { ($scalar:ty, $ev:path) => { @@ -123,7 +118,7 @@ macro_rules! impl_eig_real { fn eig( calc_v: bool, l: MatrixLayout, - mut a: &mut [Self], + a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); // LAPACK assumes a column-major input. A row-major input can @@ -144,44 +139,38 @@ macro_rules! impl_eig_real { // `sgeev`/`dgeev`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (b'V', b'N'), - MatrixLayout::F { .. } => (b'N', b'V'), + MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), + MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), } } else { - (b'N', b'N') + (EigenVectorFlag::Not, EigenVectorFlag::Not) }; - let mut eig_re = unsafe { vec_uninit(n as usize) }; - let mut eig_im = unsafe { vec_uninit(n as usize) }; + let mut eig_re: Vec = unsafe { vec_uninit(n as usize) }; + let mut eig_im: Vec = unsafe { vec_uninit(n as usize) }; - let mut vl = if jobvl == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; - let mut vr = if jobvr == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; + let mut vl: Option> = + jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); + let mut vr: Option> = + jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); // calc work size let mut info = 0; - let mut work_size = [0.0]; + let mut work_size: [Self; 1] = [0.0]; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eig_re, - &mut eig_im, - vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eig_re), + AsPtr::as_mut_ptr(&mut eig_im), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), &mut info, ) }; @@ -189,22 +178,23 @@ macro_rules! impl_eig_real { // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eig_re, - &mut eig_im, - vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eig_re), + AsPtr::as_mut_ptr(&mut eig_im), + AsPtr::as_mut_ptr(vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work), + &lwork, &mut info, ) }; @@ -254,7 +244,7 @@ macro_rules! impl_eig_real { for row in 0..n { let re = v[row + col * n]; let mut im = v[row + (col + 1) * n]; - if jobvl == b'V' { + if jobvl.is_calc() { im = -im; } eigvecs[row + col * n] = Self::complex(re, im); @@ -270,5 +260,5 @@ macro_rules! impl_eig_real { }; } -impl_eig_real!(f64, lapack::dgeev); -impl_eig_real!(f32, lapack::sgeev); +impl_eig_real!(f64, lapack_sys::dgeev_); +impl_eig_real!(f32, lapack_sys::sgeev_); diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index ad8963dc..a8403e90 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -37,15 +37,15 @@ macro_rules! impl_eigh { calc_v: bool, layout: MatrixLayout, uplo: UPLO, - mut a: &mut [Self], + a: &mut [Self], ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; + let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; let mut eigs = unsafe { vec_uninit(n as usize) }; $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; )* // calc work size @@ -53,15 +53,15 @@ macro_rules! impl_eigh { let mut work_size = [Self::zero()]; unsafe { $ev( - jobz, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + jobz.as_ptr() , + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -69,18 +69,19 @@ macro_rules! impl_eigh { // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $ev( - jobz, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + jobz.as_ptr(), + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work), + &lwork, + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -92,16 +93,16 @@ macro_rules! impl_eigh { calc_v: bool, layout: MatrixLayout, uplo: UPLO, - mut a: &mut [Self], - mut b: &mut [Self], + a: &mut [Self], + b: &mut [Self], ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; + let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; let mut eigs = unsafe { vec_uninit(n as usize) }; $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit(3 * n as usize - 2) }; )* // calc work size @@ -109,18 +110,18 @@ macro_rules! impl_eigh { let mut work_size = [Self::zero()]; unsafe { $evg( - &[1], - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + &1, // ITYPE A*x = (lambda)*B*x + jobz.as_ptr(), + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(b), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -128,21 +129,22 @@ macro_rules! impl_eigh { // actual evg let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; + let lwork = lwork as i32; unsafe { $evg( - &[1], - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + &1, // ITYPE A*x = (lambda)*B*x + jobz.as_ptr(), + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &n, + AsPtr::as_mut_ptr(b), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work), + &lwork, + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -153,7 +155,7 @@ macro_rules! impl_eigh { }; } // impl_eigh! -impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv); -impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv); -impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv); -impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv); +impl_eigh!(@real, f64, lapack_sys::dsyev_, lapack_sys::dsygv_); +impl_eigh!(@real, f32, lapack_sys::ssyev_, lapack_sys::ssygv_); +impl_eigh!(@complex, c64, lapack_sys::zheev_, lapack_sys::zhegv_); +impl_eigh!(@complex, c32, lapack_sys::cheev_, lapack_sys::chegv_); diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index fc378aa6..97f9a839 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -97,20 +97,20 @@ macro_rules! impl_least_squares { )* unsafe { $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, + &m, + &n, + &nrhs, + AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)), + &a_layout.lda(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &b_layout.lda(), + AsPtr::as_mut_ptr(&mut singular_values), + &rcond, &mut rank, - &mut work_size, - -1, - $(&mut $rwork,)* - &mut iwork_size, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork),)* + iwork_size.as_mut_ptr(), &mut info, ) }; @@ -118,29 +118,29 @@ macro_rules! impl_least_squares { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; let liwork = iwork_size[0].to_usize().unwrap(); - let mut iwork = unsafe { vec_uninit( liwork) }; + let mut iwork = unsafe { vec_uninit(liwork) }; $( let lrwork = $rwork[0].to_usize().unwrap(); - let mut $rwork = unsafe { vec_uninit( lrwork) }; + let mut $rwork: Vec = unsafe { vec_uninit(lrwork) }; )* unsafe { $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, + &m, + &n, + &nrhs, + AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)), + &a_layout.lda(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &b_layout.lda(), + AsPtr::as_mut_ptr(&mut singular_values), + &rcond, &mut rank, - &mut work, - lwork as i32, - $(&mut $rwork,)* - &mut iwork, + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + $(AsPtr::as_mut_ptr(&mut $rwork),)* + iwork.as_mut_ptr(), &mut info, ); } @@ -161,7 +161,7 @@ macro_rules! impl_least_squares { }; } -impl_least_squares!(@real, f64, lapack::dgelsd); -impl_least_squares!(@real, f32, lapack::sgelsd); -impl_least_squares!(@complex, c64, lapack::zgelsd); -impl_least_squares!(@complex, c32, lapack::cgelsd); +impl_least_squares!(@real, f64, lapack_sys::dgelsd_); +impl_least_squares!(@real, f32, lapack_sys::sgelsd_); +impl_least_squares!(@complex, c64, lapack_sys::zgelsd_); +impl_least_squares!(@complex, c32, lapack_sys::cgelsd_); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 41c15237..26b740bb 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -126,6 +126,31 @@ impl Lapack for f64 {} impl Lapack for c32 {} impl Lapack for c64 {} +/// Helper for getting pointer of slice +pub(crate) trait AsPtr: Sized { + type Elem; + fn as_ptr(vec: &[Self]) -> *const Self::Elem; + fn as_mut_ptr(vec: &mut [Self]) -> *mut Self::Elem; +} + +macro_rules! impl_as_ptr { + ($target:ty, $elem:ty) => { + impl AsPtr for $target { + type Elem = $elem; + fn as_ptr(vec: &[Self]) -> *const Self::Elem { + vec.as_ptr() as *const _ + } + fn as_mut_ptr(vec: &mut [Self]) -> *mut Self::Elem { + vec.as_mut_ptr() as *mut _ + } + } + }; +} +impl_as_ptr!(f32, f32); +impl_as_ptr!(f64, f64); +impl_as_ptr!(c32, lapack_sys::__BindgenComplex); +impl_as_ptr!(c64, lapack_sys::__BindgenComplex); + /// Upper/Lower specification for seveal usages #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -141,6 +166,11 @@ impl UPLO { UPLO::Lower => UPLO::Upper, } } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const UPLO as *const i8 + } } #[derive(Debug, Clone, Copy)] @@ -151,6 +181,13 @@ pub enum Transpose { Hermite = b'C', } +impl Transpose { + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const Transpose as *const i8 + } +} + #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum NormType { @@ -167,6 +204,41 @@ impl NormType { NormType::Frobenius => NormType::Frobenius, } } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const NormType as *const i8 + } +} + +/// Flag for calculating eigenvectors or not +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum EigenVectorFlag { + Calc = b'V', + Not = b'N', +} + +impl EigenVectorFlag { + pub fn is_calc(&self) -> bool { + match self { + EigenVectorFlag::Calc => true, + EigenVectorFlag::Not => false, + } + } + + pub fn then T>(&self, f: F) -> Option { + if self.is_calc() { + Some(f()) + } else { + None + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const EigenVectorFlag as *const i8 + } } /// Create a vector without initialization diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index dd84f441..ddcc2c85 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -1,6 +1,6 @@ //! Operator norms of matrices -use super::NormType; +use super::{AsPtr, NormType}; use crate::{layout::MatrixLayout, *}; use cauchy::*; @@ -18,18 +18,27 @@ macro_rules! impl_opnorm { MatrixLayout::F { .. } => t, MatrixLayout::C { .. } => t.transpose(), }; - let mut work = if matches!(t, NormType::Infinity) { + let mut work: Vec = if matches!(t, NormType::Infinity) { unsafe { vec_uninit(m as usize) } } else { Vec::new() }; - unsafe { $lange(t as u8, m, n, a, m, &mut work) } + unsafe { + $lange( + t.as_ptr(), + &m, + &n, + AsPtr::as_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut work), + ) + } } } }; } // impl_opnorm! -impl_opnorm!(f64, lapack::dlange); -impl_opnorm!(f32, lapack::slange); -impl_opnorm!(c64, lapack::zlange); -impl_opnorm!(c32, lapack::clange); +impl_opnorm!(f64, lapack_sys::dlange_); +impl_opnorm!(f32, lapack_sys::slange_); +impl_opnorm!(c64, lapack_sys::zlange_); +impl_opnorm!(c32, lapack_sys::clange_); diff --git a/lax/src/qr.rs b/lax/src/qr.rs index 6460b8b9..33de0372 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -21,7 +21,7 @@ pub trait QR_: Sized { macro_rules! impl_qr { ($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => { impl QR_ for $scalar { - fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { + fn householder(l: MatrixLayout, a: &mut [Self]) -> Result> { let m = l.lda(); let n = l.len(); let k = m.min(n); @@ -33,10 +33,28 @@ macro_rules! impl_qr { unsafe { match l { MatrixLayout::F { .. } => { - $qrf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + $qrf( + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ); } MatrixLayout::C { .. } => { - $lqf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + $lqf( + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ); } } } @@ -44,30 +62,30 @@ macro_rules! impl_qr { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { match l { MatrixLayout::F { .. } => { $qrf( - m, - n, - &mut a, - m, - &mut tau, - &mut work, - lwork as i32, + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ); } MatrixLayout::C { .. } => { $lqf( - m, - n, - &mut a, - m, - &mut tau, - &mut work, - lwork as i32, + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ); } @@ -78,7 +96,7 @@ macro_rules! impl_qr { Ok(tau) } - fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()> { let m = l.lda(); let n = l.len(); let k = m.min(n); @@ -89,26 +107,58 @@ macro_rules! impl_qr { let mut work_size = [Self::zero()]; unsafe { match l { - MatrixLayout::F { .. } => { - $gqr(m, k, k, &mut a, m, &tau, &mut work_size, -1, &mut info) - } - MatrixLayout::C { .. } => { - $glq(k, n, k, &mut a, m, &tau, &mut work_size, -1, &mut info) - } + MatrixLayout::F { .. } => $gqr( + &m, + &k, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ), + MatrixLayout::C { .. } => $glq( + &k, + &n, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ), } }; // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { match l { - MatrixLayout::F { .. } => { - $gqr(m, k, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) - } - MatrixLayout::C { .. } => { - $glq(k, n, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) - } + MatrixLayout::F { .. } => $gqr( + &m, + &k, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + &mut info, + ), + MatrixLayout::C { .. } => $glq( + &k, + &n, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + &mut info, + ), } } info.as_lapack_result()?; @@ -127,29 +177,29 @@ macro_rules! impl_qr { impl_qr!( f64, - lapack::dgeqrf, - lapack::dgelqf, - lapack::dorgqr, - lapack::dorglq + lapack_sys::dgeqrf_, + lapack_sys::dgelqf_, + lapack_sys::dorgqr_, + lapack_sys::dorglq_ ); impl_qr!( f32, - lapack::sgeqrf, - lapack::sgelqf, - lapack::sorgqr, - lapack::sorglq + lapack_sys::sgeqrf_, + lapack_sys::sgelqf_, + lapack_sys::sorgqr_, + lapack_sys::sorglq_ ); impl_qr!( c64, - lapack::zgeqrf, - lapack::zgelqf, - lapack::zungqr, - lapack::zunglq + lapack_sys::zgeqrf_, + lapack_sys::zgelqf_, + lapack_sys::zungqr_, + lapack_sys::zunglq_ ); impl_qr!( c32, - lapack::cgeqrf, - lapack::cgelqf, - lapack::cungqr, - lapack::cunglq + lapack_sys::cgeqrf_, + lapack_sys::cgelqf_, + lapack_sys::cungqr_, + lapack_sys::cunglq_ ); diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs index 91d7458c..fcd4211f 100644 --- a/lax/src/rcond.rs +++ b/lax/src/rcond.rs @@ -17,22 +17,22 @@ macro_rules! impl_rcond_real { let mut rcond = Self::Real::zero(); let mut info = 0; - let mut work = unsafe { vec_uninit(4 * n as usize) }; + let mut work: Vec = unsafe { vec_uninit(4 * n as usize) }; let mut iwork = unsafe { vec_uninit(n as usize) }; let norm_type = match l { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, - } as u8; + }; unsafe { $gecon( - norm_type, - n, - a, - l.lda(), - anorm, + norm_type.as_ptr(), + &n, + AsPtr::as_ptr(a), + &l.lda(), + &anorm, &mut rcond, - &mut work, - &mut iwork, + AsPtr::as_mut_ptr(&mut work), + iwork.as_mut_ptr(), &mut info, ) }; @@ -44,8 +44,8 @@ macro_rules! impl_rcond_real { }; } -impl_rcond_real!(f32, lapack::sgecon); -impl_rcond_real!(f64, lapack::dgecon); +impl_rcond_real!(f32, lapack_sys::sgecon_); +impl_rcond_real!(f64, lapack_sys::dgecon_); macro_rules! impl_rcond_complex { ($scalar:ty, $gecon:path) => { @@ -54,22 +54,22 @@ macro_rules! impl_rcond_complex { let (n, _) = l.size(); let mut rcond = Self::Real::zero(); let mut info = 0; - let mut work = unsafe { vec_uninit(2 * n as usize) }; - let mut rwork = unsafe { vec_uninit(2 * n as usize) }; + let mut work: Vec = unsafe { vec_uninit(2 * n as usize) }; + let mut rwork: Vec = unsafe { vec_uninit(2 * n as usize) }; let norm_type = match l { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, - } as u8; + }; unsafe { $gecon( - norm_type, - n, - a, - l.lda(), - anorm, + norm_type.as_ptr(), + &n, + AsPtr::as_ptr(a), + &l.lda(), + &anorm, &mut rcond, - &mut work, - &mut rwork, + AsPtr::as_mut_ptr(&mut work), + AsPtr::as_mut_ptr(&mut rwork), &mut info, ) }; @@ -81,5 +81,5 @@ macro_rules! impl_rcond_complex { }; } -impl_rcond_complex!(c32, lapack::cgecon); -impl_rcond_complex!(c64, lapack::zgecon); +impl_rcond_complex!(c32, lapack_sys::cgecon_); +impl_rcond_complex!(c64, lapack_sys::zgecon_); diff --git a/lax/src/solve.rs b/lax/src/solve.rs index e09a7e8d..9c19c874 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -35,7 +35,16 @@ macro_rules! impl_solve { let k = ::std::cmp::min(row, col); let mut ipiv = unsafe { vec_uninit(k as usize) }; let mut info = 0; - unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; + unsafe { + $getrf( + &l.lda(), + &l.len(), + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_mut_ptr(), + &mut info, + ) + }; info.as_lapack_result()?; Ok(ipiv) } @@ -50,20 +59,30 @@ macro_rules! impl_solve { // calc work size let mut info = 0; let mut work_size = [Self::zero()]; - unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; + unsafe { + $getri( + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; info.as_lapack_result()?; // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { $getri( - l.len(), - a, - l.lda(), - ipiv, - &mut work, - lwork as i32, + &l.len(), + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ) }; @@ -116,7 +135,19 @@ macro_rules! impl_solve { *b_elem = b_elem.conj(); } } - unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; + unsafe { + $getrs( + t.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b), + &ldb, + &mut info, + ) + }; if conj { for b_elem in &mut *b { *b_elem = b_elem.conj(); @@ -129,7 +160,27 @@ macro_rules! impl_solve { }; } // impl_solve! -impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); -impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); -impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); -impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); +impl_solve!( + f64, + lapack_sys::dgetrf_, + lapack_sys::dgetri_, + lapack_sys::dgetrs_ +); +impl_solve!( + f32, + lapack_sys::sgetrf_, + lapack_sys::sgetri_, + lapack_sys::sgetrs_ +); +impl_solve!( + c64, + lapack_sys::zgetrf_, + lapack_sys::zgetri_, + lapack_sys::zgetrs_ +); +impl_solve!( + c32, + lapack_sys::cgetrf_, + lapack_sys::cgetri_, + lapack_sys::cgetrs_ +); diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 1a4d6e3e..c5259dda 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -30,13 +30,13 @@ macro_rules! impl_solveh { let mut work_size = [Self::zero()]; unsafe { $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work_size, - -1, + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), &mut info, ) }; @@ -44,16 +44,16 @@ macro_rules! impl_solveh { // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let mut work: Vec = unsafe { vec_uninit(lwork) }; unsafe { $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work, - lwork as i32, + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), &mut info, ) }; @@ -64,8 +64,18 @@ macro_rules! impl_solveh { fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); let mut info = 0; - let mut work = unsafe { vec_uninit(n as usize) }; - unsafe { $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info) }; + let mut work: Vec = unsafe { vec_uninit(n as usize) }; + unsafe { + $tri( + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut work), + &mut info, + ) + }; info.as_lapack_result()?; Ok(()) } @@ -79,7 +89,19 @@ macro_rules! impl_solveh { ) -> Result<()> { let (n, _) = l.size(); let mut info = 0; - unsafe { $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info) }; + unsafe { + $trs( + uplo.as_ptr(), + &n, + &1, + AsPtr::as_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ) + }; info.as_lapack_result()?; Ok(()) } @@ -87,7 +109,27 @@ macro_rules! impl_solveh { }; } // impl_solveh! -impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs); -impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs); -impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs); -impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs); +impl_solveh!( + f64, + lapack_sys::dsytrf_, + lapack_sys::dsytri_, + lapack_sys::dsytrs_ +); +impl_solveh!( + f32, + lapack_sys::ssytrf_, + lapack_sys::ssytri_, + lapack_sys::ssytrs_ +); +impl_solveh!( + c64, + lapack_sys::zhetrf_, + lapack_sys::zhetri_, + lapack_sys::zhetrs_ +); +impl_solveh!( + c32, + lapack_sys::chetrf_, + lapack_sys::chetri_, + lapack_sys::chetrs_ +); diff --git a/lax/src/svd.rs b/lax/src/svd.rs index c990cd27..0ee56428 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -21,6 +21,10 @@ impl FlagSVD { FlagSVD::No } } + + fn as_ptr(&self) -> *const i8 { + self as *const FlagSVD as *const i8 + } } /// Result of SVD @@ -49,7 +53,7 @@ macro_rules! impl_svd { }; (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => { impl SVD_ for $scalar { - fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result> { + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result> { let ju = match l { MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), @@ -75,7 +79,7 @@ macro_rules! impl_svd { let mut s = unsafe { vec_uninit( k as usize) }; $( - let mut $rwork_ident = unsafe { vec_uninit( 5 * k as usize) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit( 5 * k as usize) }; )* // eval work size @@ -83,20 +87,20 @@ macro_rules! impl_svd { let mut work_size = [Self::zero()]; unsafe { $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + ju.as_ptr(), + jvt.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -104,23 +108,23 @@ macro_rules! impl_svd { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let mut work: Vec = unsafe { vec_uninit( lwork) }; unsafe { $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + ju.as_ptr(), + jvt.as_ptr() , + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* &mut info, ); } @@ -134,7 +138,7 @@ macro_rules! impl_svd { }; } // impl_svd! -impl_svd!(@real, f64, lapack::dgesvd); -impl_svd!(@real, f32, lapack::sgesvd); -impl_svd!(@complex, c64, lapack::zgesvd); -impl_svd!(@complex, c32, lapack::cgesvd); +impl_svd!(@real, f64, lapack_sys::dgesvd_); +impl_svd!(@real, f32, lapack_sys::sgesvd_); +impl_svd!(@complex, c64, lapack_sys::zgesvd_); +impl_svd!(@complex, c32, lapack_sys::cgesvd_); diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index a94bdace..c1198286 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -16,6 +16,12 @@ pub enum UVTFlag { None = b'N', } +impl UVTFlag { + fn as_ptr(&self) -> *const i8 { + self as *const UVTFlag as *const i8 + } +} + pub trait SVDDC_: Scalar { fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } @@ -29,7 +35,7 @@ macro_rules! impl_svddc { }; (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result> { + fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self],) -> Result> { let m = l.lda(); let n = l.len(); let k = m.min(n); @@ -58,7 +64,7 @@ macro_rules! impl_svddc { UVTFlag::None => 7 * mn, _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), }; - let mut $rwork_ident = unsafe { vec_uninit( lrwork) }; + let mut $rwork_ident: Vec = unsafe { vec_uninit( lrwork) }; )* // eval work size @@ -67,20 +73,20 @@ macro_rules! impl_svddc { let mut work_size = [Self::zero()]; unsafe { $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work_size, - -1, - $(&mut $rwork_ident,)* - &mut iwork, + jobz.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &vt_row, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* + iwork.as_mut_ptr(), &mut info, ); } @@ -88,23 +94,23 @@ macro_rules! impl_svddc { // do svd let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let mut work: Vec = unsafe { vec_uninit( lwork) }; unsafe { $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* - &mut iwork, + jobz.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &vt_row, + AsPtr::as_mut_ptr(&mut work), + &(lwork as i32), + $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* + iwork.as_mut_ptr(), &mut info, ); } @@ -119,7 +125,7 @@ macro_rules! impl_svddc { }; } -impl_svddc!(@real, f32, lapack::sgesdd); -impl_svddc!(@real, f64, lapack::dgesdd); -impl_svddc!(@complex, c32, lapack::cgesdd); -impl_svddc!(@complex, c64, lapack::zgesdd); +impl_svddc!(@real, f32, lapack_sys::sgesdd_); +impl_svddc!(@real, f64, lapack_sys::dgesdd_); +impl_svddc!(@complex, c32, lapack_sys::cgesdd_); +impl_svddc!(@complex, c64, lapack_sys::zgesdd_); diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index a48b12b3..0288d6ba 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -10,6 +10,12 @@ pub enum Diag { NonUnit = b'N', } +impl Diag { + fn as_ptr(&self) -> *const i8 { + self as *const Diag as *const i8 + } +} + /// Wraps `*trtri` and `*trtrs` pub trait Triangular_: Scalar { fn solve_triangular( @@ -60,15 +66,15 @@ macro_rules! impl_triangular { let mut info = 0; unsafe { $trtrs( - uplo as u8, - Transpose::No as u8, - diag as u8, - m, - nrhs, - a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), + uplo.as_ptr(), + Transpose::No.as_ptr(), + diag.as_ptr(), + &m, + &nrhs, + AsPtr::as_ptr(a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a)), + &a_layout.lda(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &b_layout.lda(), &mut info, ); } @@ -84,7 +90,7 @@ macro_rules! impl_triangular { }; } // impl_triangular! -impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs); -impl_triangular!(f32, lapack::strtri, lapack::strtrs); -impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs); -impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs); +impl_triangular!(f64, lapack_sys::dtrtri_, lapack_sys::dtrtrs_); +impl_triangular!(f32, lapack_sys::strtri_, lapack_sys::strtrs_); +impl_triangular!(c64, lapack_sys::ztrtri_, lapack_sys::ztrtrs_); +impl_triangular!(c32, lapack_sys::ctrtri_, lapack_sys::ctrtrs_); diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 2b1d98e2..c80ad4b5 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -157,7 +157,17 @@ macro_rules! impl_tridiagonal { // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); let mut info = 0; - unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) }; + unsafe { + $gttrf( + &n, + AsPtr::as_mut_ptr(&mut a.dl), + AsPtr::as_mut_ptr(&mut a.d), + AsPtr::as_mut_ptr(&mut a.du), + AsPtr::as_mut_ptr(&mut du2), + ipiv.as_mut_ptr(), + &mut info, + ) + }; info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, @@ -170,7 +180,7 @@ macro_rules! impl_tridiagonal { fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; - let mut work = unsafe { vec_uninit( 2 * n as usize) }; + let mut work: Vec = unsafe { vec_uninit( 2 * n as usize) }; $( let mut $iwork = unsafe { vec_uninit( n as usize) }; )* @@ -178,17 +188,17 @@ macro_rules! impl_tridiagonal { let mut info = 0; unsafe { $gtcon( - NormType::One as u8, - n, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - lu.a_opnorm_one, + NormType::One.as_ptr(), + &n, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + &lu.a_opnorm_one, &mut rcond, - &mut work, - $(&mut $iwork,)* + AsPtr::as_mut_ptr(&mut work), + $($iwork.as_mut_ptr(),)* &mut info, ); } @@ -217,16 +227,16 @@ macro_rules! impl_tridiagonal { let mut info = 0; unsafe { $gttrs( - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - ldb, + t.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &ldb, &mut info, ); } @@ -240,7 +250,7 @@ macro_rules! impl_tridiagonal { }; } // impl_tridiagonal! -impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs); -impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs); -impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs); -impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs); +impl_tridiagonal!(@real, f64, lapack_sys::dgttrf_, lapack_sys::dgtcon_, lapack_sys::dgttrs_); +impl_tridiagonal!(@real, f32, lapack_sys::sgttrf_, lapack_sys::sgtcon_, lapack_sys::sgttrs_); +impl_tridiagonal!(@complex, c64, lapack_sys::zgttrf_, lapack_sys::zgtcon_, lapack_sys::zgttrs_); +impl_tridiagonal!(@complex, c32, lapack_sys::cgttrf_, lapack_sys::cgtcon_, lapack_sys::cgttrs_);