-
Notifications
You must be signed in to change notification settings - Fork 65
/
solveh.rs
161 lines (154 loc) · 5.55 KB
/
solveh.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use crate::{error::*, layout::MatrixLayout, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};
#[cfg_attr(doc, katexit::katexit)]
/// Solve symmetric/hermite indefinite linear problem using the [Bunch-Kaufman diagonal pivoting method][BK].
///
/// For a given symmetric matrix $A$,
/// this method factorizes $A = U^T D U$ or $A = L D L^T$ where
///
/// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices
/// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
///
/// This takes two-step approach based in LAPACK:
///
/// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$
/// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$
///
/// [BK]: https://doi.org/10.2307/2005787
///
pub trait Solveh_: Sized {
/// Factorize input matrix using Bunch-Kaufman diagonal pivoting method
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:---------|:---------|:---------|:---------|
/// | [ssytrf] | [dsytrf] | [chetrf] | [zhetrf] |
///
/// [ssytrf]: https://netlib.org/lapack/explore-html/d0/d14/group__real_s_ycomputational_ga12d2e56511cf7df066712c61d9acec45.html
/// [dsytrf]: https://netlib.org/lapack/explore-html/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html
/// [chetrf]: https://netlib.org/lapack/explore-html/d4/d74/group__complex_h_ecomputational_ga081dd1908e46d064c2bf0a1f6b664b86.html
/// [zhetrf]: https://netlib.org/lapack/explore-html/d3/d80/group__complex16_h_ecomputational_gadc84a5c9818ee12ea19944623131bd52.html
///
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot>;
/// Wrapper of `*sytri` and `*hetri`
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>;
/// Wrapper of `*sytrs` and `*hetrs`
fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
}
macro_rules! impl_solveh {
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
impl Solveh_ for $scalar {
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
let (n, _) = l.size();
let mut ipiv = unsafe { vec_uninit(n as usize) };
if n == 0 {
return Ok(Vec::new());
}
// calc work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe {
$trf(
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&l.lda(),
AsPtr::as_mut_ptr(&mut ipiv),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
&mut info,
)
};
info.as_lapack_result()?;
// actual
let lwork = work_size[0].to_usize().unwrap();
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(lwork) };
unsafe {
$trf(
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&l.lda(),
AsPtr::as_mut_ptr(&mut ipiv),
AsPtr::as_mut_ptr(&mut work),
&(lwork as i32),
&mut info,
)
};
info.as_lapack_result()?;
let ipiv = unsafe { ipiv.assume_init() };
Ok(ipiv)
}
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
let (n, _) = l.size();
let mut info = 0;
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
unsafe {
$tri(
uplo.as_ptr(),
&n,
AsPtr::as_mut_ptr(a),
&l.lda(),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(&mut work),
&mut info,
)
};
info.as_lapack_result()?;
Ok(())
}
fn solveh(
l: MatrixLayout,
uplo: UPLO,
a: &[Self],
ipiv: &Pivot,
b: &mut [Self],
) -> Result<()> {
let (n, _) = l.size();
let mut info = 0;
unsafe {
$trs(
uplo.as_ptr(),
&n,
&1,
AsPtr::as_ptr(a),
&l.lda(),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(b),
&n,
&mut info,
)
};
info.as_lapack_result()?;
Ok(())
}
}
};
} // impl_solveh!
impl_solveh!(
f64,
lapack_sys::dsytrf_,
lapack_sys::dsytri_,
lapack_sys::dsytrs_
);
impl_solveh!(
f32,
lapack_sys::ssytrf_,
lapack_sys::ssytri_,
lapack_sys::ssytrs_
);
impl_solveh!(
c64,
lapack_sys::zhetrf_,
lapack_sys::zhetri_,
lapack_sys::zhetrs_
);
impl_solveh!(
c32,
lapack_sys::chetrf_,
lapack_sys::chetri_,
lapack_sys::chetrs_
);