Skip to content

Commit

Permalink
a better way to handle nccl error under elastic scenario (horovod#3112)
Browse files Browse the repository at this point in the history
Signed-off-by: guoze.lin <guozelin@tencent.com>
Signed-off-by: weihanmines <weihan13@amd.com>
  • Loading branch information
woodlgz authored and weihanmines committed Dec 10, 2021
1 parent cc1fda8 commit 8809220
Show file tree
Hide file tree
Showing 9 changed files with 872 additions and 587 deletions.
226 changes: 114 additions & 112 deletions horovod/common/operations.cc

Large diffs are not rendered by default.

80 changes: 45 additions & 35 deletions horovod/common/ops/adasum_gpu_operations.cc
Expand Up @@ -49,7 +49,8 @@ Status AdasumGpuAllreduceOp::Execute(std::vector<TensorTableEntry>& entries,

WaitForData(entries);

// Lazily initialize reduction communicators for VHDD algorithm when Adasum reduction is actually called.
// Lazily initialize reduction communicators for VHDD algorithm when Adasum
// reduction is actually called.
if (!reduction_comms_initialized) {
InitializeVHDDReductionComms();
}
Expand All @@ -66,7 +67,7 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
const Response& response) {
assert(!entries.empty());
auto& first_entry = entries[0];
assert(first_entry.process_set_id == 0); // TODO: generalize
assert(first_entry.process_set_id == 0); // TODO: generalize
auto& process_set =
global_state_->process_set_table.Get(first_entry.process_set_id);

Expand All @@ -88,20 +89,22 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
MemcpyInFusionBuffer(entries, fused_input_data, buffer_data, buffer_len);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue,
MEMCPY_IN_FUSION_BUFFER,
*gpu_op_context_.stream);
MEMCPY_IN_FUSION_BUFFER,
*gpu_op_context_.stream);
}
} else {
fused_input_data = first_entry.tensor->data();
buffer_data = (void*)first_entry.output->data();
buffer_len = (size_t)first_entry.output->size();
}

int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype());
int64_t num_elements =
buffer_len / DataType_Size(first_entry.tensor->dtype());

if (response.prescale_factor() != 1.0) {
// Execute prescaling op
ScaleBuffer(response.prescale_factor(), entries, fused_input_data, buffer_data, num_elements);
ScaleBuffer(response.prescale_factor(), entries, fused_input_data,
buffer_data, num_elements);
fused_input_data = buffer_data; // for unfused, scale is done out of place
}

Expand Down Expand Up @@ -134,9 +137,8 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
// non-divisible part (if any), do NCCL Reduce (at rank local_size-1),
// MPI Allreduce (across rank (local_size-1)'s), and NCCL Bcast

int64_t num_elements_per_rank = process_set.controller->IsHomogeneous()
? num_elements / local_size
: 0;
int64_t num_elements_per_rank =
process_set.controller->IsHomogeneous() ? num_elements / local_size : 0;

size_t buffer_len_per_rank = element_size * num_elements_per_rank;

Expand Down Expand Up @@ -172,25 +174,28 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
(size_t)num_elements_per_rank, GetNCCLDataType(first_entry.tensor),
ncclSum, *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream);

nccl_context_->ErrorCheck("ncclReduceScatter", nccl_result, *nccl_op_context_.nccl_comm_);
nccl_context_->ErrorCheck("ncclReduceScatter", nccl_result,
*nccl_op_context_.nccl_comm_);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue,
NCCL_REDUCESCATTER, *gpu_op_context_.stream);
gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCESCATTER,
*gpu_op_context_.stream);
}
}

if (num_elements_remaining > 0) {
// Reduce the remaining data at local_size-1 to append to
// existing buffer
auto nccl_result = ncclReduce(
fused_input_data_remainder, buffer_data_remainder,
(size_t)num_elements_remaining, GetNCCLDataType(first_entry.tensor),
ncclSum, root_rank, *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream);

nccl_context_->ErrorCheck("ncclReduce", nccl_result, *nccl_op_context_.nccl_comm_);
auto nccl_result =
ncclReduce(fused_input_data_remainder, buffer_data_remainder,
(size_t)num_elements_remaining,
GetNCCLDataType(first_entry.tensor), ncclSum, root_rank,
*nccl_op_context_.nccl_comm_, *gpu_op_context_.stream);

nccl_context_->ErrorCheck("ncclReduce", nccl_result,
*nccl_op_context_.nccl_comm_);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCE,
*gpu_op_context_.stream);
*gpu_op_context_.stream);
}
}

Expand All @@ -199,8 +204,13 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
// a buffer is not safe since the tensor can be arbitrarily large.
host_buffer = GetHostBuffer((uint64_t)total_buffer_len);
// Synchronize.
gpu_context_->WaitForEvents(gpu_op_context_.event_queue, entries,
timeline, nullptr, global_state_->elastic_enabled);
if (global_state_->elastic_enabled) {
gpu_context_->WaitForEventsElastic(gpu_op_context_.event_queue, entries,
timeline, nullptr);
} else {
gpu_context_->WaitForEvents(gpu_op_context_.event_queue, entries,
timeline, nullptr);
}

// According to https://docs.nvidia.com/cuda/cuda-runtime-api/
// api-sync-behavior.html#api-sync-behavior__memcpy-async,
Expand Down Expand Up @@ -263,24 +273,24 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
entries, (void*)host_buffer, (void*)recv_buffer, tensor_counts,
local_size, // start_level
mpi_context_->GetMPICommunicator(process_set.controller->IsHomogeneous()
? Communicator::GLOBAL
: Communicator::CROSS),
? Communicator::GLOBAL
: Communicator::CROSS),
0, reduction_comms_, first_entry.tensor->dtype(), global_state_);
timeline.ActivityEndAll(entries);

timeline.ActivityStartAll(entries, MEMCPY_OUT_HOST_BUFFER);
gpu_context_->MemcpyAsyncH2D(buffer_data_at_rank_offset,
host_buffer, total_buffer_len,
*gpu_op_context_.stream);
gpu_context_->MemcpyAsyncH2D(buffer_data_at_rank_offset, host_buffer,
total_buffer_len, *gpu_op_context_.stream);
timeline.ActivityEndAll(entries);
}

if (num_elements_per_rank > 0) {
nccl_context_->ErrorCheck(
"ncclAllGather", ncclAllGather(buffer_data_at_rank_offset, buffer_data,
(size_t)num_elements_per_rank,
GetNCCLDataType(first_entry.tensor),
*nccl_op_context_.nccl_comm_, *gpu_op_context_.stream),
"ncclAllGather",
ncclAllGather(buffer_data_at_rank_offset, buffer_data,
(size_t)num_elements_per_rank,
GetNCCLDataType(first_entry.tensor),
*nccl_op_context_.nccl_comm_, *gpu_op_context_.stream),
*nccl_op_context_.nccl_comm_);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_ALLGATHER,
Expand Down Expand Up @@ -312,7 +322,8 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,

if (response.postscale_factor() != 1.0) {
// Execute postscaling op
ScaleBuffer(response.postscale_factor(), entries, buffer_data, buffer_data, num_elements);
ScaleBuffer(response.postscale_factor(), entries, buffer_data, buffer_data,
num_elements);
}

// Copy memory out of the fusion buffer.
Expand All @@ -329,10 +340,9 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
return gpu_op_context_.FinalizeGPUQueue(entries, false);
}

bool AdasumGpuAllreduceOp::Enabled(
const ParameterManager& param_manager,
const std::vector<TensorTableEntry>& entries,
const Response& response) const {
bool AdasumGpuAllreduceOp::Enabled(const ParameterManager& param_manager,
const std::vector<TensorTableEntry>& entries,
const Response& response) const {
return entries[0].device != CPU_DEVICE_ID;
}
} // namespace common
Expand Down
142 changes: 92 additions & 50 deletions horovod/common/ops/cuda_operations.cc
Expand Up @@ -14,10 +14,10 @@
// limitations under the License.
// =============================================================================

#include "gpu_operations.h"
#include "cuda/cuda_kernels.h"
#include "../message.h"
#include "../hashes.h"
#include "../message.h"
#include "cuda/cuda_kernels.h"
#include "gpu_operations.h"

#include <thread>

Expand All @@ -39,7 +39,8 @@ class GPUContext::impl {
auto& queue = cuda_events[key];
if (!prepopulated[key]) {
// On first call for device and stream pair, prepopulate event queue.
// This is to minimize event reuse of callback events passed to framework.
// This is to minimize event reuse of callback events passed to
// framework.
for (int i = 0; i < N_CUDA_EVENTS_PREPOPULATE; ++i) {
cudaEvent_t ev;
status = cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);
Expand Down Expand Up @@ -81,28 +82,32 @@ class GPUContext::impl {

void ErrorCheck(std::string op_name, cudaError_t cuda_result) {
if (cuda_result != cudaSuccess) {
throw std::logic_error(std::string(op_name) + " failed: " + cudaGetErrorString(cuda_result));
throw std::logic_error(std::string(op_name) +
" failed: " + cudaGetErrorString(cuda_result));
}
}

void RecordEvent(std::queue<std::pair<std::string, Event>>& event_queue, std::string name, cudaStream_t& stream) {
void RecordEvent(std::queue<std::pair<std::string, Event>>& event_queue,
std::string name, cudaStream_t& stream) {
Event event;
ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream));
ErrorCheck("cudaEventRecord", cudaEventRecord(*(event.event), event.stream));
ErrorCheck("cudaEventRecord",
cudaEventRecord(*(event.event), event.stream));
event_queue.emplace(name, event);
}

Event RecordEvent(cudaStream_t& stream) {
Event event;
ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream));
ErrorCheck("cudaEventRecord", cudaEventRecord(*(event.event), event.stream));
ErrorCheck("cudaEventRecord",
cudaEventRecord(*(event.event), event.stream));
return event;
}

void WaitForEvents(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback,
bool elastic) {
const std::vector<TensorTableEntry>& entries,
Timeline& timeline,
const std::function<void()>& error_check_callback) {
while (!event_queue.empty()) {
std::string name;
Event event;
Expand All @@ -112,32 +117,54 @@ class GPUContext::impl {
timeline.ActivityStartAll(entries, name);
}

// Check for async (networking) errors while waiting for the event to complete
if (elastic) {
cudaError_t cuda_result;
while (true) {
cuda_result = cudaEventQuery(*(event.event));
if (cuda_result == cudaSuccess) {
break;
}

if (cuda_result != cudaErrorNotReady) {
throw std::logic_error(std::string("cudaEventQuery failed: ") + cudaGetErrorString(cuda_result));
}

if (error_check_callback) {
error_check_callback();
}
std::this_thread::yield();
cudaError_t cuda_result = cudaEventSynchronize(*(event.event));
if (cuda_result != cudaSuccess) {
throw std::logic_error(std::string("cudaEventSynchronize failed: ") +
cudaGetErrorString(cuda_result));
}
if (error_check_callback) {
error_check_callback();
}

if (name != "") {
timeline.ActivityEndAll(entries);
}
ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event));
}
}

void
WaitForEventsElastic(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries,
Timeline& timeline,
const std::function<void()>& error_check_callback) {
while (!event_queue.empty()) {
std::string name;
Event event;
std::tie(name, event) = event_queue.front();
event_queue.pop();
if (name != "") {
timeline.ActivityStartAll(entries, name);
}

// Check for async (networking) errors while waiting for the event to
// complete
cudaError_t cuda_result;
while (true) {
cuda_result = cudaEventQuery(*(event.event));
if (cuda_result == cudaSuccess) {
break;
}
} else {
cudaError_t cuda_result = cudaEventSynchronize(*(event.event));
if (cuda_result != cudaSuccess) {
throw std::logic_error(std::string("cudaEventSynchronize failed: ") + cudaGetErrorString(cuda_result));

if (cuda_result != cudaErrorNotReady) {
throw std::logic_error(std::string("cudaEventQuery failed: ") +
cudaGetErrorString(cuda_result));
}

if (error_check_callback) {
error_check_callback();
}
std::this_thread::yield();
}

if (name != "") {
Expand All @@ -148,9 +175,10 @@ class GPUContext::impl {
}

void ClearEvents(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback,
bool elastic) {
const std::vector<TensorTableEntry>& entries,
Timeline& timeline,
const std::function<void()>& error_check_callback,
bool elastic) {
while (!event_queue.empty()) {
std::string name;
Event event;
Expand All @@ -167,12 +195,13 @@ class GPUContext::impl {
}
}

void StreamCreate(cudaStream_t *stream) {
void StreamCreate(cudaStream_t* stream) {
int greatest_priority;
ErrorCheck("cudaDeviceGetStreamPriorityRange",
cudaDeviceGetStreamPriorityRange(NULL, &greatest_priority));
cudaDeviceGetStreamPriorityRange(NULL, &greatest_priority));
ErrorCheck("cudaStreamCreateWithPriority",
cudaStreamCreateWithPriority(stream, cudaStreamNonBlocking, greatest_priority));
cudaStreamCreateWithPriority(stream, cudaStreamNonBlocking,
greatest_priority));
}

void StreamSynchronize(cudaStream_t stream) {
Expand All @@ -189,29 +218,42 @@ class GPUContext::impl {
ErrorCheck("cudaSetDevice", cudaSetDevice(device));
}

void MemcpyAsyncD2D(void* dst, const void* src, size_t count, cudaStream_t stream) {
ErrorCheck("cudaMemcpyAsync", cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream));
void MemcpyAsyncD2D(void* dst, const void* src, size_t count,
cudaStream_t stream) {
ErrorCheck(
"cudaMemcpyAsync",
cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream));
}

void MemcpyAsyncH2D(void* dst, const void* src, size_t count, cudaStream_t stream) {
ErrorCheck("cudaMemcpyAsync", cudaMemcpyAsync(dst, src, count, cudaMemcpyHostToDevice, stream));
void MemcpyAsyncH2D(void* dst, const void* src, size_t count,
cudaStream_t stream) {
ErrorCheck(
"cudaMemcpyAsync",
cudaMemcpyAsync(dst, src, count, cudaMemcpyHostToDevice, stream));
}

void MemcpyAsyncD2H(void* dst, const void* src, size_t count, cudaStream_t stream) {
ErrorCheck("cudaMemcpyAsync", cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToHost, stream));
void MemcpyAsyncD2H(void* dst, const void* src, size_t count,
cudaStream_t stream) {
ErrorCheck(
"cudaMemcpyAsync",
cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToHost, stream));
}

void ScaleBufferImpl(const void* fused_input_data, void* buffer_data, int64_t num_elements,
double scale_factor, DataType dtype, cudaStream_t stream) {
ScaleBufferCudaImpl(fused_input_data, buffer_data, num_elements, scale_factor, dtype, stream);
void ScaleBufferImpl(const void* fused_input_data, void* buffer_data,
int64_t num_elements, double scale_factor,
DataType dtype, cudaStream_t stream) {
ScaleBufferCudaImpl(fused_input_data, buffer_data, num_elements,
scale_factor, dtype, stream);

// TODO: https://github.com/horovod/horovod/issues/2230
//ErrorCheck("ScaleBufferCudaImpl", cudaGetLastError());
// ErrorCheck("ScaleBufferCudaImpl", cudaGetLastError());
}

private:
// We reuse CUDA events as it appears that their creation carries non-zero cost.
std::unordered_map<std::pair<int, cudaStream_t>, std::queue<Event>> cuda_events;
// We reuse CUDA events as it appears that their creation carries non-zero
// cost.
std::unordered_map<std::pair<int, cudaStream_t>, std::queue<Event>>
cuda_events;
std::unordered_map<std::pair<int, cudaStream_t>, bool> prepopulated;
std::mutex cuda_events_mutex;

Expand Down

0 comments on commit 8809220

Please sign in to comment.