/
bcoo.py
2613 lines (2239 loc) · 118 KB
/
bcoo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BCOO (Bached coordinate format) matrix object and associated primitives."""
from __future__ import annotations
import functools
from functools import partial
import operator
from typing import Any, List, NamedTuple, Optional, Protocol, Sequence, Tuple, Union
import warnings
import numpy as np
import jax
from jax import core
from jax import lax
from jax import tree_util
from jax import vmap
from jax.config import config
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import (
_broadcasting_vmap, _count_stored_elements,
_dot_general_validated_shape, CuSparseEfficiencyWarning,
SparseEfficiencyError, SparseEfficiencyWarning)
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
import jax.numpy as jnp
from jax.interpreters import ad
from jax.util import safe_zip, unzip2, split_list
from jax._src import api_util
from jax._src.lax.lax import (
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.setops import _unique
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import canonicalize_axis
Dtype = Any
Shape = Tuple[int, ...]
#----------------------------------------------------------------------
# BCOO primitives: batched extension of COO.
def _bcoo_set_nse(mat: BCOO, nse: int) -> BCOO:
"""Return a copy of `mat` with the specified nse.
Note that if nse < mat.nse, this will potentially discard data.
"""
nse = operator.index(nse)
assert nse >= 0
if mat.nse == nse:
return mat
if nse <= mat.nse:
data = mat.data[(*(slice(None) for i in range(mat.n_batch)), slice(nse))]
indices = mat.indices[..., :nse, :]
else:
data = jnp.zeros_like(mat.data, shape=(*mat.data.shape[:mat.n_batch], nse, *mat.data.shape[mat.n_batch + 1:]))
data = data.at[(*(slice(None) for i in range(mat.n_batch)), slice(mat.nse))].set(mat.data)
indices = jnp.zeros_like(mat.indices, shape=(*mat.indices.shape[:-2], nse, mat.indices.shape[-1]))
indices = indices.at[..., :mat.nse, :].set(mat.indices)
indices = indices.at[..., mat.nse:, :].set(jnp.array(mat.shape[mat.n_batch:mat.n_batch + mat.n_sparse],
dtype=indices.dtype))
return BCOO((data, indices), shape=mat.shape,
indices_sorted=mat.indices_sorted,
unique_indices=mat.unique_indices)
# TODO(jakevdp) this can be problematic when used with autodiff; see
# https://github.com/google/jax/issues/10163. Should this be a primitive?
# Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument.
def bcoo_eliminate_zeros(mat: BCOO, nse: Optional[int] = None) -> BCOO:
data, indices, shape = mat.data, mat.indices, mat.shape
props = _validate_bcoo(data, indices, shape)
mask = (data == 0).all(tuple(range(props.n_batch + 1, data.ndim)))
dims_to_contract = tuple(i for i, s in enumerate(indices.shape[:props.n_batch]) if s == 1)
mask = mask.all(dims_to_contract, keepdims=True)
fill_value = jnp.array(shape[props.n_batch:props.n_batch + props.n_sparse], dtype=indices.dtype)
f = lambda i, m: jnp.where(m[:, None], fill_value[None, :], i)
for _ in range(props.n_batch):
f = vmap(f)
indices = f(indices, mask)
return bcoo_sum_duplicates(BCOO((data, indices), shape=shape), nse=nse)
class BCOOProperties(NamedTuple):
n_batch: int
n_sparse: int
n_dense: int
nse: int
class BCOOInfo(NamedTuple):
shape: Shape
indices_sorted: bool = False
unique_indices: bool = False
class Buffer(Protocol):
@property
def shape(self) -> Tuple[int, ...]: ...
@property
def dtype(self) -> DType: ...
def _validate_bcoo(data: Buffer, indices: Buffer, shape: Sequence[int]) -> BCOOProperties:
props = _validate_bcoo_indices(indices, shape)
n_batch, n_sparse, n_dense, nse = props
shape = tuple(shape)
if any(s1 not in (1, s2) for s1, s2 in safe_zip(data.shape[:n_batch], shape[:n_batch])):
raise ValueError(f"data batch dimensions not compatible for {data.shape=}, {shape=}")
if data.shape[n_batch:] != (nse,) + shape[n_batch + n_sparse:]:
raise ValueError(f"Invalid {data.shape=} for {nse=}, {n_batch=}, {n_dense=}")
return props
def _validate_bcoo_indices(indices: Buffer, shape: Sequence[int]) -> BCOOProperties:
assert jnp.issubdtype(indices.dtype, jnp.integer)
shape = tuple(shape)
nse, n_sparse = indices.shape[-2:]
n_batch = len(indices.shape) - 2
n_dense = len(shape) - n_batch - n_sparse
assert n_dense >= 0
if any(s1 not in (1, s2) for s1, s2 in safe_zip(indices.shape[:n_batch], shape[:n_batch])):
raise ValueError(f"indices batch dimensions not compatible for {indices.shape=}, {shape=}")
if indices.shape[n_batch:] != (nse, n_sparse):
raise ValueError(f"Invalid ={indices.shape=} for {nse=}, {n_batch=}, {n_dense=}")
return BCOOProperties(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense, nse=nse)
def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array]:
"""Given BCOO (indices), return BCSR (indices, indptr)."""
n_batch, n_sparse, _, _ = _validate_bcoo_indices(indices, shape)
if n_sparse != 2:
raise ValueError("Must have 2 sparse dimensions to be converted to BCSR.")
n_rows = shape[n_batch]
def get_ptr(i):
indptr = jnp.zeros(n_rows + 1, index_dtype)
return indptr.at[1:].set(jnp.cumsum(
jnp.bincount(i, length=n_rows).astype(index_dtype)))
for _ in range(n_batch):
get_ptr = vmap(get_ptr)
return indices[..., 1], get_ptr(indices[..., 0])
#----------------------------------------------------------------------
# bcoo_todense
bcoo_todense_p = core.Primitive('bcoo_todense')
def bcoo_todense(mat: BCOO) -> Array:
"""Convert batched sparse matrix to a dense matrix.
Args:
mat: BCOO matrix.
Returns:
mat_dense: dense version of ``mat``.
"""
return _bcoo_todense(mat.data, mat.indices, spinfo=mat._info)
def _bcoo_todense(data: Array, indices: Array, *, spinfo: BCOOInfo) -> Array:
"""Convert batched sparse matrix to a dense matrix.
Args:
data : array of shape ``batch_dims + (nse,) + block_dims``.
indices : array of shape ``batch_dims + (n_sparse, nse)``
spinfo : BCOOInfo. In particular, this includes the shape
of the matrix, which is equal to ``batch_dims + sparse_dims + block_dims``
where ``len(sparse_dims) == n_sparse``
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return bcoo_todense_p.bind(jnp.asarray(data), jnp.asarray(indices), spinfo=spinfo)
@bcoo_todense_p.def_impl
def _bcoo_todense_impl(data, indices, *, spinfo):
shape = spinfo.shape
n_batch, n_sparse, _, _ = _validate_bcoo(data, indices, shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(shape[:n_batch], indices.shape[:n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
batch_slices = tuple(np.arange(s) for s in shape[:n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]
if not sparse_ind:
data = data.sum(n_batch, keepdims=bool(batch_ind), dtype=data.dtype)
return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data)
@bcoo_todense_p.def_abstract_eval
def _bcoo_todense_abstract_eval(data, indices, *, spinfo):
shape = spinfo.shape
_validate_bcoo(data, indices, shape)
return core.ShapedArray(shape, data.dtype)
def _bcoo_todense_jvp(data_dot, data, indices, *, spinfo):
return _bcoo_todense(data_dot, indices, spinfo=spinfo)
def _bcoo_todense_transpose(ct, data, indices, *, spinfo):
shape = spinfo.shape
assert ad.is_undefined_primal(data)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.shape == shape
assert ct.dtype == data.aval.dtype
return bcoo_extract(indices, ct), indices
def _bcoo_todense_batching_rule(batched_args, batch_dims, *, spinfo):
data, indices = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"{batch_dims=}. Only 0 and None are supported.")
if batch_dims[0] is None:
data = data[None, ...]
if batch_dims[1] is None:
indices = indices[None, ...]
new_spinfo = BCOOInfo(
shape=(max(data.shape[0], indices.shape[0]), *spinfo.shape),
indices_sorted=spinfo.indices_sorted,
unique_indices=spinfo.unique_indices)
return _bcoo_todense(data, indices, spinfo=new_spinfo), 0
ad.defjvp(bcoo_todense_p, _bcoo_todense_jvp, None)
ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
mlir.register_lowering(bcoo_todense_p, mlir.lower_fun(
_bcoo_todense_impl, multiple_results=False))
#--------------------------------------------------------------------
# bcoo_fromdense
bcoo_fromdense_p = core.Primitive('bcoo_fromdense')
bcoo_fromdense_p.multiple_results = True
_TRACED_NSE_ERROR = """
The error arose for the nse argument of bcoo_fromdense. In order for
BCOO.fromdense() to be used in traced/compiled code, you must pass a concrete
value to the nse (number of stored elements) argument.
"""
def bcoo_fromdense(mat: Array, *, nse: Optional[int] = None, n_batch: int = 0,
n_dense: int = 0, index_dtype: DTypeLike = jnp.int32) -> BCOO:
"""Create BCOO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCOO.
nse : number of specified elements in each batch
n_batch : number of batch dimensions (default: 0)
n_dense : number of block_dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
mat_bcoo: BCOO representation of the matrix.
"""
mat = jnp.asarray(mat)
if nse is None:
nse = _count_stored_elements(mat, n_batch, n_dense)
nse_int = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return BCOO(_bcoo_fromdense(mat, nse=nse_int, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype),
shape=mat.shape, indices_sorted=True, unique_indices=True)
def _bcoo_fromdense(mat: Array, *, nse: int, n_batch: int = 0, n_dense: int = 0,
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array]:
"""Create BCOO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCOO, with ``ndim = n_batch + n_sparse + n_dense``.
nse : number of specified elements in each batch
n_batch : number of batch dimensions (default: 0)
n_dense : number of block_dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
data : array of shape ``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]``
and dtype ``mat.dtype``
indices : array of shape ``mat.shape[:n_batch] + (n_sparse, nse)``
"""
mat = jnp.asarray(mat)
nse = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return bcoo_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype)
@bcoo_fromdense_p.def_impl
def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
mat = jnp.asarray(mat)
n_sparse = mat.ndim - n_dense - n_batch
mask = (mat != 0)
if n_dense > 0:
mask = mask.any([-(i + 1) for i in range(n_dense)])
def _nonzero(a):
if a.ndim:
return jnp.nonzero(a, size=nse, fill_value=a.shape[:n_sparse])
return ()
for _ in range(n_batch):
_nonzero = vmap(_nonzero, 0)
indices = _nonzero(mask)
if not indices:
indices = jnp.zeros(mask.shape[:n_batch] + (nse, 0), index_dtype)
else:
indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch + 1)
data = bcoo_extract(indices, mat)
true_nse = mask.sum(list(range(n_batch, mask.ndim)))[..., None]
true_nonzeros = lax.broadcasted_iota(true_nse.dtype, (1,) * n_batch + (nse,), n_batch) < true_nse
true_nonzeros = true_nonzeros[(n_batch + 1) * (slice(None),) + n_dense * (None,)]
data = jnp.where(true_nonzeros, data, 0)
return data, indices
@bcoo_fromdense_p.def_abstract_eval
def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
n_sparse = mat.ndim - n_batch - n_dense
data_shape = mat.shape[:n_batch] + (nse,) + mat.shape[n_batch + n_sparse:]
index_shape = mat.shape[:n_batch] + (nse, n_sparse)
return core.ShapedArray(data_shape, mat.dtype), core.ShapedArray(index_shape, index_dtype)
def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype):
M, = primals
Mdot, = tangents
primals_out = _bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense, index_dtype=index_dtype)
data, indices = primals_out
if type(Mdot) is ad.Zero:
data_dot = ad.Zero.from_value(data)
else:
data_dot = bcoo_extract(indices, Mdot)
tangents_out = (data_dot, ad.Zero.from_value(indices))
return primals_out, tangents_out
def _bcoo_fromdense_transpose(ct, M, *, nse, n_batch, n_dense, index_dtype):
data, indices = ct
n_sparse = M.ndim - n_batch - n_dense
assert data.shape == M.shape[:n_batch] + (nse,) + M.shape[n_batch + n_sparse:]
assert indices.shape == M.shape[:n_batch] + (n_sparse, nse)
assert indices.dtype == index_dtype
if isinstance(indices, ad.Zero):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ad.is_undefined_primal(M)
return _bcoo_todense(data, indices, spinfo=BCOOInfo(M.aval.shape))
def _bcoo_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch, n_dense, index_dtype):
M, = batched_args
if batch_dims != (0,):
raise NotImplementedError(f"{batch_dims=}")
return _bcoo_fromdense(M, nse=nse, n_batch=n_batch + 1, n_dense=n_dense, index_dtype=index_dtype), (0, 0)
ad.primitive_jvps[bcoo_fromdense_p] = _bcoo_fromdense_jvp
ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
mlir.register_lowering(bcoo_fromdense_p, mlir.lower_fun(
_bcoo_fromdense_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcoo_extract
bcoo_extract_p = core.Primitive('bcoo_extract')
def bcoo_extract(indices: Array, mat: Array) -> Array:
"""Extract BCOO data values from a dense matrix at given BCOO indices.
Args:
indices: An ndarray; see BCOO indices.
mat: A dense matrix.
Returns:
An ndarray; see BCOO data.
"""
return bcoo_extract_p.bind(indices, mat)
@bcoo_extract_p.def_impl
def _bcoo_extract_impl(indices, mat):
mat = jnp.asarray(mat)
n_batch, n_sparse, _, _ = _validate_bcoo_indices(indices, mat.shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(mat.shape[:n_batch], indices.shape[:n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
batch_slices = tuple(np.arange(s) for s in mat.shape[:n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]
if not sparse_ind + batch_ind:
return mat[None]
return mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
@bcoo_extract_p.def_abstract_eval
def _bcoo_extract_abstract_eval(indices, mat):
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)
def _bcoo_extract_jvp(mat_dot, indices, mat):
assert mat_dot.shape == mat.shape
return bcoo_extract(indices, mat_dot)
def _bcoo_extract_transpose(ct, indices, mat):
assert ad.is_undefined_primal(mat)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.dtype == mat.aval.dtype
return indices, _bcoo_todense(ct, indices, spinfo=BCOOInfo(mat.aval.shape))
def _bcoo_extract_batching_rule(batched_args, batch_dims):
indices, mat = batched_args
assert any(b is not None for b in batch_dims)
if batch_dims[0] is None:
bdim = batch_dims[1]
indices = lax.expand_dims(indices, (bdim,))
elif batch_dims[1] is None:
# TODO(jakevdp) can we handle this case without explicit broadcasting?
bdim = batch_dims[0]
result_shape = list(mat.shape)
result_shape.insert(bdim, indices.shape[bdim])
mat = lax.broadcast_in_dim(mat, result_shape, (bdim,))
else:
if batch_dims[0] != batch_dims[1]:
raise NotImplementedError("bcoo_extract with unequal batch dimensions.")
bdim = batch_dims[0]
n_batch = indices.ndim - 2
if bdim >= n_batch:
raise ValueError(f"{batch_dims=} out of range for indices with {n_batch=}")
return bcoo_extract(indices, mat), bdim
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
mlir.register_lowering(bcoo_extract_p, mlir.lower_fun(
_bcoo_extract_impl, multiple_results=False))
#----------------------------------------------------------------------
# bcoo_transpose
# transpose of a BCOO array
bcoo_transpose_p = core.Primitive('bcoo_transpose')
bcoo_transpose_p.multiple_results = True
def bcoo_transpose(mat: BCOO, *, permutation: Sequence[int]) -> BCOO:
"""Transpose a BCOO-format array.
Args:
mat: A BCOO-format array.
permutation: A tuple or list or ndarray which contains a permutation of
[0,1,..,N-1] where N is the number of axes of ``mat`` in the order of
batch, sparse, and dense dimensions. The i’th axis of the returned array
corresponds to the axis numbered permutation[i] of ``mat``. Transpose
permutation currently does not support permuting batch axes with non-batch
axes nor permutating dense axes with non-dense axes.
Returns:
A BCOO-format array.
"""
buffers = _bcoo_transpose(mat.data, mat.indices, permutation=permutation, spinfo=mat._info)
out_shape = tuple(mat.shape[p] for p in permutation)
return BCOO(buffers, shape=out_shape, unique_indices=mat.unique_indices)
def _bcoo_transpose(data: Array, indices: Array, *,
permutation: Sequence[int], spinfo: BCOOInfo) -> Tuple[Array, Array]:
permutation = tuple(permutation)
if permutation == tuple(range(len(spinfo.shape))):
return data, indices
else:
return bcoo_transpose_p.bind(data, indices, permutation=permutation,
spinfo=spinfo)
def _validate_permutation(data, indices, permutation, shape):
if not isinstance(permutation, (tuple, list, np.ndarray)):
raise TypeError(f"transpose permutation must be a tuple/list/ndarray, got {type(permutation)}.")
if tuple(sorted(permutation)) != tuple(range(len(shape))):
raise TypeError("transpose permutation isn't a permutation of operand dimensions, "
f"got permutation {permutation} for shape {shape}.")
n_batch, n_sparse, n_dense, _ = _validate_bcoo(data, indices, shape)
batch_perm = permutation[:n_batch]
sparse_perm = [p - n_batch for p in permutation[n_batch: n_batch + n_sparse]]
dense_perm = [p - n_sparse - n_batch for p in permutation[n_batch + n_sparse:]]
if n_batch and tuple(sorted(batch_perm)) != tuple(range(n_batch)):
raise NotImplementedError("transpose permutation cannot permute batch axes with non-batch axes; "
f"got permutation {permutation}, with {n_batch=}.")
if n_dense and tuple(sorted(dense_perm)) != tuple(range(n_dense)):
raise NotImplementedError("transpose permutation cannot permute dense axes with non-dense axes; "
f"got permutation {permutation}, with {n_dense=}.")
return batch_perm, sparse_perm, dense_perm
@bcoo_transpose_p.def_impl
def _bcoo_transpose_impl(data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
batch_perm, sparse_perm, dense_perm = _validate_permutation(data, indices, permutation, spinfo.shape)
n_batch = len(batch_perm)
indices = indices[..., sparse_perm].transpose(*batch_perm, n_batch, n_batch + 1)
data = data.transpose(*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm))
return data, indices
@bcoo_transpose_p.def_abstract_eval
def _bcoo_transpose_abstract_eval(data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
batch_perm, _, dense_perm = _validate_permutation(data, indices, permutation, spinfo.shape)
n_batch = len(batch_perm)
indices_shape = np.array(indices.shape)[[*batch_perm, n_batch, n_batch + 1]]
data_shape = np.array(data.shape)[[*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm)]]
return core.ShapedArray(data_shape, data.dtype), core.ShapedArray(indices_shape, indices.dtype)
def _bcoo_transpose_jvp(primals, tangents, *, permutation: Sequence[int], spinfo: BCOOInfo):
data, indices = primals
data_dot, _ = tangents
primals_out = _bcoo_transpose(data, indices, permutation=permutation, spinfo=spinfo)
data_dot_out, _ = _bcoo_transpose(data_dot, indices, permutation=permutation, spinfo=spinfo)
return primals_out, (data_dot_out, ad.Zero.from_value(indices))
def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
data_ct, indices_ct = ct
assert isinstance(indices_ct, ad.Zero)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert data_ct.dtype == data.aval.dtype
ct_spinfo = BCOOInfo(tuple(spinfo.shape[p] for p in permutation))
rev_permutation = list(np.argsort(permutation))
# TODO(jakevdp) avoid dummy indices?
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
data_trans, _ = _bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, spinfo=ct_spinfo)
return data_trans, indices_ct
def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequence[int], spinfo: BCOOInfo):
data, indices = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
data = data[None]
else:
assert batch_dims[0] == 0
if batch_dims[1] is None:
indices = indices[None]
else:
assert batch_dims[1] == 0
batched_spinfo = BCOOInfo((batch_size, *spinfo.shape))
batched_permutation = (0, *(p + 1 for p in permutation))
data, indices = _bcoo_transpose(data, indices, permutation=batched_permutation, spinfo=batched_spinfo)
if batch_dims[0] is None:
data = data[0]
if batch_dims[1] is None:
indices = indices[0]
return (data, indices), batch_dims
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
_bcoo_transpose_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcoo_dot_general
# (batched) general dot product of a BCOO sparse ND array and a dense ND array,
# returning a dense ND array.
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
def bcoo_dot_general(lhs: Union[BCOO, Array], rhs: Union[BCOO, Array], *, dimension_numbers: DotDimensionNumbers,
precision: None = None, preferred_element_type: None = None) -> Union[BCOO, Array]:
"""A general contraction operation.
Args:
lhs: An ndarray or BCOO-format sparse array.
rhs: An ndarray or BCOO-format sparse array..
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
Returns:
An ndarray or BCOO-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCOO. If either input is dense,
the result will be dense, of type ndarray.
"""
# TODO(jakevdp) make use of these?
del precision, preferred_element_type # unused
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
dimension_numbers)
bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
dimension_numbers=dimension_numbers)
return BCOO(bufs, shape=shape)
elif isinstance(lhs, BCOO):
return _bcoo_dot_general(lhs.data, lhs.indices, rhs, dimension_numbers=dimension_numbers, # type: ignore[arg-type]
lhs_spinfo=lhs._info)
elif isinstance(rhs, BCOO):
return _bcoo_rdot_general(lhs, rhs.data, rhs.indices, dimension_numbers=dimension_numbers, # type: ignore[arg-type]
rhs_spinfo=rhs._info)
else:
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
def _bcoo_dot_general(lhs_data: Array, lhs_indices: Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers, lhs_spinfo: BCOOInfo) -> Array:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
return bcoo_dot_general_p.bind(jnp.asarray(lhs_data), jnp.asarray(lhs_indices), jnp.asarray(rhs),
dimension_numbers=(cdims, bdims),
lhs_spinfo=lhs_spinfo)
def _bcoo_rdot_general(lhs: Array, rhs_data: Array, rhs_indices: Array, *,
dimension_numbers: DotDimensionNumbers, rhs_spinfo: BCOOInfo) -> Array:
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
dimension_numbers_reversed: DotDimensionNumbers = tuple(d[::-1] for d in dimension_numbers) # type: ignore[assignment]
result = _bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_spinfo=rhs_spinfo,
dimension_numbers=dimension_numbers_reversed)
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
n_swap = len(rhs_spinfo.shape) - n_contract
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
return lax.transpose(result, permutation)
@bcoo_dot_general_p.def_impl
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
lhs_data = jnp.asarray(lhs_data)
lhs_indices = jnp.asarray(lhs_indices)
rhs = jnp.asarray(rhs)
# Validate all inputs via abstract_eval
out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
n_sparse = lhs_indices.shape[-1]
n_batch = lhs_indices.ndim - 2
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_contracting_b, rhs_contracting_b = unzip2([
(l, r) for l, r in safe_zip(lhs_contracting, rhs_contracting) if l < n_batch])
lhs_contracting_s, rhs_contracting_s = unzip2([
(l, r) for l, r in safe_zip(lhs_contracting, rhs_contracting) if l >= n_batch])
# Reorder lhs batch dimensions
if lhs_batch or lhs_contracting_b:
batch_perm = [*lhs_batch, *remaining(range(n_batch), lhs_batch, lhs_contracting_b), *lhs_contracting_b]
lhs_data = lhs_data.transpose([*batch_perm, *range(n_batch, lhs_data.ndim)])
lhs_indices = lhs_indices.transpose([*batch_perm, *range(n_batch, lhs_indices.ndim)])
# Reorder lhs sparse dimensions
if lhs_contracting_s:
lhs_contracting_s = tuple(d - n_batch for d in lhs_contracting_s)
sparse_perm = jnp.array([*lhs_contracting_s, *remaining(range(n_sparse), lhs_contracting_s)])
lhs_indices = lhs_indices[..., sparse_perm]
# Reorder rhs dimensions
rhs_perm = [*rhs_batch, *rhs_contracting_b, *rhs_contracting_s,
*remaining(range(rhs.ndim), rhs_batch, rhs_contracting)]
rhs = rhs.transpose(rhs_perm)
def result(out_array, lhs_data, lhs_indices, rhs):
idx = tuple(lhs_indices[..., i] for i in range(n_sparse))
idx_right = idx[:len(lhs_contracting_s)]
idx_out = idx[len(lhs_contracting_s):]
if idx_right and lhs_indices.ndim > 2:
idx_batch = jnp.meshgrid(
*(jnp.arange(n) for n in lhs_indices.shape[:-1]),
indexing='ij')[:lhs_indices.ndim - 2]
idx_right = (*idx_batch, *idx_right)
batch_dims = list(range(len(lhs_contracting_b) + bool(lhs_contracting_s)))
prod = lax.dot_general(lhs_data, rhs.at[idx_right].get(mode='fill', fill_value=0),
(([], []), (batch_dims, batch_dims)))
if idx_out:
return out_array.at[idx_out].add(prod)
else:
return prod.sum(tuple(range(prod.ndim - out_array.ndim)), dtype=out_array.dtype)
for _ in range(n_batch - len(lhs_contracting_b)):
result = _broadcasting_vmap(result)
rhs = lax.expand_dims(rhs, range(len(rhs_batch), n_batch - len(lhs_contracting_b)))
out_array = jnp.zeros(out_aval.shape, out_aval.dtype)
return result(out_array, lhs_data, lhs_indices, rhs)
@bcoo_dot_general_p.def_abstract_eval
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
if lhs_data.dtype != rhs.dtype:
raise ValueError("bcoo_dot_general requires arguments to have matching dtypes; "
f"got lhs.dtype={lhs_data.dtype}, rhs.dtype={rhs.dtype}")
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape)
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
dimension_numbers)
if lhs_batch and max(lhs_batch) >= n_batch:
raise NotImplementedError(
"bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
f"got {lhs_batch=}, {n_batch=}")
# TODO: support contraction of dense dimensions?
if any(d >= n_batch + n_sparse for d in lhs_contracting):
raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.")
return core.ShapedArray(out_shape, lhs_data.dtype)
_bcoo_dot_general_default_lowering = mlir.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
def _collapse_mhlo(x, start, end):
x_type = ir.RankedTensorType(x.type)
shape = x_type.shape
shape = (shape[:start]
+ [functools.reduce(operator.mul, shape[start:end + 1])]
+ shape[end + 1:])
return mhlo.ReshapeOp(
ir.RankedTensorType.get(shape, x_type.element_type), x).result
def _bcoo_dot_general_cuda_lowering(
coo_matvec_lowering, coo_matmat_lowering, ctx, lhs_data, lhs_indices, rhs,
*, dimension_numbers, lhs_spinfo: BCOOInfo):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
props = _validate_bcoo_indices(lhs_indices_aval, lhs_spinfo.shape)
rhs_ndim = len(ir.RankedTensorType(rhs.type).shape)
# Checks the shapes of lhs and rhs.
assert props.n_dense == 0
assert (props.n_batch, props.n_sparse, rhs_ndim) in [
(0, 1, 1), (0, 1, 2), (0, 2, 1), (0, 2, 2), (1, 2, 2)]
# Checks the operation dimensions.
assert len(lhs_batch) == 0
assert len(rhs_batch) == 0
assert len(lhs_contract) == 1
# Checks the dtype.
assert lhs_data_aval.dtype in [np.float32, np.float64, np.complex64,
np.complex128]
assert lhs_data_aval.dtype == rhs_aval.dtype
assert lhs_indices_aval.dtype == np.int32
if rhs_ndim == 1:
bcoo_dot_general_fn = coo_matvec_lowering
elif rhs_ndim == 2:
bcoo_dot_general_fn = coo_matmat_lowering
if rhs_contract[0] == 1:
rhs = mhlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")
if props.n_batch == 0:
# non-batch mode.
lhs_transpose = False
if props.n_sparse == 1:
# Converts lhs to a row vector.
col = _collapse_mhlo(lhs_indices, start=0, end=1)
row = mlir.full_like_aval(
ctx, 0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
np.dtype(np.int32)))
lhs_shape = (1, lhs_spinfo.shape[0])
dot_product = bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_shape, transpose=lhs_transpose,
data_dtype=lhs_data_aval.dtype, index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)
if rhs_ndim == 1:
# Transforms a single-element array to a scalar.
return [mhlo.ReshapeOp(
ir.RankedTensorType.get(
[], ir.RankedTensorType(dot_product.type).element_type),
dot_product).result]
else:
return [_collapse_mhlo(dot_product, start=0, end=1)]
elif props.n_sparse == 2:
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
row = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
col = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
if lhs_contract[0] == 0:
lhs_transpose = True
return [bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_spinfo.shape,
transpose=lhs_transpose, data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)]
else:
raise ValueError(f"lhs has to be 1d or 2d; get {props.n_sparse}d.")
elif props.n_batch == 1:
# batch mode.
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
lhs_data_shape = ir.RankedTensorType(lhs_data.type).shape
batch_count, _, _ = lhs_indices_shape
rhs_shape = ir.RankedTensorType(rhs.type).shape
# Squeeze the batch dimension for both indices and data.
lhs_indices_2d_shape = (np.prod(np.array(lhs_indices_shape)[:-1]),
lhs_indices_shape[-1])
lhs_data_1d_shape = (np.prod(np.array(lhs_data_shape)), )
lhs_indices_2d = mhlo.ReshapeOp(
ir.RankedTensorType.get(
lhs_indices_2d_shape,
ir.RankedTensorType(lhs_indices.type).element_type),
lhs_indices).result
lhs_data_1d = mhlo.ReshapeOp(
ir.RankedTensorType.get(
lhs_data_1d_shape,
ir.RankedTensorType(lhs_data.type).element_type),
lhs_data).result
row = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices_2d,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
col = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices_2d,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 2]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
# Broadcast rhs to have the same batch size as lhs.
# TODO(tianjianlu): remove broadcasting.
# Use batch_stride = 0 for non-batch.
# The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
# in cusparse library does not allow batch_stride = 0 for a non-batched rhs.
batched_rhs_shape = (batch_count,) + tuple(rhs_shape)
batched_rhs = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batched_rhs_shape,
ir.RankedTensorType(rhs.type).element_type),
rhs,
broadcast_dimensions=mlir.dense_int_elements([1, 2])).result
batched_rhs_2d_shape = (np.prod(np.array(batched_rhs_shape)[:-1]), batched_rhs_shape[-1])
batched_rhs_2d = mhlo.ReshapeOp(
ir.RankedTensorType.get(
batched_rhs_2d_shape,
ir.RankedTensorType(batched_rhs.type).element_type),
batched_rhs).result
lhs_transpose = True if lhs_contract[0] == props.n_batch else False
return [bcoo_dot_general_fn(
lhs_data_1d, row, col, batched_rhs_2d, shape=lhs_spinfo.shape,
transpose=lhs_transpose, data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)]
else:
raise ValueError(f"n_batch has to be 0 or 1; get {props.n_batch}.")
def _bcoo_dot_general_gpu_lowering(
coo_matvec_lowering, coo_matmat_lowering,
ctx, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
n_batch, n_sparse, n_dense, _ = _validate_bcoo(
lhs_data_aval, lhs_indices_aval, lhs_spinfo.shape)
dtype = lhs_data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f'bcoo_dot_general cusparse/hipsparse lowering not available '
f'for {dtype=}. Falling back to default implementation.',
CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
if (n_batch > 1 or n_dense or
n_sparse not in [1, 2] or rhs_aval.ndim not in [1, 2] or
lhs_batch or rhs_batch or len(lhs_contract) != 1):
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
else:
if not lhs_spinfo.indices_sorted:
warnings.warn("bcoo_dot_general GPU lowering requires matrices with "
"sorted indices. To sort the rows in your matrix, use e.g. "
"mat = mat.sort_indices(). Falling back to the default "
"implementation.", CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
if n_batch == 1:
# The support for batched computation in cusparseSpMM COO was added in
# 11.6.1: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cusparse-11.6.1
cuda_version = int(xla_bridge.get_backend().platform_version.split()[-1])
# TODO(tianjianlu): enable the batch mode of cusparseSpMv.
cuda_supported_batch_mode = (
n_sparse == 2 and rhs_aval.ndim == 2 and
len(lhs_contract) == 1 and lhs_contract[0] in [1, 2] and
len(rhs_contract) == 1 and rhs_contract[0] in [0, 1] and
cuda_version >= 11061)
if not cuda_supported_batch_mode:
warnings.warn("bcoo_dot_general GPU lowering currently does not "
"support this batch-mode computation. Falling back to "
"the default implementation.", CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
return _bcoo_dot_general_cuda_lowering(
coo_matvec_lowering, coo_matmat_lowering, ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
return _bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
assert not ad.is_undefined_primal(lhs_indices)
if type(ct) is ad.Zero:
return ad.Zero
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim = len(lhs_spinfo.shape)
rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim
lhs_kept = remaining(range(lhs_ndim), lhs_contract, lhs_batch)
rhs_kept = remaining(range(rhs_ndim), rhs_contract, rhs_batch)
ans_batch, ans_lhs, ans_rhs = map(list, ranges_like(lhs_batch, lhs_kept, rhs_kept))
if ad.is_undefined_primal(lhs_data):
dims: DotDimensionNumbers = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch)) # type: ignore[assignment]
lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract)))
permutation = list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs
out_axes = list(np.argsort(permutation))
# What follows is essentially this, but computed in terms of dot_general_sampled:
# out_dense_T = lax.dot_general(ct, rhs, dimension_numbers=dims)
# out_dense = lax.transpose(out_dense_T, out_axes)
# result = bcoo_extract(lhs_indices, out_dense)
# Instead we (1) un-transpose indices, (2) compute SDDMM, (3) re-transpose result
dummy_data = jnp.ones([1 for i in range(lhs_indices.ndim - 2)] + [lhs_indices.shape[-2]])
dummy_spinfo = BCOOInfo(tuple(lhs_indices.shape[:-2]) + tuple(1 for i in range(lhs_indices.shape[-1])))
_, lhs_indices_T = _bcoo_transpose(dummy_data, lhs_indices, permutation=permutation, spinfo=dummy_spinfo)
result_T = bcoo_dot_general_sampled(ct, rhs, lhs_indices_T, dimension_numbers=dims)
result, _ = _bcoo_transpose(result_T, lhs_indices_T, permutation=out_axes, spinfo=dummy_spinfo)
return result, lhs_indices, rhs
else:
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch)) # type: ignore[assignment]
rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
out_axes = list(np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept))
result = _bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_spinfo=lhs_spinfo, dimension_numbers=dims)
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_spinfo: BCOOInfo):
lhs_data, lhs_indices, rhs = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
lhs_data = lhs_data[None]
batch_dims[0] = 0
if batch_dims[1] is None: