-
Notifications
You must be signed in to change notification settings - Fork 65
/
solveh.rs
140 lines (133 loc) · 4.42 KB
/
solveh.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use crate::{error::*, layout::MatrixLayout, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};
/// Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method.
///
/// For a given symmetric matrix $A$,
/// this method factorizes $A = U^T D U$ or $A = L D L^T$ where
///
/// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices
/// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
///
pub trait Solveh_: Sized {
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot>;
/// Wrapper of `*sytri` and `*hetri`
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>;
/// Wrapper of `*sytrs` and `*hetrs`
fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
}
macro_rules! impl_solveh {
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
impl Solveh_ for $scalar {
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
let (n, _) = l.size();
let mut ipiv = unsafe { vec_uninit(n as usize) };
if n == 0 {
return Ok(Vec::new());
}
// calc work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe {
$trf(
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&l.lda(),
AsPtr::as_mut_ptr(&mut ipiv),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
&mut info,
)
};
info.as_lapack_result()?;
// actual
let lwork = work_size[0].to_usize().unwrap();
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(lwork) };
unsafe {
$trf(
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&l.lda(),
AsPtr::as_mut_ptr(&mut ipiv),
AsPtr::as_mut_ptr(&mut work),
&(lwork as i32),
&mut info,
)
};
info.as_lapack_result()?;
let ipiv = unsafe { ipiv.assume_init() };
Ok(ipiv)
}
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
let (n, _) = l.size();
let mut info = 0;
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
unsafe {
$tri(
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&l.lda(),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(&mut work),
&mut info,
)
};
info.as_lapack_result()?;
Ok(())
}
fn solveh(
l: MatrixLayout,
uplo: UPLO,
a: &[Self],
ipiv: &Pivot,
b: &mut [Self],
) -> Result<()> {
let (n, _) = l.size();
let mut info = 0;
unsafe {
$trs(
uplo.as_ptr(),
&n,
&1,
AsPtr::as_ptr(a),
&l.lda(),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(b),
&n,
&mut info,
)
};
info.as_lapack_result()?;
Ok(())
}
}
};
} // impl_solveh!
impl_solveh!(
f64,
lapack_sys::dsytrf_,
lapack_sys::dsytri_,
lapack_sys::dsytrs_
);
impl_solveh!(
f32,
lapack_sys::ssytrf_,
lapack_sys::ssytri_,
lapack_sys::ssytrs_
);
impl_solveh!(
c64,
lapack_sys::zhetrf_,
lapack_sys::zhetri_,
lapack_sys::zhetrs_
);
impl_solveh!(
c32,
lapack_sys::chetrf_,
lapack_sys::chetri_,
lapack_sys::chetrs_
);