forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fused_kernel.cpp
291 lines (264 loc) · 9.58 KB
/
fused_kernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
#include <torch/csrc/jit/codegen/fuser/compiler.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/jit/resource_guard.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cmath>
#include <sstream>
#include <stdexcept>
#include <tuple>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
// See NOTE [ USE OF NVRTC AND DRIVER API ]
const at::cuda::NVRTC& nvrtc() {
return at::globalContext().getNVRTC();
}
// query codegen output arch and target
void codegenOutputQuery(
const cudaDeviceProp* const prop,
int& major,
int& minor,
bool& compile_to_sass) {
#ifdef USE_ROCM
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&major, &minor));
compile_to_sass = false;
#else
using CudaVersion = std::pair<int, int>;
CudaVersion nvrtc_version;
AT_CUDA_NVRTC_CHECK(
nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));
TORCH_CHECK(
nvrtc_version.first >= 6,
"NVRTC versions less than 6 are not supported. Is: ",
nvrtc_version.first);
// Version supported by device
// Usually any lower version works too but is less efficient
const CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
// Maximum version supported by the driver, cap dev_version to this
CudaVersion max_dev_version;
if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
max_dev_version = CudaVersion(5, 0);
} else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
max_dev_version = CudaVersion(6, 0);
} else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
max_dev_version = CudaVersion(7, 2);
} else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
max_dev_version = CudaVersion(7, 5);
} else if (nvrtc_version == CudaVersion(11, 0)) { // 11.0 supports 3-8.0
max_dev_version = CudaVersion(8, 0);
} else if (nvrtc_version.first == 11 && nvrtc_version.second < 8) {
max_dev_version = CudaVersion(8, 6);
} else {
// If the driver version is unknown (i.e. newer than this code)
// assume the driver supports this device
max_dev_version = dev_version;
}
if (dev_version > max_dev_version) {
major = max_dev_version.first;
minor = max_dev_version.second;
// if we are clamping major/minor, sass is not compatible
compile_to_sass = false;
} else {
major = dev_version.first;
minor = dev_version.second;
compile_to_sass = true;
}
#endif
}
// Compiles the specified kernel and stores the metadata required to run it
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FusedKernelCUDA::FusedKernelCUDA(
at::DeviceIndex device,
std::string name,
std::string code,
std::vector<TensorDesc> input_desc,
std::vector<TensorDesc> output_desc,
std::vector<PartitionDesc> chunk_desc,
std::vector<PartitionDesc> concat_desc,
bool has_random)
: FusedKernel(
std::move(name),
std::move(code),
std::move(input_desc),
std::move(output_desc),
std::move(chunk_desc),
std::move(concat_desc),
has_random),
device_(device) {
// Initializes driver's API context (if necessary)
executor_utils::initializeCudaContext();
// Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
// properly in some scenarios
const auto prior_device = at::cuda::current_device();
at::cuda::set_device(device_);
// Acquires device and NVRTC properties (for compile arch and occupancy
// calculations)
prop_ = at::cuda::getCurrentDeviceProperties();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int major, minor;
bool compile_to_sass = false;
codegenOutputQuery(prop_, major, minor, compile_to_sass);
// Creates the NVRTC program
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
nvrtcProgram program;
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
&program, code_.c_str(), nullptr, 0, nullptr, nullptr));
#if defined(USE_ROCM)
std::vector<const char*> args = {"--std=c++14"};
#if ROCM_VERSION >= 40200
args.push_back("-hip-pch");
#endif
#else
const std::string compute = std::string("--gpu-architecture=") +
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
// CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
// which gives better backwards compatibility to work on older driver,
// (since older driver doesn't necessrily recognize PTX emitted by new
// toolkit);
// Meanwhile, for forward compatibility (future device with
// `compile_to_sass==false`), since SASS are not necessarily compatible,
// we fallback to PTX instead.
(compile_to_sass ? "sm_" : "compute_") +
#else
"compute_" +
#endif
std::to_string(major) + std::to_string(minor);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const std::vector<const char*> args = {
"--std=c++14", compute.c_str(), "-default-device"};
#endif
const auto result =
nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
if (result != NVRTC_SUCCESS) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t logsize;
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<char> log(logsize);
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
std::stringstream cu;
cu << log.data();
throw std::runtime_error(cu.str());
}
ResourceGuard holdProgram(
[&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
AT_CUDA_NVRTC_CHECK(result);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t ptx_size;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
// compile_to_sass determines whether we are generating SASS or PTX, hence
// the different API.
const auto getSize = compile_to_sass
? at::globalContext().getNVRTC().nvrtcGetCUBINSize
: at::globalContext().getNVRTC().nvrtcGetPTXSize;
const auto getFunc = compile_to_sass
? at::globalContext().getNVRTC().nvrtcGetCUBIN
: at::globalContext().getNVRTC().nvrtcGetPTX;
#else
const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
#endif
AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
ptx_.resize(ptx_size);
AT_CUDA_NVRTC_CHECK(getFunc(program, ptx_.data()));
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
AT_CUDA_DRIVER_CHECK(
nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
// Computes max blocks
#if defined(USE_ROCM) && ROCM_VERSION < 30500
// HIP function signature is not compatible yet
uint32_t max_blocks;
AT_CUDA_DRIVER_CHECK(nvrtc().hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_blocks, function_, 128, 0));
maxBlocks_ = max_blocks;
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
&maxBlocks_, function_, 128, 0));
#endif
maxBlocks_ *= prop_->multiProcessorCount;
// Resets device (end of hacked at::DeviceGuard)
at::cuda::set_device(prior_device);
}
static int ceilDiv(const int a, const int b) {
return (a + b - 1) / b;
}
void FusedKernelCUDA::launch_raw(
const uint32_t numel,
std::vector<void*>& arguments) const {
// NOLINTNEXTLINE(bugprone-unused-raii)
at::cuda::CUDAGuard{device_};
// Hacked at::DeviceGuard (see note above)
const auto prior_device = at::cuda::current_device();
at::cuda::set_device(device_);
const auto nBlocks = std::min(maxBlocks_, ceilDiv(numel, kBlockSize));
// Adds random state to arguments if necessary
// Note: philox_engine_inputs defined here so its lifetime extends to the
// launch
std::pair<uint64_t, uint64_t> philox_engine_inputs;
if (has_random_) {
const auto rand_offset =
4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
philox_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
rand_offset);
}
arguments.push_back(&philox_engine_inputs.first);
arguments.push_back(&philox_engine_inputs.second);
}
// Launches kernel on current stream (device was set by executor)
auto stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
function_,
nBlocks,
1,
1,
kBlockSize,
1,
1,
0,
stream,
arguments.data(),
nullptr));
// Resets device (see at::DeviceGuard notes above)
at::cuda::set_device(prior_device);
}
FusedKernelCUDA::~FusedKernelCUDA() {
nvrtc().cuModuleUnload(module_);
}
static std::shared_ptr<FusedKernel> createFusionKernel(
int16_t device,
std::string name,
std::string code,
std::vector<TensorDesc> input_desc,
std::vector<TensorDesc> output_desc,
std::vector<PartitionDesc> chunk_desc,
std::vector<PartitionDesc> concat_desc,
bool has_random) {
return std::make_shared<FusedKernelCUDA>(
static_cast<at::DeviceIndex>(device),
std::move(name),
std::move(code),
std::move(input_desc),
std::move(output_desc),
std::move(chunk_desc),
std::move(concat_desc),
has_random);
}
RegisterFusionBackend reg(DeviceType::CUDA, createFusionKernel);
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch