Skip to content

Commit

Permalink
Add in-place broadcast for TensorFlow (horovod#3128)
Browse files Browse the repository at this point in the history
* Update comment in FindTensorflow.cmake

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>

* Add in-place broadcast_() and broadcast_variables() for TF

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>

* Include source files from TF in build to avoid missing symbol errors

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>

* Limit build and test to TF 2.6+

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>

* Remove source files copied from TensorFlow

The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so,
which was introduced to Horovod with PR horovod#3053.

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>

* Fix possible type attribute values for HorovodBroadcastInplace

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>

* Add reference variables to test

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>

* Update comments, doc strings, changelog

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
Signed-off-by: weihanmines <weihan13@amd.com>
  • Loading branch information
maxhgerlach authored and weihanmines committed Dec 10, 2021
1 parent 2a340b9 commit 4d396a4
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- TensorFlow: Added in-place broadcasting of variables. ([#3128](https://github.com/horovod/horovod/pull/3128))

### Changed

### Deprecated
Expand Down
2 changes: 1 addition & 1 deletion cmake/Modules/FindTensorflow.cmake
Expand Up @@ -23,7 +23,7 @@ if (LEN EQUAL "4")
list(GET Tensorflow_LIBRARIES_LIST 0 Tensorflow_LIB_PATH_ARGUMENT)
string(REGEX REPLACE "^-L" "" Tensorflow_LIB_PATH ${Tensorflow_LIB_PATH_ARGUMENT})
if (Tensorflow_VERSION VERSION_GREATER "2.6" OR Tensorflow_VERSION VERSION_EQUAL "2.6")
# # XLA implementations are in _pywrap_tensorflow_internal.so
# XLA implementations and helpers for resource variables are in _pywrap_tensorflow_internal.so
set(Tensorflow_LIBRARIES "${Tensorflow_LIBRARIES} ${Tensorflow_LIB_PATH}/python/_pywrap_tensorflow_internal.so")
endif()
message("Tensorflow_LIBRARIES := ${Tensorflow_LIBRARIES}")
Expand Down
2 changes: 1 addition & 1 deletion horovod/tensorflow/__init__.py
Expand Up @@ -26,7 +26,7 @@
from horovod.tensorflow import elastic
from horovod.tensorflow.compression import Compression
from horovod.tensorflow.functions import allgather_object, broadcast_object, broadcast_object_fn, broadcast_variables
from horovod.tensorflow.mpi_ops import allgather, broadcast, _allreduce, _grouped_allreduce, alltoall
from horovod.tensorflow.mpi_ops import allgather, broadcast, broadcast_, _allreduce, _grouped_allreduce, alltoall
from horovod.tensorflow.mpi_ops import init, shutdown
from horovod.tensorflow.mpi_ops import is_initialized, start_timeline, stop_timeline
from horovod.tensorflow.mpi_ops import size, local_size, cross_size, rank, local_rank, cross_rank, is_homogeneous
Expand Down
41 changes: 37 additions & 4 deletions horovod/tensorflow/functions.py
Expand Up @@ -21,7 +21,7 @@

from tensorflow.python.framework import ops

from horovod.tensorflow.mpi_ops import allgather, broadcast
from horovod.tensorflow.mpi_ops import allgather, broadcast, broadcast_
from horovod.tensorflow.mpi_ops import rank, size
from horovod.tensorflow.util import _cache, _executing_eagerly, _make_subgraph
from horovod.common.process_sets import ProcessSet, global_process_set
Expand All @@ -45,20 +45,53 @@ def broadcast_group(variables, root_rank, process_set: ProcessSet):
return broadcast_group


def broadcast_variables(variables, root_rank, process_set=global_process_set):
@_cache
def _make_inplace_broadcast_group_fn():
if _executing_eagerly():
# These are just a few calls of broadcast_, no need to aggregate them in a tf.function
def broadcast_group(variable_lists, root_rank, process_set: ProcessSet):
for variables in variable_lists:
broadcast_(variables, root_rank, process_set=process_set)

return broadcast_group
else:
# Graph mode requires an Op
def broadcast_group(variable_lists, root_rank, process_set: ProcessSet):
return tf.group(*[broadcast_(variables, root_rank, process_set=process_set)
for variables in variable_lists])

return broadcast_group


def broadcast_variables(variables, root_rank, process_set=global_process_set, inplace=False):
"""
Broadcasts variables from root rank to all other processes
in a process set (defaults to all Horovod processes).
Optionally, the broadcast may be performed in-place, which avoids
temporary memory allocations and fragmentation. This is only
supported with TensorFlow 2.6 or later. Reference variables
(legacy support in TF 2) must all be of the same data type. There
is no such restriction for resource variables (default in TF 2).
Arguments:
variables: variables for broadcast
root_rank: rank of the process from which global variables will be broadcasted
to all other processes.
process_set: Process set object to limit this operation to a subset of
Horovod processes. Default is the global process set.
inplace: whether to perform in-place broadcasts
"""
broadcast_group = _make_broadcast_group_fn()
return broadcast_group(variables, root_rank, process_set)
if inplace:
vars_by_device = {}
for var in variables:
vars_by_device.setdefault(var.device, []).append(var)

inplace_broadcast_group = _make_inplace_broadcast_group_fn()
return inplace_broadcast_group(vars_by_device.values(), root_rank, process_set)
else:
broadcast_group = _make_broadcast_group_fn()
return broadcast_group(variables, root_rank, process_set)


def broadcast_object(obj, root_rank=0, session=None, name=None, process_set=global_process_set):
Expand Down
269 changes: 267 additions & 2 deletions horovod/tensorflow/mpi_ops.cc
Expand Up @@ -18,16 +18,27 @@

#include <memory>
#include <queue>
#include <regex>
#include <thread>
#include <unordered_map>

#define EIGEN_USE_THREADS
#if HAVE_CUDA || HAVE_ROCM
#define EIGEN_USE_GPU
#endif // HAVE_CUDA || HAVE_ROCM

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/common_shape_fns.h"

#include "../common/common.h"
#if TENSORFLOW_VERSION >= 2006000000
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#endif // TENSORFLOW_VERSION >= 2006000000

#define EIGEN_USE_THREADS
#include "../common/common.h"

#if HAVE_GPU

Expand Down Expand Up @@ -831,6 +842,260 @@ Output
`tensor` on root rank.
)doc");

#if TENSORFLOW_VERSION >= 2006000000
namespace {
std::string NormalizeNameForTensorFlow(const std::string& name) {
static const std::regex normalize_re(R"regex([^a-zA-Z0-9_])regex");
return std::regex_replace(name, normalize_re, "_");
}

Status GetInputDataTypeFromVariable(OpKernelContext* ctx, int input,
DataType& out) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
core::RefCountPtr<Var> var;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
out = var->tensor()->dtype();
} else {
out = BaseType(ctx->input_dtype(input));
}
return Status::OK();
}

}

template <typename Device>
class HorovodBroadcastInplaceOp : public OpKernel {
public:
explicit HorovodBroadcastInplaceOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("root_rank", &root_rank_));
OP_REQUIRES_OK(context,
context->GetAttr("process_set_id", &process_set_id_));
OP_REQUIRES_OK(context, context->GetAttr("num_variables", &num_variables_));
OP_REQUIRES_OK(context, context->GetAttr("variable_names", &variable_names_));
OP_REQUIRES(context, (int) variable_names_.size() == num_variables_,
errors::InvalidArgument(
"len(variable_names) needs to be equal to num_variables"));
}

void Compute(OpKernelContext* context) override {
OP_REQUIRES_OK(context, ConvertStatus(common::CheckInitialized()));

auto any_failures_and_tensors_done =
std::make_shared<std::pair<std::atomic<bool>, std::atomic<int>>>();
any_failures_and_tensors_done->first.store(false);
any_failures_and_tensors_done->second.store(0);

std::vector<VariableInputLockHolder> variable_locks;
variable_locks.reserve(num_variables_);

for (int tensor_index = 0; tensor_index < num_variables_; ++tensor_index) {
DataType dtype;
OP_REQUIRES_OK(
context, GetInputDataTypeFromVariable(context, tensor_index, dtype));

// Functions in tensorflow/core/kernels/training_op_helpers.h that deal
// with resource variables need a template type parameter. This requires
// us to branch out to different specializations of a templated helper
// function.
switch (dtype) {
#define PROCESS_CASE(DT, T) \
case DT: \
OP_REQUIRES_OK(context, Process<T>(context, tensor_index, variable_locks, \
any_failures_and_tensors_done)); \
break;
PROCESS_CASE(DT_UINT8, uint8)
PROCESS_CASE(DT_INT8, int8)
PROCESS_CASE(DT_INT32, int32)
PROCESS_CASE(DT_INT64, int64)
PROCESS_CASE(DT_HALF, Eigen::half)
PROCESS_CASE(DT_FLOAT, float)
PROCESS_CASE(DT_DOUBLE, double)
PROCESS_CASE(DT_BOOL, bool)
// no support for int16 and uint16 because there are no DenseUpdate
// kernels for them
default:
context->CtxFailure(__FILE__, __LINE__,errors::InvalidArgument(
"Horovod inplace broadcast does not support data type ",
DataTypeString(dtype)));
return;
}
#undef PROCESS_CASE
}

while (!any_failures_and_tensors_done->first.load() &&
any_failures_and_tensors_done->second.load() < num_variables_) {
std::this_thread::yield();
}
}

private:
int root_rank_ = 0;
int process_set_id_ = 0;
int num_variables_ = 0;
std::vector<std::string> variable_names_;

template <typename T>
Status
Process(OpKernelContext* context, int tensor_index,
std::vector<VariableInputLockHolder>& variable_locks,
const std::shared_ptr<std::pair<std::atomic<bool>, std::atomic<int>>>&
any_failures_and_tensors_done) {
const bool do_lock = true;
const bool sparse = false;
// Here we need to replicate the functionality provided by
// MaybeLockVariableInputMutexesInOrder(). That function currently does
// not work as intended for input_ids not starting at 0. See:
// https://github.com/tensorflow/tensorflow/issues/51686
{
Var* var;
mutex* mu = GetTrainingVariableMutex<Device, T>(context, tensor_index,
sparse, &var);
std::vector<Var*> vars;
if (var) {
vars.reserve(1);
vars.push_back(var);
}
std::vector<mutex*> mutexes{mu};
auto locks = absl::make_unique<std::vector<mutex_lock>>();
locks->reserve(1);
locks->emplace_back(*mu);
auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
variable_locks.emplace_back(std::move(vars), std::move(locks),
std::move(shared_locks));
}

Tensor tensor;
TF_RETURN_IF_ERROR(GetInputTensorFromVariable<Device, T>(
context, tensor_index, do_lock, sparse, &tensor));
Tensor* output = &tensor;
MaybeForwardRefInputToRefOutput(context, tensor_index, tensor_index);

std::string var_name = variable_names_[tensor_index];
if (context->input_dtype(tensor_index) == DT_RESOURCE && var_name.empty()) {
const ResourceHandle& handle = HandleFromInput(context, tensor_index);
// We use handle.name() as a fallback only when we do not have a proper
// name because typically it seems to be something like _AnonymousVar18.
// The Python name attribute of the variable does not appear to be passed
// through automatically.
var_name = handle.name();
}

auto device = GetDeviceID(context);
// ReadyEvent makes sure input tensor is ready, and output is allocated.
common::ReadyEventList ready_event_list;
#if HAVE_GPU
ready_event_list.AddReadyEvent(
std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(context)));
#endif
auto hvd_context = std::make_shared<TFOpContext>(context);
auto hvd_tensor = std::make_shared<TFTensor>(tensor);
auto hvd_output = std::make_shared<TFTensor>(*output);
const std::string node_name =
name() + "_" + NormalizeNameForTensorFlow(var_name);
auto enqueue_result = EnqueueTensorBroadcast(
hvd_context, hvd_tensor, hvd_output, root_rank_, ready_event_list,
node_name, device,
[context, any_failures_and_tensors_done](const common::Status& status) {
#if HAVE_GPU
auto hvd_event = status.event;
if (hvd_event.event) {
auto device_context = context->op_device_context();
if (device_context != nullptr) {
auto stream = stream_executor::gpu::AsGpuStreamValue(
device_context->stream());
HVD_GPU_CHECK(gpuStreamWaitEvent(stream, *(hvd_event.event), 0));
}
}
#endif
if (!status.ok()) {
auto prev_failures = any_failures_and_tensors_done->first.load();
if (!prev_failures) {
// Only keeping failure status of the first broadcast that fails
context->SetStatus(ConvertStatus(status));
any_failures_and_tensors_done->first.store(false);
}
}
any_failures_and_tensors_done->second.fetch_add(1);
},
process_set_id_);
return ConvertStatus(enqueue_result);
}
};

REGISTER_KERNEL_BUILDER(Name("HorovodBroadcastInplace").Device(DEVICE_CPU),
HorovodBroadcastInplaceOp<CPUDevice>);
#if HOROVOD_GPU_BROADCAST
REGISTER_KERNEL_BUILDER(Name("HorovodBroadcastInplace").Device(DEVICE_GPU),
HorovodBroadcastInplaceOp<GPUDevice>);
#endif

REGISTER_OP("HorovodBroadcastInplace")
.Attr(
"T: {uint8, int8, int32, int64, float16, float32, float64, bool}")
.Attr("root_rank: int")
.Attr("process_set_id: int = 0")
.Attr("num_variables: int")
.Attr("variable_names: list(string)")
.Input("tensor_refs: Ref(num_variables * T)")
.Output("output_refs: Ref(num_variables * T)")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Perform an in-place Broadcast on (TF1-style) reference variables. All other
processes that do a broadcast on variables with the same names must have the
same dimensions for those variables. All variables must be located on the same
device and they must be of the same data type.
This requires TensorFlow 2.6+.
Arguments
root_rank: Rank that will send data, other ranks will receive data.
variable_names: Names associated to the variables (obtained via Python
framework)
Input
tensor_refs: Variables to broadcast. They will be updated in-place
to the values from the root rank.
Output
output_refs: The updated variables.
)doc");

REGISTER_KERNEL_BUILDER(
Name("HorovodBroadcastInplaceResource").Device(DEVICE_CPU),
HorovodBroadcastInplaceOp<CPUDevice>);
#if HOROVOD_GPU_BROADCAST
REGISTER_KERNEL_BUILDER(Name("HorovodBroadcastInplaceResource")
.Device(DEVICE_GPU)
.HostMemory("resources"),
HorovodBroadcastInplaceOp<GPUDevice>);
#endif

REGISTER_OP("HorovodBroadcastInplaceResource")
.Attr("root_rank: int")
.Attr("process_set_id: int = 0")
.Attr("num_variables: int")
.Attr("variable_names: list(string)")
.Input("resources: num_variables * resource")
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Perform an in-place Broadcast on (TF2-style) resource variables. All other
processes that do a broadcast on variables with the same names must have the
same dimensions for those variables. All variables must be located on the same
device.
This requires TensorFlow 2.6+.
Arguments
root_rank: Rank that will send data, other ranks will receive data.
variable_names: Names associated to the variables (obtained via Python
framework)
Input
resources: Variables to broadcast. They will be updated in-place
to the values from the root rank.
)doc");
#endif // TENSORFLOW_VERSION >= 2006000000

class HorovodJoinOp : public AsyncOpKernel {
public:
explicit HorovodJoinOp(OpKernelConstruction* context)
Expand Down

0 comments on commit 4d396a4

Please sign in to comment.