Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise enum namings #331

Merged
merged 8 commits into from Sep 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions lax/src/eig.rs
Expand Up @@ -35,11 +35,11 @@ macro_rules! impl_eig_complex {
// eigenvalues are the eigenvalues computed with `A`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
}
} else {
(EigenVectorFlag::Not, EigenVectorFlag::Not)
(JobEv::None, JobEv::None)
};
let mut eigs: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
let mut rwork: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(2 * n as usize) };
Expand Down Expand Up @@ -143,11 +143,11 @@ macro_rules! impl_eig_real {
// `sgeev`/`dgeev`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
}
} else {
(EigenVectorFlag::Not, EigenVectorFlag::Not)
(JobEv::None, JobEv::None)
};
let mut eig_re: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
let mut eig_im: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
Expand Down
4 changes: 2 additions & 2 deletions lax/src/eigh.rs
Expand Up @@ -41,7 +41,7 @@ macro_rules! impl_eigh {
) -> Result<Vec<Self::Real>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(n as usize) };

$(
Expand Down Expand Up @@ -100,7 +100,7 @@ macro_rules! impl_eigh {
) -> Result<Vec<Self::Real>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(n as usize) };

$(
Expand Down
137 changes: 137 additions & 0 deletions lax/src/flags.rs
@@ -0,0 +1,137 @@
//! Charactor flags, e.g. `'T'`, used in LAPACK API

/// Upper/Lower specification for seveal usages
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum UPLO {
Upper = b'U',
Lower = b'L',
}

impl UPLO {
pub fn t(self) -> Self {
match self {
UPLO::Upper => UPLO::Lower,
UPLO::Lower => UPLO::Upper,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const UPLO as *const i8
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Transpose {
No = b'N',
Transpose = b'T',
Hermite = b'C',
}

impl Transpose {
/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const Transpose as *const i8
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum NormType {
One = b'O',
Infinity = b'I',
Frobenius = b'F',
}

impl NormType {
pub fn transpose(self) -> Self {
match self {
NormType::One => NormType::Infinity,
NormType::Infinity => NormType::One,
NormType::Frobenius => NormType::Frobenius,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const NormType as *const i8
}
}

/// Flag for calculating eigenvectors or not
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum JobEv {
/// Calculate eigenvectors in addition to eigenvalues
All = b'V',
/// Do not calculate eigenvectors. Only calculate eigenvalues.
None = b'N',
}

impl JobEv {
pub fn is_calc(&self) -> bool {
match self {
JobEv::All => true,
JobEv::None => false,
}
}

pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if self.is_calc() {
Some(f())
} else {
None
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const JobEv as *const i8
}
}

/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
///
/// For an input array of shape *m*×*n*, the following are computed:
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum JobSvd {
/// All *m* columns of *U* and all *n* rows of *V*ᵀ.
All = b'A',
/// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ.
Some = b'S',
/// No columns of *U* or rows of *V*ᵀ.
None = b'N',
}

impl JobSvd {
pub fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
JobSvd::All
} else {
JobSvd::None
}
}

pub fn as_ptr(&self) -> *const i8 {
self as *const JobSvd as *const i8
}
}

/// Specify whether input triangular matrix is unit or not
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Diag {
/// Unit triangular matrix, i.e. all diagonal elements of the matrix are `1`
Unit = b'U',
/// Non-unit triangular matrix. Its diagonal elements may be different from `1`
NonUnit = b'N',
}

impl Diag {
pub fn as_ptr(&self) -> *const i8 {
self as *const Diag as *const i8
}
}
92 changes: 2 additions & 90 deletions lax/src/lib.rs
Expand Up @@ -69,6 +69,7 @@ extern crate openblas_src as _src;
extern crate netlib_src as _src;

pub mod error;
pub mod flags;
pub mod layout;

mod cholesky;
Expand All @@ -88,6 +89,7 @@ mod tridiagonal;
pub use self::cholesky::*;
pub use self::eig::*;
pub use self::eigh::*;
pub use self::flags::*;
pub use self::least_squares::*;
pub use self::opnorm::*;
pub use self::qr::*;
Expand Down Expand Up @@ -173,96 +175,6 @@ impl<T> VecAssumeInit for Vec<MaybeUninit<T>> {
}
}

/// Upper/Lower specification for seveal usages
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum UPLO {
Upper = b'U',
Lower = b'L',
}

impl UPLO {
pub fn t(self) -> Self {
match self {
UPLO::Upper => UPLO::Lower,
UPLO::Lower => UPLO::Upper,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const UPLO as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Transpose {
No = b'N',
Transpose = b'T',
Hermite = b'C',
}

impl Transpose {
/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const Transpose as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum NormType {
One = b'O',
Infinity = b'I',
Frobenius = b'F',
}

impl NormType {
pub fn transpose(self) -> Self {
match self {
NormType::One => NormType::Infinity,
NormType::Infinity => NormType::One,
NormType::Frobenius => NormType::Frobenius,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const NormType as *const i8
}
}

/// Flag for calculating eigenvectors or not
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum EigenVectorFlag {
Calc = b'V',
Not = b'N',
}

impl EigenVectorFlag {
pub fn is_calc(&self) -> bool {
match self {
EigenVectorFlag::Calc => true,
EigenVectorFlag::Not => false,
}
}

pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if self.is_calc() {
Some(f())
} else {
None
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const EigenVectorFlag as *const i8
}
}

/// Create a vector without initialization
///
/// Safety
Expand Down
43 changes: 11 additions & 32 deletions lax/src/svd.rs
@@ -1,32 +1,9 @@
//! Singular-value decomposition

use crate::{error::*, layout::MatrixLayout, *};
use super::{error::*, layout::*, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

#[repr(u8)]
#[derive(Debug, Copy, Clone)]
enum FlagSVD {
All = b'A',
// OverWrite = b'O',
// Separately = b'S',
No = b'N',
}

impl FlagSVD {
fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
FlagSVD::All
} else {
FlagSVD::No
}
}

fn as_ptr(&self) -> *const i8 {
self as *const FlagSVD as *const i8
}
}

/// Result of SVD
pub struct SVDOutput<A: Scalar> {
/// diagonal values
Expand Down Expand Up @@ -55,24 +32,26 @@ macro_rules! impl_svd {
impl SVD_ for $scalar {
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> {
let ju = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
};
let jvt = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
};

let m = l.lda();
let mut u = match ju {
FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
FlagSVD::No => None,
JobSvd::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
JobSvd::None => None,
_ => unimplemented!("SVD with partial vector output is not supported yet")
};

let n = l.len();
let mut vt = match jvt {
FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
FlagSVD::No => None,
JobSvd::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
JobSvd::None => None,
_ => unimplemented!("SVD with partial vector output is not supported yet")
};

let k = std::cmp::min(m, n);
Expand Down