Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement more asynchronous dependency handling on GPU #2963

Merged
merged 3 commits into from Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 93 additions & 2 deletions horovod/common/common.h
Expand Up @@ -22,10 +22,50 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include "message.h"
#include "nvtx_op_range.h"

#if HAVE_GPU
#if HAVE_CUDA
#include <cuda_runtime.h>
using gpuError_t = cudaError_t;
using gpuEvent_t = cudaEvent_t;
using gpuStream_t = cudaStream_t;
#define gpuEventCreateWithFlags cudaEventCreateWithFlags
#define gpuEventDisableTiming cudaEventDisableTiming
#define gpuEventRecord cudaEventRecord
#define gpuEventSynchronize cudaEventSynchronize
#define gpuStreamWaitEvent cudaStreamWaitEvent
#define HVD_GPU_CHECK(x) \
do { \
cudaError_t cuda_result = x; \
if (cuda_result != cudaSuccess) { \
throw std::logic_error(std::string("GPU Error:") + cudaGetErrorString(cuda_result)); \
} \
} while (0)
#endif
#elif HAVE_ROCM
#include <hip/hip_runtime_api.h>
using gpuError_t = hipError_t;
using gpuEvent_t = hipEvent_t;
using gpuStream_t = hipStream_t;
#define gpuEventCreateWithFlags hipEventCreateWithFlags
#define gpuEventDisableTiming hipEventDisableTiming
#define gpuEventRecord hipEventRecord
#define gpuEventSynchronize hipEventSynchronize
#define gpuStreamWaitEvent hipStreamWaitEvent
#define HVD_GPU_CHECK(x) \
do { \
hipError_t hip_result = x; \
if (hip_result != hipSuccess) { \
throw std::logic_error(std::string("GPU Error:") + hipGetErrorString(hip_result)); \
} \
} while (0)
#endif


namespace horovod {
namespace common {

Expand Down Expand Up @@ -62,6 +102,7 @@ namespace common {
#define GLOO_ALLREDUCE "GLOO_ALLREDUCE"
#define GLOO_ALLGATHER "GLOO_ALLGATHER"
#define GLOO_BCAST "GLOO_BCAST"
#define HOROVOD_ELASTIC "HOROVOD_ELASTIC"

// Horovod knobs.
#define HOROVOD_MPI_THREADS_DISABLE "HOROVOD_MPI_THREADS_DISABLE"
Expand Down Expand Up @@ -94,6 +135,7 @@ namespace common {
#define HOROVOD_THREAD_AFFINITY "HOROVOD_THREAD_AFFINITY"
#define HOROVOD_DISABLE_GROUP_FUSION "HOROVOD_DISABLE_GROUP_FUSION"
#define HOROVOD_DISABLE_NVTX_RANGES "HOROVOD_DISABLE_NVTX_RANGES"
#define HOROVOD_ENABLE_ASYNC_COMPLETION "HOROVOD_ENABLE_ASYNC_COMPLETION"

// String constant for gloo interface.
#define GLOO_DEFAULT_IFACE ""
Expand Down Expand Up @@ -135,6 +177,17 @@ inline std::string CommunicatorName(Communicator comm) {
}
}

struct Event {
Event() = default;
#if HAVE_GPU
Event(std::shared_ptr<gpuEvent_t> event, gpuStream_t stream) :
event(event), stream(stream) {};
std::shared_ptr<gpuEvent_t> event;
gpuStream_t stream = nullptr;
#endif
};


class Status {
public:
Status();
Expand All @@ -148,6 +201,7 @@ class Status {
bool in_progress() const;
StatusType type() const;
const std::string& reason() const;
Event event;

private:
StatusType type_ = StatusType::OK;
Expand Down Expand Up @@ -198,6 +252,43 @@ class ReadyEvent {
public:
virtual bool Ready() const = 0;
virtual ~ReadyEvent() = default;
#if HAVE_GPU
virtual gpuEvent_t event() const = 0;
#endif

};

class ReadyEventList {
public:
bool Ready() const {
for (auto& e : ready_events_) {
if (e != nullptr && !e->Ready()) {
return false;
}
}
return true;
}

void AddReadyEvent(const std::shared_ptr<ReadyEvent>& e) {
ready_events_.emplace_back(e);
}

int size() const {
return ready_events_.size();
}

#if HAVE_GPU
void PushEventsToSet(std::unordered_set<gpuEvent_t>& event_set) {
for (auto& e : ready_events_) {
event_set.insert(e->event());
}
}
#endif

~ReadyEventList() = default;

private:
std::vector<std::shared_ptr<ReadyEvent>> ready_events_;
};

class OpContext;
Expand Down Expand Up @@ -257,8 +348,8 @@ struct TensorTableEntry {
std::shared_ptr<Tensor> output;
// Root rank for broadcast operation.
int root_rank = 0;
// Event indicating that data is ready.
std::shared_ptr<ReadyEvent> ready_event;
// List of events indicating that data is ready.
ReadyEventList ready_event_list;
// GPU to do reduction on, or CPU_DEVICE_ID in case of CPU.
int device = CPU_DEVICE_ID;
// A callback to call with the status.
Expand Down
6 changes: 6 additions & 0 deletions horovod/common/global_state.h
Expand Up @@ -57,6 +57,9 @@ struct HorovodGlobalState {
// Flag indicating whether timeline enabled.
bool timeline_enabled = false;

// Flag indicating whether running elastic.
bool elastic_enabled = false;

// Flag indicating whether to mark cycles in the timeline.
std::atomic_bool mark_cycles_in_timeline{false};

Expand Down Expand Up @@ -120,6 +123,9 @@ struct HorovodGlobalState {
// Flag indicating whether to prohibit groups from fusing
bool disable_group_fusion = false;

// Flag indicating whether to enable async completion
bool enable_async_completion = false;

~HorovodGlobalState() {
// Make sure that the destructor of the background thread is safe to
// call. If a thread is still joinable (not detached or complete) its
Expand Down
1 change: 0 additions & 1 deletion horovod/common/gloo/gloo_context.h
Expand Up @@ -40,7 +40,6 @@
#define HOROVOD_LOCAL_SIZE "HOROVOD_LOCAL_SIZE"
#define HOROVOD_CROSS_RANK "HOROVOD_CROSS_RANK"
#define HOROVOD_CROSS_SIZE "HOROVOD_CROSS_SIZE"
#define HOROVOD_ELASTIC "HOROVOD_ELASTIC"

namespace horovod {
namespace common {
Expand Down
12 changes: 12 additions & 0 deletions horovod/common/hashes.h
Expand Up @@ -73,6 +73,18 @@ template <typename U, typename V, typename W> struct hash<std::tuple<U, V, W>> {
}
};

template <typename U, typename V> struct hash<std::pair<U, V>> {
using argument_type = std::tuple<U, V>;
using result_type = std::size_t;

result_type operator()(argument_type const& in) const {
result_type seed = 0;
seed = hash_one<U>(std::get<0>(in), seed);
seed = hash_one<V>(std::get<1>(in), seed);
return seed;
}
};

template <> struct hash<horovod::common::Framework> {
std::size_t operator()(horovod::common::Framework const& in) const {
return (std::size_t)in;
Expand Down
59 changes: 19 additions & 40 deletions horovod/common/operations.cc
Expand Up @@ -284,32 +284,6 @@ void PerformOperation(Response response, HorovodGlobalState& state) {
return;
}
}

// On GPU data readiness is signalled by ready_event.
std::vector<TensorTableEntry> waiting_tensors;
for (auto& e : entries) {
if (e.ready_event != nullptr) {
timeline.ActivityStart(e.tensor_name, WAIT_FOR_DATA);
waiting_tensors.push_back(e);
}
}
while (!waiting_tensors.empty()) {
for (auto it = waiting_tensors.begin(); it != waiting_tensors.end();) {
if (it->ready_event->Ready()) {
timeline.ActivityEnd(it->tensor_name);
timeline.ActivityStart(it->tensor_name, WAIT_FOR_OTHER_TENSOR_DATA);
it = waiting_tensors.erase(it);
} else {
++it;
}
}
std::this_thread::sleep_for(std::chrono::nanoseconds(100));
}
for (auto& e : entries) {
if (e.ready_event != nullptr) {
timeline.ActivityEnd(e.tensor_name);
}
}
}

Status status;
Expand Down Expand Up @@ -437,6 +411,8 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
}
state.controller->SetTimelineEnabled(should_enable_timeline);

SetBoolFromEnv(HOROVOD_ELASTIC, state.elastic_enabled, true);

ParseStallInspectorFromEnv(state.controller->GetStallInspector());
bool mark_cycles = false;
SetBoolFromEnv(HOROVOD_TIMELINE_MARK_CYCLES, mark_cycles,
Expand Down Expand Up @@ -521,6 +497,9 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
// Check if group fusion should be disabled
SetBoolFromEnv(HOROVOD_DISABLE_GROUP_FUSION, state.disable_group_fusion, true);

// Check if async completion should be enabled
SetBoolFromEnv(HOROVOD_ENABLE_ASYNC_COMPLETION, state.enable_async_completion, true);

// Enable auto-tuning.
auto horovod_autotune = std::getenv(HOROVOD_AUTOTUNE);
if (horovod_autotune != nullptr &&
Expand Down Expand Up @@ -919,7 +898,7 @@ int horovod_reduce_op_adasum() {
Status EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output,
std::shared_ptr<ReadyEvent> ready_event,
ReadyEventList ready_event_list,
std::string name, const int device,
StatusCallback callback,
ReduceOp reduce_op,
Expand All @@ -929,26 +908,26 @@ Status EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
std::vector<std::shared_ptr<OpContext>> contexts;
std::vector<std::shared_ptr<Tensor>> tensors;
std::vector<std::shared_ptr<Tensor>> outputs;
std::vector<std::shared_ptr<ReadyEvent>> ready_events;
std::vector<ReadyEventList> ready_event_lists;
std::vector<std::string> names;
std::vector<StatusCallback> callbacks;

contexts.emplace_back(std::move(context));
tensors.emplace_back(std::move(tensor));
outputs.emplace_back(std::move(output));
ready_events.emplace_back(std::move(ready_event));
ready_event_lists.emplace_back(std::move(ready_event_list));
names.emplace_back(std::move(name));
callbacks.emplace_back(std::move(callback));

return EnqueueTensorAllreduces(contexts, tensors, outputs, ready_events,
return EnqueueTensorAllreduces(contexts, tensors, outputs, ready_event_lists,
names, device, callbacks, reduce_op,
prescale_factor, postscale_factor);
}

Status EnqueueTensorAllreduces(std::vector<std::shared_ptr<OpContext>>& contexts,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::vector<std::shared_ptr<Tensor>>& outputs,
std::vector<std::shared_ptr<ReadyEvent>>& ready_events,
std::vector<ReadyEventList>& ready_event_lists,
std::vector<std::string>& names,
const int device,
std::vector<StatusCallback>& callbacks,
Expand Down Expand Up @@ -1008,7 +987,7 @@ Status EnqueueTensorAllreduces(std::vector<std::shared_ptr<OpContext>>& contexts
e.tensor = tensors[n];
e.output = outputs[n];
}
e.ready_event = std::move(ready_events[n]);
e.ready_event_list = std::move(ready_event_lists[n]);
e.device = device;
e.callback = std::move(callbacks[n]);

Expand Down Expand Up @@ -1059,7 +1038,7 @@ Status EnqueueTensorAllreduces(std::vector<std::shared_ptr<OpContext>>& contexts
// must be running before this function is called.
Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<ReadyEvent> ready_event,
ReadyEventList ready_event_list,
const std::string& name, const int device,
StatusCallback callback) {
Request message;
Expand All @@ -1076,7 +1055,7 @@ Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
e.tensor_name = name;
e.context = context;
e.tensor = tensor;
e.ready_event = ready_event;
e.ready_event_list = ready_event_list;
e.device = device;
e.callback = callback;
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodAllgather, e.tensor->size());
Expand All @@ -1096,7 +1075,7 @@ Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
Status EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output, int root_rank,
std::shared_ptr<ReadyEvent> ready_event,
ReadyEventList ready_event_list,
const std::string& name, const int device,
StatusCallback callback) {
Request message;
Expand All @@ -1116,7 +1095,7 @@ Status EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
e.tensor = tensor;
e.output = output;
e.root_rank = root_rank;
e.ready_event = ready_event;
e.ready_event_list = ready_event_list;
e.device = device;
e.callback = callback;
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodBroadcast, e.tensor->size());
Expand All @@ -1136,7 +1115,7 @@ Status EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> splits,
std::shared_ptr<ReadyEvent> ready_event,
ReadyEventList ready_event_list,
const std::string& name, const int device,
StatusCallback callback) {
// Check arguments
Expand All @@ -1161,7 +1140,7 @@ Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,
e.tensor_name = name;
e.context = context;
e.tensor = tensor;
e.ready_event = ready_event;
e.ready_event_list = ready_event_list;
e.device = device;
e.callback = callback;
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodAlltoall, e.tensor->size());
Expand Down Expand Up @@ -1200,7 +1179,7 @@ Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,
// Contexts and controller must be initialized and the background thread
// must be running before this function is called.
Status EnqueueJoin(std::shared_ptr<OpContext> context,
std::shared_ptr<ReadyEvent> ready_event,
ReadyEventList ready_event_list,
const std::string& name, const int device,
StatusCallback callback) {
Request message;
Expand All @@ -1211,7 +1190,7 @@ Status EnqueueJoin(std::shared_ptr<OpContext> context,
TensorTableEntry e;
e.tensor_name = name;
e.context = context;
e.ready_event = ready_event;
e.ready_event_list = ready_event_list;
e.device = device;
e.callback = callback;

Expand Down