/
impl_std_ops.rs
331 lines (303 loc) · 12.7 KB
/
impl_std_ops.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use crate::ops::serial::{
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
spmm_csc_prealloc_unchecked, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc_unchecked,
};
use crate::ops::Op;
use nalgebra::allocator::Allocator;
use nalgebra::base::storage::RawStorage;
use nalgebra::constraint::{DimEq, ShapeConstraint};
use nalgebra::{
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dyn, Matrix, OMatrix,
Scalar, U1,
};
use num_traits::{One, Zero};
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
/// Helper macro for implementing binary operators for different matrix types
/// See below for usage.
macro_rules! impl_bin_op {
($trait:ident, $method:ident,
<$($life:lifetime),* $(,)? $($scalar_type:ident $(: $bounds:path)?)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
=>
{
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
where
// Note: The Neg bound is currently required because we delegate e.g.
// Sub to SpAdd with negative coefficients. This is not well-defined for
// unsigned data types.
$($scalar_type: $($bounds + )? Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
{
type Output = $ret;
fn $method(self, $b: $b_type) -> Self::Output {
let $a = self;
$body
}
}
};
}
/// Implements a +/- b for all combinations of reference and owned matrices, for
/// CsrMatrix or CscMatrix.
macro_rules! impl_sp_plus_minus {
// We first match on some special-case syntax, and forward to the actual implementation
($matrix_type:ident, $spadd_fn:ident, +) => {
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
};
($matrix_type:ident, $spadd_fn:ident, -) => {
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
};
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
// If both matrices have the same pattern, then we can immediately re-use it
let pattern = spadd_pattern(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
// We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap();
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
$spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap();
result
});
impl_bin_op!($trait, $method,
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
&a $sign b
});
impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a $sign &b
});
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a $sign &b
});
}
}
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
macro_rules! impl_mul {
($($args:tt)*) => {
impl_bin_op!(Mul, mul, $($args)*);
}
}
/// Implements a + b for all combinations of reference and owned matrices, for
/// CsrMatrix or CscMatrix.
macro_rules! impl_spmm {
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let pattern = $pattern_fn(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap();
$spmm_fn(T::zero(),
&mut result,
T::one(),
Op::NoOp(a),
Op::NoOp(b))
.expect("Internal error: spmm failed (please debug).");
result
});
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
impl_mul!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
}
}
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc_unchecked);
// Need to switch order of operations for CSC pattern
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc_unchecked);
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
macro_rules! impl_concrete_scalar_matrix_mul {
($matrix_type:ident, $($scalar_type:ty),*) => {
// For each concrete scalar type, forward the implementation of scalar * matrix
// to matrix * scalar, which we have already implemented through generics
$(
impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * a });
impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * a });
impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * (*a) });
impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * *a });
)*
}
}
/// Implements multiplication between matrix and scalar for various matrix types
macro_rules! impl_scalar_mul {
($matrix_type: ident) => {
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
let values: Vec<_> = a.values()
.iter()
.map(|v_i| v_i.clone() * b.clone())
.collect();
$matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
});
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
a * &b
});
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
let mut a = a;
for value in a.values_mut() {
*value = b.clone() * value.clone();
}
a
});
impl_mul!(<T>(a: $matrix_type<T>, b: T) -> $matrix_type<T> {
a * &b
});
impl_concrete_scalar_matrix_mul!(
$matrix_type,
i8, i16, i32, i64, isize, f32, f64);
impl<T> MulAssign<T> for $matrix_type<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
fn mul_assign(&mut self, scalar: T) {
for val in self.values_mut() {
*val *= scalar.clone();
}
}
}
impl<'a, T> MulAssign<&'a T> for $matrix_type<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
fn mul_assign(&mut self, scalar: &'a T) {
for val in self.values_mut() {
*val *= scalar.clone();
}
}
}
}
}
impl_scalar_mul!(CsrMatrix);
impl_scalar_mul!(CscMatrix);
macro_rules! impl_neg {
($matrix_type:ident) => {
impl<T> Neg for $matrix_type<T>
where
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
fn neg(mut self) -> Self::Output {
for v_i in self.values_mut() {
*v_i = -v_i.clone();
}
self
}
}
impl<'a, T> Neg for &'a $matrix_type<T>
where
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
fn neg(self) -> Self::Output {
// TODO: This is inefficient. Ideally we'd have a method that would let us
// obtain both the sparsity pattern and values from the matrix,
// and then modify the values before creating a new matrix from the pattern
// and negated values.
-self.clone()
}
}
};
}
impl_neg!(CsrMatrix);
impl_neg!(CscMatrix);
macro_rules! impl_div {
($matrix_type:ident) => {
impl_bin_op!(Div, div, <T: ClosedDiv>(matrix: $matrix_type<T>, scalar: T) -> $matrix_type<T> {
let mut matrix = matrix;
matrix /= scalar;
matrix
});
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: $matrix_type<T>, scalar: &T) -> $matrix_type<T> {
matrix / scalar.clone()
});
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: T) -> $matrix_type<T> {
let new_values = matrix.values()
.iter()
.map(|v_i| v_i.clone() / scalar.clone())
.collect();
$matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
.unwrap()
});
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
matrix / scalar.clone()
});
impl<T> DivAssign<T> for $matrix_type<T>
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
{
fn div_assign(&mut self, scalar: T) {
self.values_mut().iter_mut().for_each(|v_i| *v_i /= scalar.clone());
}
}
impl<'a, T> DivAssign<&'a T> for $matrix_type<T>
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
{
fn div_assign(&mut self, scalar: &'a T) {
*self /= scalar.clone();
}
}
}
}
impl_div!(CsrMatrix);
impl_div!(CscMatrix);
macro_rules! impl_spmm_cs_dense {
($matrix_type_name:ident, $spmm_fn:ident) => {
// Implement ref-ref
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
let (_, ncols) = rhs.shape_generic();
let nrows = Dyn(lhs.nrows());
let mut result = OMatrix::<T, Dyn, C>::zeros_generic(nrows, ncols);
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs));
result
});
// Implement the other combinations by deferring to ref-ref
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
lhs * &rhs
});
impl_spmm_cs_dense!($matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
&lhs * rhs
});
impl_spmm_cs_dense!($matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
&lhs * &rhs
});
};
// Main body of the macro. The first pattern just forwards to this pattern but with
// different arguments
($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident,
|$lhs:ident, $rhs:ident| $body:tt) =>
{
impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type
where
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
R: Dim,
C: Dim,
S: RawStorage<T, R, C>,
DefaultAllocator: Allocator<T, Dyn, C>,
// TODO: Is it possible to simplify these bounds?
ShapeConstraint:
// Bounds so that we can turn OMatrix<T, Dyn, C> into a DMatrixSliceMut
DimEq<U1, <<DefaultAllocator as Allocator<T, Dyn, C>>::Buffer as RawStorage<T, Dyn, C>>::RStride>
+ DimEq<C, Dyn>
+ DimEq<Dyn, <<DefaultAllocator as Allocator<T, Dyn, C>>::Buffer as RawStorage<T, Dyn, C>>::CStride>
// Bounds so that we can turn &Matrix<T, R, C, S> into a DMatrixSlice
+ DimEq<U1, S::RStride>
+ DimEq<R, Dyn>
+ DimEq<Dyn, S::CStride>
{
// We need the column dimension to be generic, so that if RHS is a vector, then
// we also get a vector (and not a matrix)
type Output = OMatrix<T, Dyn, C>;
fn mul(self, rhs: $dense_matrix_type) -> Self::Output {
let $lhs = self;
let $rhs = rhs;
$body
}
}
}
}
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);