Skip to content

Commit

Permalink
a better way to handle nccl error under elastic scenario
Browse files Browse the repository at this point in the history
Signed-off-by: guoze.lin <guozelin@tencent.com>
  • Loading branch information
guoze.lin committed Aug 16, 2021
1 parent 1359e3a commit 9aa5d55
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 37 deletions.
1 change: 1 addition & 0 deletions horovod/common/operations.cc
Expand Up @@ -442,6 +442,7 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {

#if HAVE_NCCL
nccl_context.nccl_comms.resize(state.num_nccl_streams);
SetBoolFromEnv(HOROVOD_ELASTIC, nccl_context.elastic, true);
#endif
gpu_context.streams.resize(state.num_nccl_streams);

Expand Down
60 changes: 38 additions & 22 deletions horovod/common/ops/cuda_operations.cc
Expand Up @@ -101,8 +101,34 @@ class GPUContext::impl {

void WaitForEvents(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback,
bool elastic) {
const std::function<void()>& error_check_callback) {
while (!event_queue.empty()) {
std::string name;
Event event;
std::tie(name, event) = event_queue.front();
event_queue.pop();
if (name != "") {
timeline.ActivityStartAll(entries, name);
}

cudaError_t cuda_result = cudaEventSynchronize(*(event.event));
if (cuda_result != cudaSuccess) {
throw std::logic_error(std::string("cudaEventSynchronize failed: ") + cudaGetErrorString(cuda_result));
}
if (error_check_callback) {
error_check_callback();
}

if (name != "") {
timeline.ActivityEndAll(entries);
}
ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event));
}
}

void WaitForEventsElastic(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback) {
while (!event_queue.empty()) {
std::string name;
Event event;
Expand All @@ -113,31 +139,21 @@ class GPUContext::impl {
}

// Check for async (networking) errors while waiting for the event to complete
if (elastic) {
cudaError_t cuda_result;
while (true) {
cuda_result = cudaEventQuery(*(event.event));
if (cuda_result == cudaSuccess) {
break;
}

if (cuda_result != cudaErrorNotReady) {
throw std::logic_error(std::string("cudaEventQuery failed: ") + cudaGetErrorString(cuda_result));
}

if (error_check_callback) {
error_check_callback();
}
std::this_thread::yield();
cudaError_t cuda_result;
while (true) {
cuda_result = cudaEventQuery(*(event.event));
if (cuda_result == cudaSuccess) {
break;
}
} else {
cudaError_t cuda_result = cudaEventSynchronize(*(event.event));
if (cuda_result != cudaSuccess) {
throw std::logic_error(std::string("cudaEventSynchronize failed: ") + cudaGetErrorString(cuda_result));

if (cuda_result != cudaErrorNotReady) {
throw std::logic_error(std::string("cudaEventQuery failed: ") + cudaGetErrorString(cuda_result));
}

if (error_check_callback) {
error_check_callback();
}
std::this_thread::yield();
}

if (name != "") {
Expand Down
10 changes: 7 additions & 3 deletions horovod/common/ops/gpu_context_impl.cc
Expand Up @@ -22,9 +22,13 @@ void GPUContext::ReleaseEvent(Event event) {
}

void GPUContext::WaitForEvents(std::queue<std::pair<std::string, Event>>& event_queue, const std::vector<TensorTableEntry>& entries,
Timeline& timeline, const std::function<void()>& error_check_callback,
bool elastic) {
pimpl->WaitForEvents(event_queue, entries, timeline, error_check_callback, elastic);
Timeline& timeline, const std::function<void()>& error_check_callback) {
pimpl->WaitForEvents(event_queue, entries, timeline, error_check_callback);
}

void GPUContext::WaitForEventsElastic(std::queue<std::pair<std::string, Event>>& event_queue, const std::vector<TensorTableEntry>& entries,
Timeline& timeline, const std::function<void()>& error_check_callback) {
pimpl->WaitForEventsElastic(event_queue, entries, timeline, error_check_callback);
}

void GPUContext::ClearEvents(std::queue<std::pair<std::string, Event>>& event_queue, const std::vector<TensorTableEntry>& entries,
Expand Down
28 changes: 24 additions & 4 deletions horovod/common/ops/gpu_operations.cc
Expand Up @@ -60,6 +60,7 @@ Status GPUOpContext::FinalizeGPUQueue(std::vector<TensorTableEntry>& entries, bo
auto& evt_queue = event_queue;
auto& timeline = global_state_->timeline;
auto& gpu_context = gpu_context_;
auto& global_state = global_state_;

// Claim a std::shared_ptr to the fusion buffer to prevent its memory from being reclaimed
// during finalization.
Expand All @@ -71,13 +72,26 @@ Status GPUOpContext::FinalizeGPUQueue(std::vector<TensorTableEntry>& entries, bo
auto current_stream = *stream;
gpu_context_->finalizer_thread_pool.execute([entries, first_entry, cpu_buffer, fusion_buffer, free_host_buffer,
evt_queue, &timeline, &gpu_context, error_check_callback,
elastic, enable_async_completion, current_stream]() mutable {
elastic, enable_async_completion, current_stream, &global_state]() mutable {
gpu_context->SetDevice(first_entry.device);

Event event;
bool gpu_evt_failed = false;
std::string gpu_evt_err_msg;
if (!enable_async_completion || timeline.Initialized()) {
// If timeline is enabled, wait for events on CPU for accurate timings.
gpu_context->WaitForEvents(evt_queue, entries, timeline, error_check_callback, elastic);
if (elastic) {
try {
gpu_context->WaitForEventsElastic(evt_queue, entries, timeline, error_check_callback);
}catch(std::exception& e){
// notify background loop to exit and reinit rather than just aborting the program
global_state.shut_down = true;
gpu_evt_failed = true;
gpu_evt_err_msg = e.what();
}
}else {
gpu_context->WaitForEvents(evt_queue, entries, timeline, error_check_callback);
}
} else {
gpu_context->ClearEvents(evt_queue, entries, timeline, error_check_callback, elastic);
event = gpu_context->RecordEvent(current_stream);
Expand All @@ -87,10 +101,16 @@ Status GPUOpContext::FinalizeGPUQueue(std::vector<TensorTableEntry>& entries, bo
free(cpu_buffer);
}

Status status;
if (gpu_evt_failed) {
status = Status::UnknownError(gpu_evt_err_msg);
} else {
status = Status::OK();
status.event = event;
}

for (auto& e : entries) {
timeline.End(e.tensor_name, e.output);
auto status = Status::OK();
status.event = event;
e.FinishWithCallback(status);
}
if (enable_async_completion) {
Expand Down
7 changes: 5 additions & 2 deletions horovod/common/ops/gpu_operations.h
Expand Up @@ -75,8 +75,11 @@ class GPUContext {

void WaitForEvents(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback = nullptr,
bool elastic = false);
const std::function<void()>& error_check_callback = nullptr);

void WaitForEventsElastic(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback = nullptr);

void ClearEvents(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
Expand Down
9 changes: 7 additions & 2 deletions horovod/common/ops/hip_operations.cc
Expand Up @@ -77,8 +77,7 @@ class GPUContext::impl {

void WaitForEvents(std::queue<std::pair<std::string, hipEvent_t>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback,
bool elastic) {
const std::function<void()>& error_check_callback) {
while (!event_queue.empty()) {
std::string name;
hipEvent_t event;
Expand Down Expand Up @@ -113,6 +112,12 @@ class GPUContext::impl {
}
}

void WaitForEventsElastic(std::queue<std::pair<std::string, hipEvent_t>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback) {
WaitForEvents(event_queue, entries, timeline, error_check_callback);
}

void StreamCreate(hipStream_t *stream) {
int greatest_priority;
ErrorCheck("hipDeviceGetStreamPriorityRange",
Expand Down
19 changes: 15 additions & 4 deletions horovod/common/ops/nccl_operations.cc
Expand Up @@ -46,6 +46,19 @@ ncclDataType_t GetNCCLDataType(const std::shared_ptr<Tensor> tensor) {
}
}

void commDestroyOrAbort(ncclComm_t& nccl_comm, bool elastic) {
ncclResult_t nccl_async_err;
auto nccl_err = ncclCommGetAsyncError(nccl_comm, &nccl_async_err);
if (nccl_err != ncclSuccess) {
return;
}
if(nccl_async_err == ncclSuccess && !elastic) {
ncclCommDestroy(nccl_comm);
}else {
ncclCommAbort(nccl_comm);
}
}

void NCCLContext::ErrorCheck(std::string op_name, ncclResult_t nccl_result, ncclComm_t& nccl_comm) {
if (nccl_result != ncclSuccess) {
ncclCommAbort(nccl_comm);
Expand All @@ -56,7 +69,7 @@ void NCCLContext::ErrorCheck(std::string op_name, ncclResult_t nccl_result, nccl
void NCCLContext::ShutDown(){
for(auto it = nccl_comms.begin(); it != nccl_comms.end(); ++it) {
for (auto entry = it->begin(); entry != it->end(); ++entry) {
ncclCommDestroy(entry->second);
commDestroyOrAbort(entry->second, elastic);
}
}
nccl_comms.clear();
Expand Down Expand Up @@ -114,11 +127,9 @@ void NCCLOpContext::AsyncErrorCheck() {
}

if (nccl_async_err != ncclSuccess) {
ncclCommAbort(*nccl_comm_);
//do not call ncclCommAbort(*nccl_comm_) from event polling thread to avoid race condition
throw std::logic_error(std::string("NCCL async error: ") + ncclGetErrorString(nccl_async_err));
}


}

void NCCLOpContext::PopulateNCCLCommStrategy(int& nccl_rank, int& nccl_size,
Expand Down
2 changes: 2 additions & 0 deletions horovod/common/ops/nccl_operations.h
Expand Up @@ -46,6 +46,8 @@ struct NCCLContext {
void ErrorCheck(std::string op_name, ncclResult_t nccl_result, ncclComm_t& nccl_comm);

void ShutDown();

bool elastic;
};

class NCCLOpContext {
Expand Down

0 comments on commit 9aa5d55

Please sign in to comment.