Skip to content

Commit

Permalink
CUDA 11.0
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 13, 2022
1 parent 2951520 commit 9b3cd3e
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1558,17 +1558,24 @@ class CUDAEvent {
cudaEvent_t Get() const { return event_; }
};

#ifndef cudaEventWaitDefault
#define cudaEventWaitDefault 0x00
#endif

class CUDAStreamView {
cudaStream_t stream_{nullptr};

public:
explicit CUDAStreamView(cudaStream_t s) : stream_{s} {}
void Wait(CUDAEvent const &e) {
// #if CUDA
#if defined(__CUDACC_VER_MAJOR__)
#if __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0
// CUDA == 11.0
dh::safe_cuda(cudaStreamWaitEvent(stream_, e.Get(), 0));
#else
// CUDA > 11.0
dh::safe_cuda(cudaStreamWaitEvent(stream_, e.Get(), cudaEventWaitDefault));
#endif // __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0:
#else // clang
dh::safe_cuda(cudaStreamWaitEvent(stream_, e.Get(), cudaEventWaitDefault));
#endif // defined(__CUDACC_VER_MAJOR__)
}
operator cudaStream_t() const { // NOLINT
return stream_;
Expand Down

0 comments on commit 9b3cd3e

Please sign in to comment.