From 9109a2f6370e4e7e3f79df673cdf6e59ffbd6dfa Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 2 Sep 2022 17:17:04 +0900 Subject: [PATCH 1/8] Gather enum definitions --- lax/src/flags.rs | 145 ++++++++++++++++++++++++++++++++++++++++++ lax/src/lib.rs | 92 +-------------------------- lax/src/svd.rs | 25 +------- lax/src/svddc.rs | 20 ------ lax/src/triangular.rs | 13 ---- 5 files changed, 148 insertions(+), 147 deletions(-) create mode 100644 lax/src/flags.rs diff --git a/lax/src/flags.rs b/lax/src/flags.rs new file mode 100644 index 00000000..a0e146db --- /dev/null +++ b/lax/src/flags.rs @@ -0,0 +1,145 @@ +/// Upper/Lower specification for seveal usages +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum UPLO { + Upper = b'U', + Lower = b'L', +} + +impl UPLO { + pub fn t(self) -> Self { + match self { + UPLO::Upper => UPLO::Lower, + UPLO::Lower => UPLO::Upper, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const UPLO as *const i8 + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} + +impl Transpose { + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const Transpose as *const i8 + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum NormType { + One = b'O', + Infinity = b'I', + Frobenius = b'F', +} + +impl NormType { + pub fn transpose(self) -> Self { + match self { + NormType::One => NormType::Infinity, + NormType::Infinity => NormType::One, + NormType::Frobenius => NormType::Frobenius, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const NormType as *const i8 + } +} + +/// Flag for calculating eigenvectors or not +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum EigenVectorFlag { + Calc = b'V', + Not = b'N', +} + +impl EigenVectorFlag { + pub fn is_calc(&self) -> bool { + match self { + EigenVectorFlag::Calc => true, + EigenVectorFlag::Not => false, + } + } + + pub fn then T>(&self, f: F) -> Option { + if self.is_calc() { + Some(f()) + } else { + None + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const EigenVectorFlag as *const i8 + } +} + +#[repr(u8)] +#[derive(Debug, Copy, Clone)] +pub enum FlagSVD { + All = b'A', + // OverWrite = b'O', + // Separately = b'S', + No = b'N', +} + +impl FlagSVD { + pub fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + FlagSVD::All + } else { + FlagSVD::No + } + } + + pub fn as_ptr(&self) -> *const i8 { + self as *const FlagSVD as *const i8 + } +} + +/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. +/// +/// For an input array of shape *m*×*n*, the following are computed: +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(u8)] +pub enum UVTFlag { + /// All *m* columns of *U* and all *n* rows of *V*ᵀ. + Full = b'A', + /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. + Some = b'S', + /// No columns of *U* or rows of *V*ᵀ. + None = b'N', +} + +impl UVTFlag { + pub fn as_ptr(&self) -> *const i8 { + self as *const UVTFlag as *const i8 + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Diag { + Unit = b'U', + NonUnit = b'N', +} + +impl Diag { + pub fn as_ptr(&self) -> *const i8 { + self as *const Diag as *const i8 + } +} diff --git a/lax/src/lib.rs b/lax/src/lib.rs index c8d2264d..0a6249ec 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -74,6 +74,7 @@ pub mod layout; mod cholesky; mod eig; mod eigh; +mod flags; mod least_squares; mod opnorm; mod qr; @@ -88,6 +89,7 @@ mod tridiagonal; pub use self::cholesky::*; pub use self::eig::*; pub use self::eigh::*; +pub use self::flags::*; pub use self::least_squares::*; pub use self::opnorm::*; pub use self::qr::*; @@ -173,96 +175,6 @@ impl VecAssumeInit for Vec> { } } -/// Upper/Lower specification for seveal usages -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum UPLO { - Upper = b'U', - Lower = b'L', -} - -impl UPLO { - pub fn t(self) -> Self { - match self { - UPLO::Upper => UPLO::Lower, - UPLO::Lower => UPLO::Upper, - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const UPLO as *const i8 - } -} - -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Transpose { - No = b'N', - Transpose = b'T', - Hermite = b'C', -} - -impl Transpose { - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const Transpose as *const i8 - } -} - -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum NormType { - One = b'O', - Infinity = b'I', - Frobenius = b'F', -} - -impl NormType { - pub fn transpose(self) -> Self { - match self { - NormType::One => NormType::Infinity, - NormType::Infinity => NormType::One, - NormType::Frobenius => NormType::Frobenius, - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const NormType as *const i8 - } -} - -/// Flag for calculating eigenvectors or not -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum EigenVectorFlag { - Calc = b'V', - Not = b'N', -} - -impl EigenVectorFlag { - pub fn is_calc(&self) -> bool { - match self { - EigenVectorFlag::Calc => true, - EigenVectorFlag::Not => false, - } - } - - pub fn then T>(&self, f: F) -> Option { - if self.is_calc() { - Some(f()) - } else { - None - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const EigenVectorFlag as *const i8 - } -} - /// Create a vector without initialization /// /// Safety diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 8c731c7a..c9e73f89 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -1,32 +1,9 @@ //! Singular-value decomposition -use crate::{error::*, layout::MatrixLayout, *}; +use super::{error::*, layout::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -#[repr(u8)] -#[derive(Debug, Copy, Clone)] -enum FlagSVD { - All = b'A', - // OverWrite = b'O', - // Separately = b'S', - No = b'N', -} - -impl FlagSVD { - fn from_bool(calc_uv: bool) -> Self { - if calc_uv { - FlagSVD::All - } else { - FlagSVD::No - } - } - - fn as_ptr(&self) -> *const i8 { - self as *const FlagSVD as *const i8 - } -} - /// Result of SVD pub struct SVDOutput { /// diagonal values diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index f956d848..e31f126f 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -2,26 +2,6 @@ use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. -/// -/// For an input array of shape *m*×*n*, the following are computed: -#[derive(Clone, Copy, Eq, PartialEq)] -#[repr(u8)] -pub enum UVTFlag { - /// All *m* columns of *U* and all *n* rows of *V*ᵀ. - Full = b'A', - /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. - Some = b'S', - /// No columns of *U* or rows of *V*ᵀ. - None = b'N', -} - -impl UVTFlag { - fn as_ptr(&self) -> *const i8 { - self as *const UVTFlag as *const i8 - } -} - pub trait SVDDC_: Scalar { fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index e8825758..14f29807 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -3,19 +3,6 @@ use crate::{error::*, layout::*, *}; use cauchy::*; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Diag { - Unit = b'U', - NonUnit = b'N', -} - -impl Diag { - fn as_ptr(&self) -> *const i8 { - self as *const Diag as *const i8 - } -} - /// Wraps `*trtri` and `*trtrs` pub trait Triangular_: Scalar { fn solve_triangular( From 262a2c966e97abcdde957b5e5daa33babc66d6b9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 2 Sep 2022 17:28:30 +0900 Subject: [PATCH 2/8] Make `lax::flags` module public --- lax/src/flags.rs | 2 ++ lax/src/lib.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lax/src/flags.rs b/lax/src/flags.rs index a0e146db..f3b3ecd0 100644 --- a/lax/src/flags.rs +++ b/lax/src/flags.rs @@ -1,3 +1,5 @@ +//! Charactor flags, e.g. `'T'`, used in LAPACK API + /// Upper/Lower specification for seveal usages #[derive(Debug, Clone, Copy)] #[repr(u8)] diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 0a6249ec..83cf3658 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -69,12 +69,12 @@ extern crate openblas_src as _src; extern crate netlib_src as _src; pub mod error; +pub mod flags; pub mod layout; mod cholesky; mod eig; mod eigh; -mod flags; mod least_squares; mod opnorm; mod qr; From 4912a115a156d6fd25c786cc81260a49975f546b Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 3 Sep 2022 00:20:14 +0900 Subject: [PATCH 3/8] Rename `EigenVectorFlag` to `JobEv` --- lax/src/eig.rs | 12 ++++++------ lax/src/eigh.rs | 4 ++-- lax/src/flags.rs | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index f11f5287..184172a0 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -35,11 +35,11 @@ macro_rules! impl_eig_complex { // eigenvalues are the eigenvalues computed with `A`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), - MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), + MatrixLayout::C { .. } => (JobEv::Calc, JobEv::Not), + MatrixLayout::F { .. } => (JobEv::Not, JobEv::Calc), } } else { - (EigenVectorFlag::Not, EigenVectorFlag::Not) + (JobEv::Not, JobEv::Not) }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; let mut rwork: Vec> = unsafe { vec_uninit(2 * n as usize) }; @@ -143,11 +143,11 @@ macro_rules! impl_eig_real { // `sgeev`/`dgeev`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not), - MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc), + MatrixLayout::C { .. } => (JobEv::Calc, JobEv::Not), + MatrixLayout::F { .. } => (JobEv::Not, JobEv::Calc), } } else { - (EigenVectorFlag::Not, EigenVectorFlag::Not) + (JobEv::Not, JobEv::Not) }; let mut eig_re: Vec> = unsafe { vec_uninit(n as usize) }; let mut eig_im: Vec> = unsafe { vec_uninit(n as usize) }; diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index 0692f921..08e5f689 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -41,7 +41,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; + let jobz = if calc_v { JobEv::Calc } else { JobEv::Not }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( @@ -100,7 +100,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; + let jobz = if calc_v { JobEv::Calc } else { JobEv::Not }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( diff --git a/lax/src/flags.rs b/lax/src/flags.rs index f3b3ecd0..86d1d985 100644 --- a/lax/src/flags.rs +++ b/lax/src/flags.rs @@ -63,16 +63,16 @@ impl NormType { /// Flag for calculating eigenvectors or not #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] -pub enum EigenVectorFlag { +pub enum JobEv { Calc = b'V', Not = b'N', } -impl EigenVectorFlag { +impl JobEv { pub fn is_calc(&self) -> bool { match self { - EigenVectorFlag::Calc => true, - EigenVectorFlag::Not => false, + JobEv::Calc => true, + JobEv::Not => false, } } @@ -86,7 +86,7 @@ impl EigenVectorFlag { /// To use Fortran LAPACK API in lapack-sys crate pub fn as_ptr(&self) -> *const i8 { - self as *const EigenVectorFlag as *const i8 + self as *const JobEv as *const i8 } } From 62e40d134cf6f3eb07a277a4f43e48078d6ffeee Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 3 Sep 2022 14:40:14 +0900 Subject: [PATCH 4/8] Replace FlagSVD by UVTFlag --- lax/src/flags.rs | 31 ++++++++----------------------- lax/src/svd.rs | 18 ++++++++++-------- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/lax/src/flags.rs b/lax/src/flags.rs index 86d1d985..23625ee3 100644 --- a/lax/src/flags.rs +++ b/lax/src/flags.rs @@ -90,29 +90,6 @@ impl JobEv { } } -#[repr(u8)] -#[derive(Debug, Copy, Clone)] -pub enum FlagSVD { - All = b'A', - // OverWrite = b'O', - // Separately = b'S', - No = b'N', -} - -impl FlagSVD { - pub fn from_bool(calc_uv: bool) -> Self { - if calc_uv { - FlagSVD::All - } else { - FlagSVD::No - } - } - - pub fn as_ptr(&self) -> *const i8 { - self as *const FlagSVD as *const i8 - } -} - /// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. /// /// For an input array of shape *m*×*n*, the following are computed: @@ -128,6 +105,14 @@ pub enum UVTFlag { } impl UVTFlag { + pub fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + UVTFlag::Full + } else { + UVTFlag::None + } + } + pub fn as_ptr(&self) -> *const i8 { self as *const UVTFlag as *const i8 } diff --git a/lax/src/svd.rs b/lax/src/svd.rs index c9e73f89..7807acc9 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -32,24 +32,26 @@ macro_rules! impl_svd { impl SVD_ for $scalar { fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result> { let ju = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), + MatrixLayout::F { .. } => UVTFlag::from_bool(calc_u), + MatrixLayout::C { .. } => UVTFlag::from_bool(calc_vt), }; let jvt = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), + MatrixLayout::F { .. } => UVTFlag::from_bool(calc_vt), + MatrixLayout::C { .. } => UVTFlag::from_bool(calc_u), }; let m = l.lda(); let mut u = match ju { - FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }), - FlagSVD::No => None, + UVTFlag::Full => Some(unsafe { vec_uninit( (m * m) as usize) }), + UVTFlag::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet") }; let n = l.len(); let mut vt = match jvt { - FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }), - FlagSVD::No => None, + UVTFlag::Full => Some(unsafe { vec_uninit( (n * n) as usize) }), + UVTFlag::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet") }; let k = std::cmp::min(m, n); From 950b7dabdba51c53ac0cba8f4868636e4a2d7b13 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 3 Sep 2022 14:51:27 +0900 Subject: [PATCH 5/8] Rename UVTFlag to JobSvd, and `Full` to `All` --- lax/src/flags.rs | 12 ++++++------ lax/src/svd.rs | 16 ++++++++-------- lax/src/svddc.rs | 16 ++++++++-------- ndarray-linalg/src/svddc.rs | 20 ++++++++++---------- ndarray-linalg/tests/svddc.rs | 20 ++++++++++---------- 5 files changed, 42 insertions(+), 42 deletions(-) diff --git a/lax/src/flags.rs b/lax/src/flags.rs index 23625ee3..e8f5b9a3 100644 --- a/lax/src/flags.rs +++ b/lax/src/flags.rs @@ -95,26 +95,26 @@ impl JobEv { /// For an input array of shape *m*×*n*, the following are computed: #[derive(Clone, Copy, Eq, PartialEq)] #[repr(u8)] -pub enum UVTFlag { +pub enum JobSvd { /// All *m* columns of *U* and all *n* rows of *V*ᵀ. - Full = b'A', + All = b'A', /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. Some = b'S', /// No columns of *U* or rows of *V*ᵀ. None = b'N', } -impl UVTFlag { +impl JobSvd { pub fn from_bool(calc_uv: bool) -> Self { if calc_uv { - UVTFlag::Full + JobSvd::All } else { - UVTFlag::None + JobSvd::None } } pub fn as_ptr(&self) -> *const i8 { - self as *const UVTFlag as *const i8 + self as *const JobSvd as *const i8 } } diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 7807acc9..0a509a0e 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -32,25 +32,25 @@ macro_rules! impl_svd { impl SVD_ for $scalar { fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result> { let ju = match l { - MatrixLayout::F { .. } => UVTFlag::from_bool(calc_u), - MatrixLayout::C { .. } => UVTFlag::from_bool(calc_vt), + MatrixLayout::F { .. } => JobSvd::from_bool(calc_u), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt), }; let jvt = match l { - MatrixLayout::F { .. } => UVTFlag::from_bool(calc_vt), - MatrixLayout::C { .. } => UVTFlag::from_bool(calc_u), + MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_u), }; let m = l.lda(); let mut u = match ju { - UVTFlag::Full => Some(unsafe { vec_uninit( (m * m) as usize) }), - UVTFlag::None => None, + JobSvd::All => Some(unsafe { vec_uninit( (m * m) as usize) }), + JobSvd::None => None, _ => unimplemented!("SVD with partial vector output is not supported yet") }; let n = l.len(); let mut vt = match jvt { - UVTFlag::Full => Some(unsafe { vec_uninit( (n * n) as usize) }), - UVTFlag::None => None, + JobSvd::All => Some(unsafe { vec_uninit( (n * n) as usize) }), + JobSvd::None => None, _ => unimplemented!("SVD with partial vector output is not supported yet") }; diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index e31f126f..bb59348f 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -3,7 +3,7 @@ use cauchy::*; use num_traits::{ToPrimitive, Zero}; pub trait SVDDC_: Scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; + fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result>; } macro_rules! impl_svddc { @@ -15,33 +15,33 @@ macro_rules! impl_svddc { }; (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self],) -> Result> { + fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self],) -> Result> { let m = l.lda(); let n = l.len(); let k = m.min(n); let mut s = unsafe { vec_uninit( k as usize) }; let (u_col, vt_row) = match jobz { - UVTFlag::Full | UVTFlag::None => (m, n), - UVTFlag::Some => (k, k), + JobSvd::All | JobSvd::None => (m, n), + JobSvd::Some => (k, k), }; let (mut u, mut vt) = match jobz { - UVTFlag::Full => ( + JobSvd::All => ( Some(unsafe { vec_uninit( (m * m) as usize) }), Some(unsafe { vec_uninit( (n * n) as usize) }), ), - UVTFlag::Some => ( + JobSvd::Some => ( Some(unsafe { vec_uninit( (m * u_col) as usize) }), Some(unsafe { vec_uninit( (n * vt_row) as usize) }), ), - UVTFlag::None => (None, None), + JobSvd::None => (None, None), }; $( // for complex only let mx = n.max(m) as usize; let mn = n.min(m) as usize; let lrwork = match jobz { - UVTFlag::None => 7 * mn, + JobSvd::None => 7 * mn, _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), }; let mut $rwork_ident: Vec> = unsafe { vec_uninit( lrwork) }; diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index ff73407f..0b0ae237 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -3,14 +3,14 @@ use super::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -pub use lax::UVTFlag; +pub use lax::JobSvd; /// Singular-value decomposition of matrix (copying) by divide-and-conquer pub trait SVDDC { type U; type VT; type Sigma; - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)>; + fn svddc(&self, uvt_flag: JobSvd) -> Result<(Option, Self::Sigma, Option)>; } /// Singular-value decomposition of matrix by divide-and-conquer @@ -20,7 +20,7 @@ pub trait SVDDCInto { type Sigma; fn svddc_into( self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)>; } @@ -31,7 +31,7 @@ pub trait SVDDCInplace { type Sigma; fn svddc_inplace( &mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)>; } @@ -44,7 +44,7 @@ where type VT = Array2; type Sigma = Array1; - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)> { + fn svddc(&self, uvt_flag: JobSvd) -> Result<(Option, Self::Sigma, Option)> { self.to_owned().svddc_into(uvt_flag) } } @@ -60,7 +60,7 @@ where fn svddc_into( mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)> { self.svddc_inplace(uvt_flag) } @@ -77,7 +77,7 @@ where fn svddc_inplace( &mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = A::svddc(l, uvt_flag, self.as_allocated_mut()?)?; @@ -85,9 +85,9 @@ where let k = m.min(n); let (u_col, vt_row) = match uvt_flag { - UVTFlag::Full => (m, n), - UVTFlag::Some => (k, k), - UVTFlag::None => (0, 0), + JobSvd::All => (m, n), + JobSvd::Some => (k, k), + JobSvd::None => (0, 0), }; let u = svd_res diff --git a/ndarray-linalg/tests/svddc.rs b/ndarray-linalg/tests/svddc.rs index fb26c8d5..ed28fc30 100644 --- a/ndarray-linalg/tests/svddc.rs +++ b/ndarray-linalg/tests/svddc.rs @@ -1,16 +1,16 @@ use ndarray::*; use ndarray_linalg::*; -fn test(a: &Array2, flag: UVTFlag) { +fn test(a: &Array2, flag: JobSvd) { let (n, m) = a.dim(); let k = n.min(m); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); let mut sm: Array2 = match flag { - UVTFlag::Full => Array::zeros((n, m)), - UVTFlag::Some => Array::zeros((k, k)), - UVTFlag::None => { + JobSvd::All => Array::zeros((n, m)), + JobSvd::Some => Array::zeros((k, k)), + JobSvd::None => { assert!(u.is_none()); assert!(vt.is_none()); return; @@ -33,37 +33,37 @@ macro_rules! test_svd_impl { #[test] fn []() { let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::Full); + test::<$scalar>(&a, JobSvd::All); } #[test] fn []() { let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::Some); + test::<$scalar>(&a, JobSvd::Some); } #[test] fn []() { let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::None); + test::<$scalar>(&a, JobSvd::None); } #[test] fn []() { let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::Full); + test::<$scalar>(&a, JobSvd::All); } #[test] fn []() { let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::Some); + test::<$scalar>(&a, JobSvd::Some); } #[test] fn []() { let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::None); + test::<$scalar>(&a, JobSvd::None); } } }; From b029bfbf1745e9ffcdf98bae7fee7d91cd106326 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 3 Sep 2022 15:10:38 +0900 Subject: [PATCH 6/8] Rename JobEv::{Calc, Not} to {All, None} to match JobSvd --- lax/src/eig.rs | 12 ++++++------ lax/src/eigh.rs | 4 ++-- lax/src/flags.rs | 10 ++++++---- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index 184172a0..d3c07222 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -35,11 +35,11 @@ macro_rules! impl_eig_complex { // eigenvalues are the eigenvalues computed with `A`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (JobEv::Calc, JobEv::Not), - MatrixLayout::F { .. } => (JobEv::Not, JobEv::Calc), + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), } } else { - (JobEv::Not, JobEv::Not) + (JobEv::None, JobEv::None) }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; let mut rwork: Vec> = unsafe { vec_uninit(2 * n as usize) }; @@ -143,11 +143,11 @@ macro_rules! impl_eig_real { // `sgeev`/`dgeev`. let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (JobEv::Calc, JobEv::Not), - MatrixLayout::F { .. } => (JobEv::Not, JobEv::Calc), + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), } } else { - (JobEv::Not, JobEv::Not) + (JobEv::None, JobEv::None) }; let mut eig_re: Vec> = unsafe { vec_uninit(n as usize) }; let mut eig_im: Vec> = unsafe { vec_uninit(n as usize) }; diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index 08e5f689..a9406ee6 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -41,7 +41,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { JobEv::Calc } else { JobEv::Not }; + let jobz = if calc_v { JobEv::All } else { JobEv::None }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( @@ -100,7 +100,7 @@ macro_rules! impl_eigh { ) -> Result> { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { JobEv::Calc } else { JobEv::Not }; + let jobz = if calc_v { JobEv::All } else { JobEv::None }; let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( diff --git a/lax/src/flags.rs b/lax/src/flags.rs index e8f5b9a3..cdf7faf5 100644 --- a/lax/src/flags.rs +++ b/lax/src/flags.rs @@ -64,15 +64,17 @@ impl NormType { #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum JobEv { - Calc = b'V', - Not = b'N', + /// Calculate eigenvectors in addition to eigenvalues + All = b'V', + /// Do not calculate eigenvectors. Only calculate eigenvalues. + None = b'N', } impl JobEv { pub fn is_calc(&self) -> bool { match self { - JobEv::Calc => true, - JobEv::Not => false, + JobEv::All => true, + JobEv::None => false, } } From ef80e1a66a2da56b124bb8b83964a5bcda48f8df Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 3 Sep 2022 15:17:36 +0900 Subject: [PATCH 7/8] Document for Diag --- lax/src/flags.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lax/src/flags.rs b/lax/src/flags.rs index cdf7faf5..4d4f9bd1 100644 --- a/lax/src/flags.rs +++ b/lax/src/flags.rs @@ -120,10 +120,13 @@ impl JobSvd { } } +/// Specify whether input triangular matrix is unit or not #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum Diag { + /// Unit triangular matrix, i.e. all diagonal elements of the matrix are `1` Unit = b'U', + /// Non-unit triangular matrix. Its diagonal elements may be different from `1` NonUnit = b'N', } From 86c61c39556949ff19fd297296a4105608fc1449 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 3 Sep 2022 15:19:22 +0900 Subject: [PATCH 8/8] More auto-derives --- lax/src/flags.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lax/src/flags.rs b/lax/src/flags.rs index 4d4f9bd1..37a11b3c 100644 --- a/lax/src/flags.rs +++ b/lax/src/flags.rs @@ -1,7 +1,7 @@ //! Charactor flags, e.g. `'T'`, used in LAPACK API /// Upper/Lower specification for seveal usages -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum UPLO { Upper = b'U', @@ -22,7 +22,7 @@ impl UPLO { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum Transpose { No = b'N', @@ -37,7 +37,7 @@ impl Transpose { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum NormType { One = b'O', @@ -61,7 +61,7 @@ impl NormType { } /// Flag for calculating eigenvectors or not -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum JobEv { /// Calculate eigenvectors in addition to eigenvalues @@ -95,7 +95,7 @@ impl JobEv { /// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. /// /// For an input array of shape *m*×*n*, the following are computed: -#[derive(Clone, Copy, Eq, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum JobSvd { /// All *m* columns of *U* and all *n* rows of *V*ᵀ. @@ -121,7 +121,7 @@ impl JobSvd { } /// Specify whether input triangular matrix is unit or not -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum Diag { /// Unit triangular matrix, i.e. all diagonal elements of the matrix are `1`