diff --git a/lax/src/eig.rs b/lax/src/eig.rs index f11f5287..d3c07222 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -35,11 +35,11 @@ macro_rules! impl_eig_complex { // eigenvalues are the eigenvalues computed with `A`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), - MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), } } else { - (EigenVectorFlag::Not, EigenVectorFlag::Not) + (JobEv::None, JobEv::None) }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; let mut rwork: Vec> = unsafe { vec_uninit(2 * n as usize) }; @@ -143,11 +143,11 @@ macro_rules! impl_eig_real { // `sgeev`/`dgeev`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), - MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), } } else { - (EigenVectorFlag::Not, EigenVectorFlag::Not) + (JobEv::None, JobEv::None) }; let mut eig_re: Vec> = unsafe { vec_uninit(n as usize) }; let mut eig_im: Vec> = unsafe { vec_uninit(n as usize) }; diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index 0692f921..a9406ee6 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -41,7 +41,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; + let jobz = if calc_v { JobEv::All } else { JobEv::None }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( @@ -100,7 +100,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; + let jobz = if calc_v { JobEv::All } else { JobEv::None }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( diff --git a/lax/src/flags.rs b/lax/src/flags.rs new file mode 100644 index 00000000..37a11b3c --- /dev/null +++ b/lax/src/flags.rs @@ -0,0 +1,137 @@ +//! Charactor flags, e.g. `'T'`, used in LAPACK API + +/// Upper/Lower specification for seveal usages +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum UPLO { + Upper = b'U', + Lower = b'L', +} + +impl UPLO { + pub fn t(self) -> Self { + match self { + UPLO::Upper => UPLO::Lower, + UPLO::Lower => UPLO::Upper, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const UPLO as *const i8 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} + +impl Transpose { + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const Transpose as *const i8 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum NormType { + One = b'O', + Infinity = b'I', + Frobenius = b'F', +} + +impl NormType { + pub fn transpose(self) -> Self { + match self { + NormType::One => NormType::Infinity, + NormType::Infinity => NormType::One, + NormType::Frobenius => NormType::Frobenius, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const NormType as *const i8 + } +} + +/// Flag for calculating eigenvectors or not +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum JobEv { + /// Calculate eigenvectors in addition to eigenvalues + All = b'V', + /// Do not calculate eigenvectors. Only calculate eigenvalues. + None = b'N', +} + +impl JobEv { + pub fn is_calc(&self) -> bool { + match self { + JobEv::All => true, + JobEv::None => false, + } + } + + pub fn then T>(&self, f: F) -> Option { + if self.is_calc() { + Some(f()) + } else { + None + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const JobEv as *const i8 + } +} + +/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. +/// +/// For an input array of shape *m*×*n*, the following are computed: +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum JobSvd { + /// All *m* columns of *U* and all *n* rows of *V*ᵀ. + All = b'A', + /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. + Some = b'S', + /// No columns of *U* or rows of *V*ᵀ. + None = b'N', +} + +impl JobSvd { + pub fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + JobSvd::All + } else { + JobSvd::None + } + } + + pub fn as_ptr(&self) -> *const i8 { + self as *const JobSvd as *const i8 + } +} + +/// Specify whether input triangular matrix is unit or not +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum Diag { + /// Unit triangular matrix, i.e. all diagonal elements of the matrix are `1` + Unit = b'U', + /// Non-unit triangular matrix. Its diagonal elements may be different from `1` + NonUnit = b'N', +} + +impl Diag { + pub fn as_ptr(&self) -> *const i8 { + self as *const Diag as *const i8 + } +} diff --git a/lax/src/lib.rs b/lax/src/lib.rs index c8d2264d..83cf3658 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -69,6 +69,7 @@ extern crate openblas_src as _src; extern crate netlib_src as _src; pub mod error; +pub mod flags; pub mod layout; mod cholesky; @@ -88,6 +89,7 @@ mod tridiagonal; pub use self::cholesky::*; pub use self::eig::*; pub use self::eigh::*; +pub use self::flags::*; pub use self::least_squares::*; pub use self::opnorm::*; pub use self::qr::*; @@ -173,96 +175,6 @@ impl VecAssumeInit for Vec> { } } -/// Upper/Lower specification for seveal usages -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum UPLO { - Upper = b'U', - Lower = b'L', -} - -impl UPLO { - pub fn t(self) -> Self { - match self { - UPLO::Upper => UPLO::Lower, - UPLO::Lower => UPLO::Upper, - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const UPLO as *const i8 - } -} - -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Transpose { - No = b'N', - Transpose = b'T', - Hermite = b'C', -} - -impl Transpose { - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const Transpose as *const i8 - } -} - -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum NormType { - One = b'O', - Infinity = b'I', - Frobenius = b'F', -} - -impl NormType { - pub fn transpose(self) -> Self { - match self { - NormType::One => NormType::Infinity, - NormType::Infinity => NormType::One, - NormType::Frobenius => NormType::Frobenius, - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const NormType as *const i8 - } -} - -/// Flag for calculating eigenvectors or not -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum EigenVectorFlag { - Calc = b'V', - Not = b'N', -} - -impl EigenVectorFlag { - pub fn is_calc(&self) -> bool { - match self { - EigenVectorFlag::Calc => true, - EigenVectorFlag::Not => false, - } - } - - pub fn then T>(&self, f: F) -> Option { - if self.is_calc() { - Some(f()) - } else { - None - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const EigenVectorFlag as *const i8 - } -} - /// Create a vector without initialization /// /// Safety diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 8c731c7a..0a509a0e 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -1,32 +1,9 @@ //! Singular-value decomposition -use crate::{error::*, layout::MatrixLayout, *}; +use super::{error::*, layout::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -#[repr(u8)] -#[derive(Debug, Copy, Clone)] -enum FlagSVD { - All = b'A', - // OverWrite = b'O', - // Separately = b'S', - No = b'N', -} - -impl FlagSVD { - fn from_bool(calc_uv: bool) -> Self { - if calc_uv { - FlagSVD::All - } else { - FlagSVD::No - } - } - - fn as_ptr(&self) -> *const i8 { - self as *const FlagSVD as *const i8 - } -} - /// Result of SVD pub struct SVDOutput { /// diagonal values @@ -55,24 +32,26 @@ macro_rules! impl_svd { impl SVD_ for $scalar { fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result> { let ju = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), + MatrixLayout::F { .. } => JobSvd::from_bool(calc_u), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt), }; let jvt = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), + MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_u), }; let m = l.lda(); let mut u = match ju { - FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }), - FlagSVD::No => None, + JobSvd::All => Some(unsafe { vec_uninit( (m * m) as usize) }), + JobSvd::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet") }; let n = l.len(); let mut vt = match jvt { - FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }), - FlagSVD::No => None, + JobSvd::All => Some(unsafe { vec_uninit( (n * n) as usize) }), + JobSvd::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet") }; let k = std::cmp::min(m, n); diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index f956d848..bb59348f 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -2,28 +2,8 @@ use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. -/// -/// For an input array of shape *m*×*n*, the following are computed: -#[derive(Clone, Copy, Eq, PartialEq)] -#[repr(u8)] -pub enum UVTFlag { - /// All *m* columns of *U* and all *n* rows of *V*ᵀ. - Full = b'A', - /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. - Some = b'S', - /// No columns of *U* or rows of *V*ᵀ. - None = b'N', -} - -impl UVTFlag { - fn as_ptr(&self) -> *const i8 { - self as *const UVTFlag as *const i8 - } -} - pub trait SVDDC_: Scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; + fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result>; } macro_rules! impl_svddc { @@ -35,33 +15,33 @@ macro_rules! impl_svddc { }; (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self],) -> Result> { + fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self],) -> Result> { let m = l.lda(); let n = l.len(); let k = m.min(n); let mut s = unsafe { vec_uninit( k as usize) }; let (u_col, vt_row) = match jobz { - UVTFlag::Full | UVTFlag::None => (m, n), - UVTFlag::Some => (k, k), + JobSvd::All | JobSvd::None => (m, n), + JobSvd::Some => (k, k), }; let (mut u, mut vt) = match jobz { - UVTFlag::Full => ( + JobSvd::All => ( Some(unsafe { vec_uninit( (m * m) as usize) }), Some(unsafe { vec_uninit( (n * n) as usize) }), ), - UVTFlag::Some => ( + JobSvd::Some => ( Some(unsafe { vec_uninit( (m * u_col) as usize) }), Some(unsafe { vec_uninit( (n * vt_row) as usize) }), ), - UVTFlag::None => (None, None), + JobSvd::None => (None, None), }; $( // for complex only let mx = n.max(m) as usize; let mn = n.min(m) as usize; let lrwork = match jobz { - UVTFlag::None => 7 * mn, + JobSvd::None => 7 * mn, _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), }; let mut $rwork_ident: Vec> = unsafe { vec_uninit( lrwork) }; diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index e8825758..14f29807 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -3,19 +3,6 @@ use crate::{error::*, layout::*, *}; use cauchy::*; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Diag { - Unit = b'U', - NonUnit = b'N', -} - -impl Diag { - fn as_ptr(&self) -> *const i8 { - self as *const Diag as *const i8 - } -} - /// Wraps `*trtri` and `*trtrs` pub trait Triangular_: Scalar { fn solve_triangular( diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index ff73407f..0b0ae237 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -3,14 +3,14 @@ use super::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -pub use lax::UVTFlag; +pub use lax::JobSvd; /// Singular-value decomposition of matrix (copying) by divide-and-conquer pub trait SVDDC { type U; type VT; type Sigma; - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)>; + fn svddc(&self, uvt_flag: JobSvd) -> Result<(Option, Self::Sigma, Option)>; } /// Singular-value decomposition of matrix by divide-and-conquer @@ -20,7 +20,7 @@ pub trait SVDDCInto { type Sigma; fn svddc_into( self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)>; } @@ -31,7 +31,7 @@ pub trait SVDDCInplace { type Sigma; fn svddc_inplace( &mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)>; } @@ -44,7 +44,7 @@ where type VT = Array2; type Sigma = Array1; - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)> { + fn svddc(&self, uvt_flag: JobSvd) -> Result<(Option, Self::Sigma, Option)> { self.to_owned().svddc_into(uvt_flag) } } @@ -60,7 +60,7 @@ where fn svddc_into( mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)> { self.svddc_inplace(uvt_flag) } @@ -77,7 +77,7 @@ where fn svddc_inplace( &mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = A::svddc(l, uvt_flag, self.as_allocated_mut()?)?; @@ -85,9 +85,9 @@ where let k = m.min(n); let (u_col, vt_row) = match uvt_flag { - UVTFlag::Full => (m, n), - UVTFlag::Some => (k, k), - UVTFlag::None => (0, 0), + JobSvd::All => (m, n), + JobSvd::Some => (k, k), + JobSvd::None => (0, 0), }; let u = svd_res diff --git a/ndarray-linalg/tests/svddc.rs b/ndarray-linalg/tests/svddc.rs index fb26c8d5..ed28fc30 100644 --- a/ndarray-linalg/tests/svddc.rs +++ b/ndarray-linalg/tests/svddc.rs @@ -1,16 +1,16 @@ use ndarray::*; use ndarray_linalg::*; -fn test(a: &Array2, flag: UVTFlag) { +fn test(a: &Array2, flag: JobSvd) { let (n, m) = a.dim(); let k = n.min(m); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); let mut sm: Array2 = match flag { - UVTFlag::Full => Array::zeros((n, m)), - UVTFlag::Some => Array::zeros((k, k)), - UVTFlag::None => { + JobSvd::All => Array::zeros((n, m)), + JobSvd::Some => Array::zeros((k, k)), + JobSvd::None => { assert!(u.is_none()); assert!(vt.is_none()); return; @@ -33,37 +33,37 @@ macro_rules! test_svd_impl { #[test] fn []() { let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::Full); + test::<$scalar>(&a, JobSvd::All); } #[test] fn []() { let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::Some); + test::<$scalar>(&a, JobSvd::Some); } #[test] fn []() { let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::None); + test::<$scalar>(&a, JobSvd::None); } #[test] fn []() { let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::Full); + test::<$scalar>(&a, JobSvd::All); } #[test] fn []() { let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::Some); + test::<$scalar>(&a, JobSvd::Some); } #[test] fn []() { let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::None); + test::<$scalar>(&a, JobSvd::None); } } };