Skip to content

Commit

Permalink
Gather enum definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Sep 2, 2022
1 parent 08aae3b commit 9109a2f
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 147 deletions.
145 changes: 145 additions & 0 deletions lax/src/flags.rs
@@ -0,0 +1,145 @@
/// Upper/Lower specification for seveal usages
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum UPLO {
Upper = b'U',
Lower = b'L',
}

impl UPLO {
pub fn t(self) -> Self {
match self {
UPLO::Upper => UPLO::Lower,
UPLO::Lower => UPLO::Upper,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const UPLO as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Transpose {
No = b'N',
Transpose = b'T',
Hermite = b'C',
}

impl Transpose {
/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const Transpose as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum NormType {
One = b'O',
Infinity = b'I',
Frobenius = b'F',
}

impl NormType {
pub fn transpose(self) -> Self {
match self {
NormType::One => NormType::Infinity,
NormType::Infinity => NormType::One,
NormType::Frobenius => NormType::Frobenius,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const NormType as *const i8
}
}

/// Flag for calculating eigenvectors or not
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum EigenVectorFlag {
Calc = b'V',
Not = b'N',
}

impl EigenVectorFlag {
pub fn is_calc(&self) -> bool {
match self {
EigenVectorFlag::Calc => true,
EigenVectorFlag::Not => false,
}
}

pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if self.is_calc() {
Some(f())
} else {
None
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const EigenVectorFlag as *const i8
}
}

#[repr(u8)]
#[derive(Debug, Copy, Clone)]
pub enum FlagSVD {
All = b'A',
// OverWrite = b'O',
// Separately = b'S',
No = b'N',
}

impl FlagSVD {
pub fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
FlagSVD::All
} else {
FlagSVD::No
}
}

pub fn as_ptr(&self) -> *const i8 {
self as *const FlagSVD as *const i8
}
}

/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
///
/// For an input array of shape *m*×*n*, the following are computed:
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum UVTFlag {
/// All *m* columns of *U* and all *n* rows of *V*ᵀ.
Full = b'A',
/// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ.
Some = b'S',
/// No columns of *U* or rows of *V*ᵀ.
None = b'N',
}

impl UVTFlag {
pub fn as_ptr(&self) -> *const i8 {
self as *const UVTFlag as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Diag {
Unit = b'U',
NonUnit = b'N',
}

impl Diag {
pub fn as_ptr(&self) -> *const i8 {
self as *const Diag as *const i8
}
}
92 changes: 2 additions & 90 deletions lax/src/lib.rs
Expand Up @@ -74,6 +74,7 @@ pub mod layout;
mod cholesky;
mod eig;
mod eigh;
mod flags;
mod least_squares;
mod opnorm;
mod qr;
Expand All @@ -88,6 +89,7 @@ mod tridiagonal;
pub use self::cholesky::*;
pub use self::eig::*;
pub use self::eigh::*;
pub use self::flags::*;
pub use self::least_squares::*;
pub use self::opnorm::*;
pub use self::qr::*;
Expand Down Expand Up @@ -173,96 +175,6 @@ impl<T> VecAssumeInit for Vec<MaybeUninit<T>> {
}
}

/// Upper/Lower specification for seveal usages
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum UPLO {
Upper = b'U',
Lower = b'L',
}

impl UPLO {
pub fn t(self) -> Self {
match self {
UPLO::Upper => UPLO::Lower,
UPLO::Lower => UPLO::Upper,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const UPLO as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Transpose {
No = b'N',
Transpose = b'T',
Hermite = b'C',
}

impl Transpose {
/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const Transpose as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum NormType {
One = b'O',
Infinity = b'I',
Frobenius = b'F',
}

impl NormType {
pub fn transpose(self) -> Self {
match self {
NormType::One => NormType::Infinity,
NormType::Infinity => NormType::One,
NormType::Frobenius => NormType::Frobenius,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const NormType as *const i8
}
}

/// Flag for calculating eigenvectors or not
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum EigenVectorFlag {
Calc = b'V',
Not = b'N',
}

impl EigenVectorFlag {
pub fn is_calc(&self) -> bool {
match self {
EigenVectorFlag::Calc => true,
EigenVectorFlag::Not => false,
}
}

pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if self.is_calc() {
Some(f())
} else {
None
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const EigenVectorFlag as *const i8
}
}

/// Create a vector without initialization
///
/// Safety
Expand Down
25 changes: 1 addition & 24 deletions lax/src/svd.rs
@@ -1,32 +1,9 @@
//! Singular-value decomposition

use crate::{error::*, layout::MatrixLayout, *};
use super::{error::*, layout::*, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

#[repr(u8)]
#[derive(Debug, Copy, Clone)]
enum FlagSVD {
All = b'A',
// OverWrite = b'O',
// Separately = b'S',
No = b'N',
}

impl FlagSVD {
fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
FlagSVD::All
} else {
FlagSVD::No
}
}

fn as_ptr(&self) -> *const i8 {
self as *const FlagSVD as *const i8
}
}

/// Result of SVD
pub struct SVDOutput<A: Scalar> {
/// diagonal values
Expand Down
20 changes: 0 additions & 20 deletions lax/src/svddc.rs
Expand Up @@ -2,26 +2,6 @@ use crate::{error::*, layout::MatrixLayout, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
///
/// For an input array of shape *m*×*n*, the following are computed:
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum UVTFlag {
/// All *m* columns of *U* and all *n* rows of *V*ᵀ.
Full = b'A',
/// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ.
Some = b'S',
/// No columns of *U* or rows of *V*ᵀ.
None = b'N',
}

impl UVTFlag {
fn as_ptr(&self) -> *const i8 {
self as *const UVTFlag as *const i8
}
}

pub trait SVDDC_: Scalar {
fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
}
Expand Down
13 changes: 0 additions & 13 deletions lax/src/triangular.rs
Expand Up @@ -3,19 +3,6 @@
use crate::{error::*, layout::*, *};
use cauchy::*;

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Diag {
Unit = b'U',
NonUnit = b'N',
}

impl Diag {
fn as_ptr(&self) -> *const i8 {
self as *const Diag as *const i8
}
}

/// Wraps `*trtri` and `*trtrs`
pub trait Triangular_: Scalar {
fn solve_triangular(
Expand Down

0 comments on commit 9109a2f

Please sign in to comment.