diff --git a/lax/src/flags.rs b/lax/src/flags.rs new file mode 100644 index 00000000..a0e146db --- /dev/null +++ b/lax/src/flags.rs @@ -0,0 +1,145 @@ +/// Upper/Lower specification for seveal usages +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum UPLO { + Upper = b'U', + Lower = b'L', +} + +impl UPLO { + pub fn t(self) -> Self { + match self { + UPLO::Upper => UPLO::Lower, + UPLO::Lower => UPLO::Upper, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const UPLO as *const i8 + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} + +impl Transpose { + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const Transpose as *const i8 + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum NormType { + One = b'O', + Infinity = b'I', + Frobenius = b'F', +} + +impl NormType { + pub fn transpose(self) -> Self { + match self { + NormType::One => NormType::Infinity, + NormType::Infinity => NormType::One, + NormType::Frobenius => NormType::Frobenius, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const NormType as *const i8 + } +} + +/// Flag for calculating eigenvectors or not +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum EigenVectorFlag { + Calc = b'V', + Not = b'N', +} + +impl EigenVectorFlag { + pub fn is_calc(&self) -> bool { + match self { + EigenVectorFlag::Calc => true, + EigenVectorFlag::Not => false, + } + } + + pub fn then T>(&self, f: F) -> Option { + if self.is_calc() { + Some(f()) + } else { + None + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const i8 { + self as *const EigenVectorFlag as *const i8 + } +} + +#[repr(u8)] +#[derive(Debug, Copy, Clone)] +pub enum FlagSVD { + All = b'A', + // OverWrite = b'O', + // Separately = b'S', + No = b'N', +} + +impl FlagSVD { + pub fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + FlagSVD::All + } else { + FlagSVD::No + } + } + + pub fn as_ptr(&self) -> *const i8 { + self as *const FlagSVD as *const i8 + } +} + +/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. +/// +/// For an input array of shape *m*×*n*, the following are computed: +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(u8)] +pub enum UVTFlag { + /// All *m* columns of *U* and all *n* rows of *V*ᵀ. + Full = b'A', + /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. + Some = b'S', + /// No columns of *U* or rows of *V*ᵀ. + None = b'N', +} + +impl UVTFlag { + pub fn as_ptr(&self) -> *const i8 { + self as *const UVTFlag as *const i8 + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Diag { + Unit = b'U', + NonUnit = b'N', +} + +impl Diag { + pub fn as_ptr(&self) -> *const i8 { + self as *const Diag as *const i8 + } +} diff --git a/lax/src/lib.rs b/lax/src/lib.rs index c8d2264d..ee611893 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -74,6 +74,7 @@ pub mod layout; mod cholesky; mod eig; mod eigh; +mod flags; mod least_squares; mod opnorm; mod qr; @@ -85,19 +86,20 @@ mod svddc; mod triangular; mod tridiagonal; -pub use self::cholesky::*; -pub use self::eig::*; -pub use self::eigh::*; -pub use self::least_squares::*; -pub use self::opnorm::*; -pub use self::qr::*; -pub use self::rcond::*; -pub use self::solve::*; -pub use self::solveh::*; -pub use self::svd::*; -pub use self::svddc::*; -pub use self::triangular::*; -pub use self::tridiagonal::*; +pub use cholesky::*; +pub use eig::*; +pub use eigh::*; +pub use flags::*; +pub use least_squares::*; +pub use opnorm::*; +pub use qr::*; +pub use rcond::*; +pub use solve::*; +pub use solveh::*; +pub use svd::*; +pub use svddc::*; +pub use triangular::*; +pub use tridiagonal::*; use cauchy::*; use std::mem::MaybeUninit; @@ -173,96 +175,6 @@ impl VecAssumeInit for Vec> { } } -/// Upper/Lower specification for seveal usages -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum UPLO { - Upper = b'U', - Lower = b'L', -} - -impl UPLO { - pub fn t(self) -> Self { - match self { - UPLO::Upper => UPLO::Lower, - UPLO::Lower => UPLO::Upper, - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const UPLO as *const i8 - } -} - -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Transpose { - No = b'N', - Transpose = b'T', - Hermite = b'C', -} - -impl Transpose { - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const Transpose as *const i8 - } -} - -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum NormType { - One = b'O', - Infinity = b'I', - Frobenius = b'F', -} - -impl NormType { - pub fn transpose(self) -> Self { - match self { - NormType::One => NormType::Infinity, - NormType::Infinity => NormType::One, - NormType::Frobenius => NormType::Frobenius, - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const NormType as *const i8 - } -} - -/// Flag for calculating eigenvectors or not -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum EigenVectorFlag { - Calc = b'V', - Not = b'N', -} - -impl EigenVectorFlag { - pub fn is_calc(&self) -> bool { - match self { - EigenVectorFlag::Calc => true, - EigenVectorFlag::Not => false, - } - } - - pub fn then T>(&self, f: F) -> Option { - if self.is_calc() { - Some(f()) - } else { - None - } - } - - /// To use Fortran LAPACK API in lapack-sys crate - pub fn as_ptr(&self) -> *const i8 { - self as *const EigenVectorFlag as *const i8 - } -} - /// Create a vector without initialization /// /// Safety diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 8c731c7a..c9e73f89 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -1,32 +1,9 @@ //! Singular-value decomposition -use crate::{error::*, layout::MatrixLayout, *}; +use super::{error::*, layout::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -#[repr(u8)] -#[derive(Debug, Copy, Clone)] -enum FlagSVD { - All = b'A', - // OverWrite = b'O', - // Separately = b'S', - No = b'N', -} - -impl FlagSVD { - fn from_bool(calc_uv: bool) -> Self { - if calc_uv { - FlagSVD::All - } else { - FlagSVD::No - } - } - - fn as_ptr(&self) -> *const i8 { - self as *const FlagSVD as *const i8 - } -} - /// Result of SVD pub struct SVDOutput { /// diagonal values diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index f956d848..e31f126f 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -2,26 +2,6 @@ use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. -/// -/// For an input array of shape *m*×*n*, the following are computed: -#[derive(Clone, Copy, Eq, PartialEq)] -#[repr(u8)] -pub enum UVTFlag { - /// All *m* columns of *U* and all *n* rows of *V*ᵀ. - Full = b'A', - /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. - Some = b'S', - /// No columns of *U* or rows of *V*ᵀ. - None = b'N', -} - -impl UVTFlag { - fn as_ptr(&self) -> *const i8 { - self as *const UVTFlag as *const i8 - } -} - pub trait SVDDC_: Scalar { fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index e8825758..14f29807 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -3,19 +3,6 @@ use crate::{error::*, layout::*, *}; use cauchy::*; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Diag { - Unit = b'U', - NonUnit = b'N', -} - -impl Diag { - fn as_ptr(&self) -> *const i8 { - self as *const Diag as *const i8 - } -} - /// Wraps `*trtri` and `*trtrs` pub trait Triangular_: Scalar { fn solve_triangular(