Skip to content

Commit

Permalink
call_once (#43206)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoxiaohehe001 committed Jun 8, 2022
1 parent 8dab269 commit cad139a
Showing 1 changed file with 82 additions and 40 deletions.
122 changes: 82 additions & 40 deletions paddle/phi/backends/gpu/gpu_context.cc
Expand Up @@ -214,23 +214,6 @@ struct GPUContext::Impl {
&max_grid_dim_size_);
phi::InitStream(&stream_);
InitEigenDevice();
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
phi::InitBlasLtHandle(&blaslt_handle_);
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
phi::InitSolverHandle(&solver_handle_, stream_);
phi::InitSparseHandle(&sparse_handle_, stream_);
InitDnnWorkspace();
}

Expand All @@ -246,23 +229,6 @@ struct GPUContext::Impl {
&max_threads_per_block_,
&max_grid_dim_size_);
phi::InitStream(&stream_);
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
phi::InitBlasLtHandle(&blaslt_handle_);
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
phi::InitSolverHandle(&solver_handle_, stream_);
phi::InitSparseHandle(&sparse_handle_, stream_);
}

void PartialInitWithAllocator() {
Expand Down Expand Up @@ -356,7 +322,28 @@ struct GPUContext::Impl {
return eigen_device_;
}

blasHandle_t GetBlasHandle() const {
blasHandle_t GetBlasHandle() {
std::call_once(flag_blas_, [=]() {
if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
PD_CHECK(blas_handle_ != nullptr, "the gpu blas handle is nullptr.");
return blas_handle_;
}
Expand All @@ -373,12 +360,18 @@ struct GPUContext::Impl {

void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }

blasLtHandle_t GetBlasLtHandle() const {
blasLtHandle_t GetBlasLtHandle() {
std::call_once(flag_blaslt_, [=]() {
if (!blaslt_handle_) phi::InitBlasLtHandle(&blaslt_handle_);
});
PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr.");
return blaslt_handle_;
}

dnnHandle_t GetDnnHandle() {
std::call_once(flag_dnn_, [=]() {
if (!dnn_handle_) phi::InitDnnHandle(&dnn_handle_, stream_, place_);
});
PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr.");
return dnn_handle_;
}
Expand All @@ -399,7 +392,10 @@ struct GPUContext::Impl {

void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; }

solverHandle_t GetSolverHandle() const {
solverHandle_t GetSolverHandle() {
std::call_once(flag_slover_, [=]() {
if (!solver_handle_) phi::InitSolverHandle(&solver_handle_, stream_);
});
PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr.");
return solver_handle_;
}
Expand Down Expand Up @@ -461,8 +457,28 @@ struct GPUContext::Impl {
#endif
}

inline void CublasCall(
const std::function<void(blasHandle_t)>& callback) const {
inline void CublasCall(const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_cublas_, [=]() {
if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
if (blas_tf32_tensor_core_handle_ != nullptr) {
std::lock_guard<std::mutex> guard(blas_tf32_mtx_);
callback(blas_tf32_tensor_core_handle_);
Expand All @@ -473,7 +489,26 @@ struct GPUContext::Impl {
}

inline void TensorCoreCublasCallIfAvailable(
const std::function<void(blasHandle_t)>& callback) const {
const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_tensorcore_cublas_, [=]() {
if (!blas_handle_) phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
if (blas_tensor_core_handle_ != nullptr) {
std::lock_guard<std::mutex> guard(blas_tensor_core_mtx_);
callback(blas_tensor_core_handle_);
Expand Down Expand Up @@ -563,6 +598,13 @@ struct GPUContext::Impl {
sparseHandle_t sparse_handle_{nullptr};
DnnWorkspaceHandle* workspace_{nullptr};

std::once_flag flag_blas_;
std::once_flag flag_blaslt_;
std::once_flag flag_dnn_;
std::once_flag flag_slover_;
std::once_flag flag_cublas_;
std::once_flag flag_tensorcore_cublas_;

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
Expand Down

0 comments on commit cad139a

Please sign in to comment.