Skip to content

Commit

Permalink
Merge pull request #349 from rust-ndarray/lax-tridiagonal
Browse files Browse the repository at this point in the history
Add `LuTridiagonalWork`, merge `Tridiagonal_` into `Lapack`
  • Loading branch information
termoshtt committed Oct 4, 2022
2 parents ad19250 + 4913818 commit 1539577
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 270 deletions.
53 changes: 42 additions & 11 deletions lax/src/lib.rs
Expand Up @@ -84,14 +84,14 @@ extern crate openblas_src as _src;
#[cfg(any(feature = "netlib-system", feature = "netlib-static"))]
extern crate netlib_src as _src;

pub mod error;
pub mod flags;
pub mod layout;

pub mod alloc;
pub mod cholesky;
pub mod eig;
pub mod eigh;
pub mod eigh_generalized;
pub mod error;
pub mod flags;
pub mod layout;
pub mod least_squares;
pub mod opnorm;
pub mod qr;
Expand All @@ -101,16 +101,12 @@ pub mod solveh;
pub mod svd;
pub mod svddc;
pub mod triangular;
pub mod tridiagonal;

mod alloc;
mod tridiagonal;

pub use self::cholesky::*;
pub use self::flags::*;
pub use self::least_squares::LeastSquaresOwned;
pub use self::opnorm::*;
pub use self::svd::{SvdOwned, SvdRef};
pub use self::tridiagonal::*;
pub use self::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal};

use self::{alloc::*, error::*, layout::*};
use cauchy::*;
Expand All @@ -120,7 +116,7 @@ pub type Pivot = Vec<i32>;

#[cfg_attr(doc, katexit::katexit)]
/// Trait for primitive types which implements LAPACK subroutines
pub trait Lapack: Tridiagonal_ {
pub trait Lapack: Scalar {
/// Compute right eigenvalue and eigenvectors for a general matrix
fn eig(
calc_v: bool,
Expand Down Expand Up @@ -306,6 +302,19 @@ pub trait Lapack: Tridiagonal_ {
a: &[Self],
b: &mut [Self],
) -> Result<()>;

/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
/// partial pivoting with row interchanges.
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;

fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;

fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
bl: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()>;
}

macro_rules! impl_lapack {
Expand Down Expand Up @@ -491,6 +500,28 @@ macro_rules! impl_lapack {
use triangular::*;
SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b)
}

fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
use tridiagonal::*;
let work = LuTridiagonalWork::<$s>::new(a.l);
work.eval(a)
}

fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
use tridiagonal::*;
let mut work = RcondTridiagonalWork::<$s>::new(lu.a.l);
work.calc(lu)
}

fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
bl: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()> {
use tridiagonal::*;
SolveTridiagonalImpl::solve_tridiagonal(lu, bl, t, b)
}
}
};
}
Expand Down
259 changes: 0 additions & 259 deletions lax/src/tridiagonal.rs

This file was deleted.

0 comments on commit 1539577

Please sign in to comment.