-
Notifications
You must be signed in to change notification settings - Fork 65
/
svd.rs
137 lines (124 loc) · 5.05 KB
/
svd.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//! Singular-value decomposition
use super::{error::*, layout::*, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};
/// Result of SVD
pub struct SVDOutput<A: Scalar> {
/// diagonal values
pub s: Vec<A::Real>,
/// Unitary matrix for destination space
pub u: Option<Vec<A>>,
/// Unitary matrix for departure space
pub vt: Option<Vec<A>>,
}
#[cfg_attr(doc, katexit::katexit)]
/// Singular value decomposition
pub trait SVD_: Scalar {
/// Compute singular value decomposition $A = U \Sigma V^T$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:-------|:-------|:-------|:-------|
/// | sgesvd | dgesvd | cgesvd | zgesvd |
///
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self])
-> Result<SVDOutput<Self>>;
}
macro_rules! impl_svd {
(@real, $scalar:ty, $gesvd:path) => {
impl_svd!(@body, $scalar, $gesvd, );
};
(@complex, $scalar:ty, $gesvd:path) => {
impl_svd!(@body, $scalar, $gesvd, rwork);
};
(@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => {
impl SVD_ for $scalar {
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> {
let ju = match l {
MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
};
let jvt = match l {
MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
};
let m = l.lda();
let mut u = match ju {
JobSvd::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
JobSvd::None => None,
_ => unimplemented!("SVD with partial vector output is not supported yet")
};
let n = l.len();
let mut vt = match jvt {
JobSvd::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
JobSvd::None => None,
_ => unimplemented!("SVD with partial vector output is not supported yet")
};
let k = std::cmp::min(m, n);
let mut s = unsafe { vec_uninit( k as usize) };
$(
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit( 5 * k as usize) };
)*
// eval work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe {
$gesvd(
ju.as_ptr(),
jvt.as_ptr(),
&m,
&n,
AsPtr::as_mut_ptr(a),
&m,
AsPtr::as_mut_ptr(&mut s),
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
&m,
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;
// calc
let lwork = work_size[0].to_usize().unwrap();
let mut work: Vec<MaybeUninit<Self>> = unsafe { vec_uninit( lwork) };
unsafe {
$gesvd(
ju.as_ptr(),
jvt.as_ptr() ,
&m,
&n,
AsPtr::as_mut_ptr(a),
&m,
AsPtr::as_mut_ptr(&mut s),
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
&m,
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
&n,
AsPtr::as_mut_ptr(&mut work),
&(lwork as i32),
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;
let s = unsafe { s.assume_init() };
let u = u.map(|v| unsafe { v.assume_init() });
let vt = vt.map(|v| unsafe { v.assume_init() });
match l {
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
}
}
}
};
} // impl_svd!
impl_svd!(@real, f64, lapack_sys::dgesvd_);
impl_svd!(@real, f32, lapack_sys::sgesvd_);
impl_svd!(@complex, c64, lapack_sys::zgesvd_);
impl_svd!(@complex, c32, lapack_sys::cgesvd_);