-
Notifications
You must be signed in to change notification settings - Fork 21.3k
/
jit_utils.cpp
1660 lines (1467 loc) · 54.7 KB
/
jit_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#define TORCH_ASSERT_NO_OPERATORS
#include <c10/core/ScalarType.h>
#include <c10/util/irange.h>
#include <c10/util/hash.h>
#include <c10/util/Optional.h>
#include <ATen/jit_macros.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <ATen/code_template.h>
#include <ATen/OpMathType.h>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/cuda/llvm_jit_strings.h>
#include <ATen/native/cuda/reduction_template.cuh>
#include <sstream>
#include <fstream>
#include <cstdio>
#include <iterator> // istreambuf_iterator
#include <cstdlib>
#include <string>
// TODO: C++17 has the fileystem header, which may replace these
#ifdef _WIN32
// On Windows, the POSIX implementations are considered deprecated. We simply map to the newer variant.
#include <process.h>
#include <direct.h>
#include <io.h>
#define access _access
#define getpid _getpid
#define R_OK 4
#define W_OK 2
#define F_OK 0
#else
#include <sys/types.h>
#include <sys/stat.h> // mkdir
#include <unistd.h>
#endif
namespace at { namespace cuda { namespace jit {
// hiprtc already includes some traits, so this removes duplicate definitions of
// integral_constant, is_same, is_integral, enable_if, is_floating_point, is_arithmetic.
// Copied from aten/src/ATen/cuda/llvm_basic.cpp, then modified as above.
// If not compiling for ROCm, return the original get_traits_string().
std::string get_traits_string_but_hiprtc_safe() {
#ifdef USE_ROCM
return R"ESCAPE(
namespace std {
template <class _Tp>
_Tp&& __declval(int);
template <class _Tp>
_Tp __declval(long);
template <class _Tp>
decltype(__declval<_Tp>(0)) declval() noexcept;
template <class _Tp> struct remove_const {typedef _Tp type;};
template <class _Tp> struct remove_const<const _Tp> {typedef _Tp type;};
template <class _Tp> using remove_const_t = typename remove_const<_Tp>::type;
template <class _Tp> struct remove_volatile {typedef _Tp type;};
template <class _Tp> struct remove_volatile<volatile _Tp> {typedef _Tp type;};
template <class _Tp> using remove_volatile_t = typename remove_volatile<_Tp>::type;
template <class _Tp> struct remove_cv
{typedef typename remove_volatile<typename remove_const<_Tp>::type>::type type;};
template <class _Tp> using remove_cv_t = typename remove_cv<_Tp>::type;
template <class _Tp> struct __libcpp_is_floating_point : public false_type {};
template <> struct __libcpp_is_floating_point<float> : public true_type {};
template <> struct __libcpp_is_floating_point<double> : public true_type {};
template <> struct __libcpp_is_floating_point<long double> : public true_type {};
template <class _Tp>
inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value;
template <class _Tp>
struct __numeric_type
{
static void __test(...);
static float __test(float);
static double __test(char);
static double __test(int);
static double __test(unsigned);
static double __test(long);
static double __test(unsigned long);
static double __test(long long);
static double __test(unsigned long long);
static double __test(double);
static long double __test(long double);
typedef decltype(__test(declval<_Tp>())) type;
static const bool value = !is_same<type, void>::value;
};
template <>
struct __numeric_type<void>
{
static const bool value = true;
};
// __promote
template <class _A1, class _A2 = void, class _A3 = void,
bool = __numeric_type<_A1>::value &&
__numeric_type<_A2>::value &&
__numeric_type<_A3>::value>
class __promote_imp
{
public:
static const bool value = false;
};
template <class _A1, class _A2, class _A3>
class __promote_imp<_A1, _A2, _A3, true>
{
private:
typedef typename __promote_imp<_A1>::type __type1;
typedef typename __promote_imp<_A2>::type __type2;
typedef typename __promote_imp<_A3>::type __type3;
public:
typedef decltype(__type1() + __type2() + __type3()) type;
static const bool value = true;
};
template <class _A1, class _A2>
class __promote_imp<_A1, _A2, void, true>
{
private:
typedef typename __promote_imp<_A1>::type __type1;
typedef typename __promote_imp<_A2>::type __type2;
public:
typedef decltype(__type1() + __type2()) type;
static const bool value = true;
};
template <class _A1>
class __promote_imp<_A1, void, void, true>
{
public:
typedef typename __numeric_type<_A1>::type type;
static const bool value = true;
};
template <class _A1, class _A2 = void, class _A3 = void>
class __promote : public __promote_imp<_A1, _A2, _A3> {};
} // namespace std
)ESCAPE";
#else
return get_traits_string();
#endif
}
#ifdef USE_ROCM
const std::string jit_preamble = R"ESCAPE(
#pragma clang force_cuda_host_device begin
)ESCAPE";
const std::string jit_epilogue = R"ESCAPE(
#pragma clang force_cuda_host_device end
)ESCAPE";
#else
const std::string jit_preamble;
const std::string jit_epilogue;
#endif
const std::string jit_common_types = R"ESCAPE(
#ifdef __HIPCC__
#define ERROR_UNSUPPORTED_CAST ;
// corresponds to aten/src/ATen/native/cuda/thread_constants.h
#define CUDA_OR_ROCM_NUM_THREADS 256
// corresponds to aten/src/ATen/cuda/detail/OffsetCalculator.cuh
#define MAX_DIMS 16
#ifndef __forceinline__
#define __forceinline__ inline __attribute__((always_inline))
#endif
#else
//TODO use _assert_fail, because assert is disabled in non-debug builds
#define ERROR_UNSUPPORTED_CAST assert(false);
#define CUDA_OR_ROCM_NUM_THREADS 128
#define MAX_DIMS 25
#endif
#define POS_INFINITY __int_as_float(0x7f800000)
#define INFINITY POS_INFINITY
#define NEG_INFINITY __int_as_float(0xff800000)
#define NAN __int_as_float(0x7fffffff)
typedef long long int int64_t;
typedef unsigned int uint32_t;
typedef signed char int8_t;
typedef unsigned char uint8_t; // NOTE: this MUST be "unsigned char"! "char" is equivalent to "signed char"
typedef short int16_t;
static_assert(sizeof(int64_t) == 8, "expected size does not match");
static_assert(sizeof(uint32_t) == 4, "expected size does not match");
static_assert(sizeof(int8_t) == 1, "expected size does not match");
constexpr int num_threads = CUDA_OR_ROCM_NUM_THREADS;
constexpr int thread_work_size = 4; // TODO: make template substitution once we decide where those vars live
constexpr int block_work_size = thread_work_size * num_threads;
${traits_string}
${cmath_string}
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
// Note, some types have ctype as void because we don't support them in codegen
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(std::complex<at::Half>, ComplexHalf) /* 8 */ \
_(std::complex<float>, ComplexFloat) /* 9 */ \
_(std::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(void, QInt8) /* 12 */ \
_(void, QUInt8) /* 13 */ \
_(void, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(std::complex<at::Half>, ComplexHalf) \
_(std::complex<float>, ComplexFloat) \
_(std::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16)
enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
#undef DEFINE_ENUM
Undefined,
NumOptions
};
template <typename T, int size>
struct Array {
T data[size];
__device__ T operator[](int i) const {
return data[i];
}
__device__ T& operator[](int i) {
return data[i];
}
Array() = default;
Array(const Array&) = default;
Array& operator=(const Array&) = default;
__device__ Array(T x) {
for (int i = 0; i < size; i++) {
data[i] = x;
}
}
};
${half_string}
${bfloat16_string}
${complex_body_string}
${complex_half_body_string}
${complex_math_string}
)ESCAPE";
//we need to include half, bfloat16 and complex strings to all kernels with half arguments and to all kernels with type casting
//regardless of whether they have half arguments (because fetch_and_cast and cast_and_store loop over all types)
const std::string jiterator_half_support_literal = R"ESCAPE(
namespace at {
struct alignas(2) Half {
unsigned short x;
Half() = default;
inline __host__ __device__ Half(float value){
#ifdef __HIPCC__
x = __half_as_short(__float2half(value));
#else
asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(x) : "f"(value));
#endif
}
inline __host__ __device__ operator float() const{
#ifdef __HIPCC__
return __half2float(*reinterpret_cast<const __half*>(&x));
#else
float val;
asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x)); // do we need const cast here?
//asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(x)));
return val;
#endif
}
};
}
)ESCAPE";
const std::string jiterator_bfloat16_support_literal = R"ESCAPE(
namespace at {
struct alignas(2) BFloat16 {
unsigned short x;
__device__ unsigned short __internal_float2bfloat16(
const float f,
unsigned int& sign,
unsigned int& remainder) {
unsigned int x;
x = __float_as_uint(f);
if ((x & 0x7fffffffU) > 0x7f800000U) {
sign = 0U;
remainder = 0U;
return static_cast<unsigned short>(0x7fffU);
}
sign = x >> 31;
remainder = x << 16;
return static_cast<unsigned short>(x >> 16);
}
BFloat16() = default;
inline __host__ __device__ BFloat16(float value){
#if __CUDA_ARCH__ >= 800
asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(x) : "f"(value));
)ESCAPE"
R"ESCAPE(
#else
unsigned int sign;
unsigned int remainder;
x = __internal_float2bfloat16(value, sign, remainder);
if ((remainder > 0x80000000U) ||
((remainder == 0x80000000U) && ((x & 0x1U) != 0U))) {
x++;
}
#endif
}
inline __host__ __device__ operator float() const{
#ifdef __HIPCC__
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
#else
float val;
asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(x)); //do we need const cast here?
return val;
#endif
}
};
}
)ESCAPE";
// From c10/util/Load.h
const std::string load_support_literal = R"ESCAPE(
namespace c10 {
template <typename T>
struct LoadImpl {
__device__ static T apply(const void *src) {
return *reinterpret_cast<const T*>(src);
}
};
template <>
struct LoadImpl<bool> {
__device__ static bool apply(const void *src) {
static_assert(sizeof(bool) == sizeof(char), "");
return LoadImpl<char>::apply(src);
}
};
template <typename T>
__device__ T load(const void *src) {
return LoadImpl<T>::apply(src);
}
template <typename scalar_t>
__device__ scalar_t load(const scalar_t *src) {
return LoadImpl<scalar_t>::apply(src);
}
} // namespace c10
)ESCAPE";
// copy-pasted from c10/util/TypeCast.h and c10/core/DynamicCast.h
const std::string dynamic_cast_support_literal = R"ESCAPE(
template <typename T>
struct is_complex : public std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> : public std::true_type {};
template <typename dest_t, typename src_t>
struct needs_real {
constexpr static bool value =
(is_complex<src_t>::value && !is_complex<dest_t>::value);
};
template <bool, typename src_t>
struct maybe_real {
static inline src_t apply(src_t src) {
return src;
}
};
template <typename src_t>
struct maybe_real<true, src_t> {
static inline decltype(auto) apply(src_t src) {
return src.real();
}
};
template <typename dest_t, typename src_t>
struct static_cast_with_inter_type {
static inline dest_t apply(
src_t src) {
constexpr bool real = needs_real<dest_t, src_t>::value;
return static_cast<dest_t>(maybe_real<real, src_t>::apply(src));
}
};
template <typename src_t>
struct static_cast_with_inter_type<uint8_t, src_t> {
static inline uint8_t apply(
src_t src) {
constexpr bool real = needs_real<uint8_t, src_t>::value;
return static_cast<uint8_t>(
static_cast<int64_t>(maybe_real<real, src_t>::apply(src)));
}
};
template <>
struct static_cast_with_inter_type<std::complex<at::Half>, at::BFloat16> {
static inline std::complex<at::Half> apply(at::BFloat16 src) {
return static_cast<std::complex<at::Half>>(float{src});
}
};
template <>
struct static_cast_with_inter_type<std::complex<at::Half>, at::Half> {
static inline std::complex<at::Half> apply(at::Half src) {
return static_cast<std::complex<at::Half>>(float{src});
}
};
template <>
struct static_cast_with_inter_type<
std::complex<at::Half>,
std::complex<double>> {
static inline std::complex<at::Half> apply(std::complex<double> src) {
return static_cast<std::complex<at::Half>>(static_cast<std::complex<float>>(src));
}
};
// Fetch a value with dynamic type src_type from ptr, and cast it to static type dest_t.
#define FETCH_AND_CAST_CASE(type, scalartype) \
case ScalarType::scalartype: \
return static_cast_with_inter_type<dest_t, type>::apply(c10::load<type>(ptr));
template<typename dest_t>
__device__ inline dest_t fetch_and_cast(const ScalarType src_type, const void *ptr) {
switch (src_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(FETCH_AND_CAST_CASE)
default:
ERROR_UNSUPPORTED_CAST
}
return dest_t(0); // just to avoid compiler warning
}
// Cast a value with static type src_t into dynamic dest_type, and store it to ptr.
#define CAST_AND_STORE_CASE(type, scalartype) \
case ScalarType::scalartype: \
*(type*)ptr = static_cast_with_inter_type<type, src_t>::apply(value); \
return;
template<typename src_t>
__device__ inline void cast_and_store(const ScalarType dest_type, void *ptr, src_t value) {
switch (dest_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(CAST_AND_STORE_CASE)
default:;
}
ERROR_UNSUPPORTED_CAST
}
template <int N>
struct LoadWithCast {
using array_t = Array<ScalarType, N==0? 1 : N>;
using size_array_t = Array<uint32_t, N==0? 1: N>;
array_t dtypes;
size_array_t element_sizes;
template <typename scalar_t>
__device__ scalar_t load(char* base_ptr, uint32_t offset, int arg) {
void* ptr = base_ptr + element_sizes[arg] * offset;
return fetch_and_cast<scalar_t>(dtypes[arg], ptr);
}
};
template <int N = 1>
struct StoreWithCast {
using array_t = Array<ScalarType, N==0? 1 : N>;
using size_array_t = Array<uint32_t, N==0? 1: N>;
array_t dtypes;
size_array_t element_sizes;
template<typename scalar_t>
__device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
void *ptr = base_ptr + element_sizes[arg] * offset;
cast_and_store<scalar_t>(dtypes[arg], ptr, value);
}
};
)ESCAPE";
const std::string no_dynamic_cast_support_literal = R"ESCAPE(
struct LoadWithoutCast {
template <typename scalar_t>
__device__ scalar_t load(char* base_ptr, uint32_t offset, int arg=0) {
return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
}
};
struct StoreWithoutCast {
template<typename scalar_t>
__device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg=0) {
*(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
}
};
)ESCAPE";
const std::string offset_calc_template = R"ESCAPE(
template <typename T>
struct DivMod {
T div;
T mod;
__device__ DivMod(T _div, T _mod) {
div = _div;
mod = _mod;
}
};
//<unsigned int>
struct IntDivider {
IntDivider() = default;
__device__ inline unsigned int div(unsigned int n) const {
unsigned int t = __umulhi(n, m1);
return (t + n) >> shift;
}
__device__ inline unsigned int mod(unsigned int n) const {
return n - div(n) * divisor;
}
__device__ inline DivMod<unsigned int> divmod(unsigned int n) const {
unsigned int q = div(n);
return DivMod<unsigned int>(q, n - q * divisor);
}
unsigned int divisor; // d above.
unsigned int m1; // Magic number: m' above.
unsigned int shift; // Shift amounts.
};
template <int NARGS>
struct TrivialOffsetCalculator {
// The offset for each argument. Wrapper around fixed-size array.
// The offsets are in # of elements, not in bytes.
Array<${index_type}, NARGS> get(${index_type} linear_idx) const {
Array<${index_type}, NARGS> offsets;
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = linear_idx;
}
return offsets;
}
};
template<int NARGS>
struct OffsetCalculator {
OffsetCalculator() = default;
__device__ __forceinline__ Array<${index_type}, NARGS> get(${index_type} linear_idx) const {
Array<${index_type}, NARGS> offsets;
#pragma unroll
for (int arg = 0; arg < NARGS; ++arg) {
offsets[arg] = 0;
}
#pragma unroll
for (int dim = 0; dim < MAX_DIMS; ++dim) {
if (dim == dims) {
break;
}
auto divmod = sizes_[dim].divmod(linear_idx);
linear_idx = divmod.div;
#pragma unroll
for (int arg = 0; arg < NARGS; ++arg) {
offsets[arg] += divmod.mod * strides_[dim][arg];
}
//printf("offset calc thread dim size stride offset %d %d %d %d %d %d %d %d\n",
//threadIdx.x, dim, sizes_[dim].divisor, strides_[dim][0], offsets[0], linear_idx, divmod.div, divmod.mod);
}
return offsets;
}
int dims;
IntDivider sizes_[MAX_DIMS];
// NOTE: this approach will not support nInputs == 0
${index_type} strides_[MAX_DIMS][NARGS];
};
)ESCAPE";
const std::string jit_code_template = R"ESCAPE(
${load_support}
${dynamic_casting_string}
${functor}
// TODO: setup grid-stride loop
extern "C" __global__
void ${name}_kernel(
const int numel,
Array<char*, ${nInputs}+${nOutputs}> data, //[${nInputs}+${nOutputs}],
${offset_calculator}<${nInputs}> input_calculator,
${offset_calculator}<${nOutputs}> output_calculator,
${loader} l,
${storer} s,
${compute_type} scalar_val${extra_params}) {
${declare_load_arrays}
${declare_store_arrays}
int idx = blockIdx.x;
int remaining = numel - block_work_size * idx;
int thread_idx = threadIdx.x;
#pragma unroll
for (int j = 0; j < thread_work_size; j++){
if (thread_idx >= remaining) {
break;
}
int linear_idx = thread_idx + block_work_size * idx;
auto input_offsets = input_calculator.get(linear_idx);
${load_inputs}
// printf(
// "thread %d a %f offsets %d\n", threadIdx.x, arg0[j], input_offsets[0]);
thread_idx += num_threads;
}
#pragma unroll
for (int j = 0; j < thread_work_size; j++) {
if ((threadIdx.x + j*num_threads) < remaining) {
${call_functor}
}
}
thread_idx = threadIdx.x;
#pragma unroll
for (int j = 0; j < thread_work_size; j++){
if (thread_idx >= remaining) {
break;
}
//TODO maybe think about unifying offset calculators and reuse
//offsets computed in the load loop
int linear_idx = thread_idx + block_work_size * idx;
auto output_offsets = output_calculator.get(linear_idx);
//printf("output thread %d offset %d\n", threadIdx.x, output_offsets[0]);
${store_outputs}
thread_idx += num_threads;
}
}
)ESCAPE";
const std::string jit_vectorized_code_template = R"ESCAPE(
${load_support}
template <typename scalar_t>
__device__ __inline__ scalar_t load(char* base_ptr, uint32_t offset) {
return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
}
template<typename scalar_t>
__device__ __inline__ void store(scalar_t value, char *base_ptr, uint32_t offset) {
*(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
}
// aligned vector generates vectorized load/store on CUDA
template<typename scalar_t, int vec_size>
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
scalar_t val[vec_size];
};
template <int vec_size, typename scalar_t>
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
using vec_t = aligned_vector<scalar_t, vec_size>;
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
return from[offset];
}
template <int vec_size>
__device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
// See NOTE [Loading boolean values]
auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
aligned_vector<bool, vec_size> ret;
for (int i = 0; i < vec_size; ++i) {
ret.val[i] = bool(tmp.val[i]);
}
return ret;
}
${functor}
// TODO: setup grid-stride loop
extern "C" __global__
void ${name}_vectorized${vec_size}_kernel(
const int N,
Array<char*, ${nInputs}+${nOutputs}> data,
${compute_type} scalar_val${extra_params}) //[${nInputs}+${nOutputs}],
{
constexpr int vec_size = ${vec_size};
using scalar_t = ${scalar_type};
int remaining = N - block_work_size * blockIdx.x;
int thread_idx = threadIdx.x;
int idx = blockIdx.x;
${declare_load_arrays}
${declare_store_arrays}
if (remaining < block_work_size) {
#pragma unroll
for (int j = 0; j < thread_work_size; j++){
if (thread_idx >= remaining) {
break;
}
int linear_idx = thread_idx + block_work_size * idx;
${load_unrolled_inputs}
thread_idx += num_threads;
}
#pragma unroll
for (int j = 0; j < thread_work_size; j++) {
if ((threadIdx.x + j*num_threads) < remaining) {
${call_functor}
}
}
thread_idx = threadIdx.x;
#pragma unroll
for (int j = 0; j < thread_work_size; j++) {
if (thread_idx >= remaining) {
break;
}
int linear_idx = thread_idx + block_work_size * idx;
${store_unrolled_outputs}
thread_idx += num_threads;
}
} else {
static constexpr int loop_size = thread_work_size / vec_size;
//actual loading
${vector_inputs}
#pragma unroll
for (int i = 0; i<loop_size; i++){
${load_vectorized_inputs}
thread_idx += num_threads;
}
#pragma unroll
for (int j = 0; j < thread_work_size; j++) {
${call_functor}
}
using vec_t_output = aligned_vector<${result_type}, vec_size>;
${vector_outputs}
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i<loop_size; i++){
vec_t_output v;
${store_vectorized_outputs}
thread_idx += num_threads;
}
}
}
)ESCAPE";
static void replace_all(std::string& s, const std::string& to_replace, const std::string& replace_with) {
std::ostringstream oss;
std::size_t pos = 0;
std::size_t prev_pos = pos;
while (true) {
prev_pos = pos;
pos = s.find(to_replace, pos);
if (pos == std::string::npos)
break;
oss << s.substr(prev_pos, pos - prev_pos);
oss << replace_with;
pos += to_replace.size();
}
oss << s.substr(prev_pos);
s = oss.str();
}
// hipify replaces certain device math functions, e.g., std::max -> ::max
// See torch/utils/hipify/cuda_to_hip_mappings.py.
// Replace them back. Search for " ::<name>" to avoid duplicate replacements.
static std::string unhipify_math_functions(const std::string &original) {
static std::vector<std::pair<std::string,std::string>> mappings = {
{" std::max", " ::max"},
{" std::min", " ::min"},
{" std::ceil", " ::ceil"},
{" std::floor", " ::floor"},
{" std::exp", " ::exp"},
{" std::log", " ::log"},
{" std::pow", " ::pow"},
{" std::fabs", " ::fabs"},
{" std::fmod", " ::fmod"},
{" std::remainder", " ::remainder"},
{" std::frexp", " ::frexp"}
};
std::string ret = original;
for (const auto& mapping : mappings) {
replace_all(ret, mapping.second, mapping.first);
}
return ret;
}
// The following is copied from fused_kernel.cpp
// TODO: refactor codegenOutputQuery into its own file
// that can be included by both files
// See NOTE [ USE OF NVRTC AND DRIVER API ]
const at::cuda::NVRTC& nvrtc() {
return at::globalContext().getNVRTC();
}
// query codegen output arch and target
// TODO refactor so this function is usable both from jit and from aten
void codegenOutputQuery(
const cudaDeviceProp* const prop,
int& cuda_major,
int& cuda_minor,
int& nvrtc_major,
int& nvrtc_minor,
bool& compile_to_sass) {
#ifdef USE_ROCM
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
cuda_major = prop->major;
cuda_minor = prop->minor;
compile_to_sass = false;
#else
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
TORCH_CHECK(
nvrtc_major >= 6, "NVRTC versions less than 6 are not supported. Is: ", nvrtc_major);
// Version supported by device
// Usually any lower version works too but is less efficient
using CUDAVersion = std::pair<int, int>;
const CUDAVersion nvrtc_version{nvrtc_major, nvrtc_minor};
const CUDAVersion dev_version{prop->major, prop->minor};
// Maximum version supported by the driver, cap dev_version to this
CUDAVersion max_dev_version;
if (nvrtc_major <= 7) { // 7 supports 2-5.x
max_dev_version = CUDAVersion(5, 0);
} else if (nvrtc_major <= 8) { // 8 supports 2-6.x
max_dev_version = CUDAVersion(6, 0);
} else if (nvrtc_major <= 9) { // 9 supports 3-7.2
max_dev_version = CUDAVersion(7, 2);
} else if (nvrtc_major <= 10) { // 10 supports 3-7.5
max_dev_version = CUDAVersion(7, 5);
} else if (nvrtc_version == CUDAVersion(11, 0)) { // 11.0 supports 3-8.0
max_dev_version = CUDAVersion(8, 0);
} else if (nvrtc_major == 11 && nvrtc_minor < 8) {
max_dev_version = CUDAVersion(8, 6);
} else {
// If the driver version is unknown (i.e. newer than this code)
// assume the driver supports this device
max_dev_version = dev_version;
}
if (dev_version > max_dev_version) {
cuda_major = max_dev_version.first;
cuda_minor = max_dev_version.second;
// if we are clamping major/minor, sass is not compatible
compile_to_sass = false;
} else {
cuda_major = dev_version.first;
cuda_minor = dev_version.second;
compile_to_sass = true;
}
#if defined(CUDA_VERSION) && CUDA_VERSION < 11010
// compile to sass is not allowed prior to CUDA 11.1
compile_to_sass = false;
#endif
#endif
}
// TODO: another copy paste from jit, refactor so it's usable from both
// TODO: try making the CUcontext thread local to see if that improves performance - why is this slow?
void __inline__ initializeCudaContext() {
// lazily construct context if non-existing yet;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CUcontext pctx = nullptr;
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx));
if (!pctx) {
std::unique_lock<std::mutex> cudaFreeMutexLock(
*(c10::cuda::getFreeMutex()));
cudaFree(nullptr);
}
}
std::string generate_code(
const KernelDescriptor &desc,
bool contiguous,
bool dynamic_casting,
BinaryFuncVariant scalar_pos,
bool vectorized,
int vec_size,
bool return_by_ref) {
c10::SmallVector<std::string> extra_args_typenames(desc.extra_args_types.size());
for (auto i : c10::irange(extra_args_typenames.size())) {
extra_args_typenames[i] = typeName(desc.extra_args_types[i]);
}
return generate_code(
desc.nInputs,
desc.nOutputs,
desc.f,
desc.name,
typeName(desc.f_inputs_type),
typeName(toOpMathType(desc.f_inputs_type)),
typeName(desc.result_type),
contiguous,
dynamic_casting,
scalar_pos,
extra_args_typenames,
vectorized,
vec_size,
return_by_ref);
}
//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
#define THREAD_WORK_SIZE 4
constexpr int thread_work_size = THREAD_WORK_SIZE;
std::string generate_code(
int nInputs,
int nOutputs,
const std::string& func_,
const std::string& name,
const std::string& f_inputs_type,
const std::string& compute_type,
const std::string& result_type,
bool contiguous,
bool dynamic_casting,
BinaryFuncVariant scalar_pos,
c10::SmallVector<std::string>& extra_args_typenames,
bool vectorized,
int vec_size,
bool return_by_ref) {
std::string func = func_;
at::jit::TemplateEnv env;
env.s("index_type", "unsigned int");
env.s("nInputs", std::to_string(nInputs));
env.s("nOutputs", std::to_string(nOutputs));
env.s("scalar_type", f_inputs_type);
env.s("compute_type", compute_type);
env.s("functor", func);
env.s("name", name);
env.s("cmath_string", get_cmath_string());
// Generate `extra_params` for function signature
// and `extra_args` for computation call if
// extra arguments to capture runtime state are passed.
// (look at polygamma for example).