-
Notifications
You must be signed in to change notification settings - Fork 65
/
cholesky.rs
119 lines (111 loc) · 3.5 KB
/
cholesky.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! Cholesky decomposition
use super::*;
use crate::{error::*, layout::*};
use cauchy::*;
pub trait Cholesky_: Sized {
/// Cholesky: wrapper of `*potrf`
///
/// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
/// Wrapper of `*potri`
///
/// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
/// Wrapper of `*potrs`
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
}
macro_rules! impl_cholesky {
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
impl Cholesky_ for $scalar {
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
if matches!(l, MatrixLayout::C { .. }) {
square_transpose(l, a);
}
let mut info = 0;
unsafe {
$trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info);
}
info.as_lapack_result()?;
if matches!(l, MatrixLayout::C { .. }) {
square_transpose(l, a);
}
Ok(())
}
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
if matches!(l, MatrixLayout::C { .. }) {
square_transpose(l, a);
}
let mut info = 0;
unsafe {
$tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info);
}
info.as_lapack_result()?;
if matches!(l, MatrixLayout::C { .. }) {
square_transpose(l, a);
}
Ok(())
}
fn solve_cholesky(
l: MatrixLayout,
mut uplo: UPLO,
a: &[Self],
b: &mut [Self],
) -> Result<()> {
let (n, _) = l.size();
let nrhs = 1;
let mut info = 0;
if matches!(l, MatrixLayout::C { .. }) {
uplo = uplo.t();
for val in b.iter_mut() {
*val = val.conj();
}
}
unsafe {
$trs(
uplo.as_ptr(),
&n,
&nrhs,
AsPtr::as_ptr(a),
&l.lda(),
AsPtr::as_mut_ptr(b),
&n,
&mut info,
);
}
info.as_lapack_result()?;
if matches!(l, MatrixLayout::C { .. }) {
for val in b.iter_mut() {
*val = val.conj();
}
}
Ok(())
}
}
};
} // end macro_rules
impl_cholesky!(
f64,
lapack_sys::dpotrf_,
lapack_sys::dpotri_,
lapack_sys::dpotrs_
);
impl_cholesky!(
f32,
lapack_sys::spotrf_,
lapack_sys::spotri_,
lapack_sys::spotrs_
);
impl_cholesky!(
c64,
lapack_sys::zpotrf_,
lapack_sys::zpotri_,
lapack_sys::zpotrs_
);
impl_cholesky!(
c32,
lapack_sys::cpotrf_,
lapack_sys::cpotri_,
lapack_sys::cpotrs_
);