Skip to content

Commit

Permalink
Include source files from TF in build to avoid missing symbol errors
Browse files Browse the repository at this point in the history
Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
  • Loading branch information
maxhgerlach committed Aug 26, 2021
1 parent 72a65e1 commit 950da7b
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 0 deletions.
19 changes: 19 additions & 0 deletions horovod/tensorflow/CMakeLists.txt
Expand Up @@ -61,6 +61,25 @@ set(Tensorflow_CXX11 ${Tensorflow_CXX11} PARENT_SCOPE)
list(APPEND TF_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/mpi_ops.cc")
list(APPEND TF_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/xla_mpi_ops.cc")

list(APPEND TF_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/dense_update_functor.cc"
"${PROJECT_SOURCE_DIR}/horovod/tensorflow/training_op_helpers.cc")

if(HAVE_CUDA OR HAVE_SUB_PROJECT_CUDA)
list(APPEND TF_CUDA_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/dense_update_functor_gpu.cc.cu")
# dense_update_functor_gpu.cu.cc from TensorFlow source had to be renamed because it was not compiled via nvcc otherwise

set(CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER})

set(ENV{PYTHONPATH} "${PROJECT_SOURCE_DIR}/cmake:$ENV{PYTHONPATH}")
execute_process(COMMAND ${PY_EXE} -c "import build_utils; print(' '.join(build_utils.get_nvcc_flags()))"
OUTPUT_VARIABLE HVD_NVCC_COMPILE_FLAGS OUTPUT_STRIP_TRAILING_WHITESPACE)

list(APPEND CUDA_NVCC_FLAGS "${HVD_NVCC_COMPILE_FLAGS}")

cuda_add_library(tensorflow_cuda_kernels ${TF_CUDA_SOURCES} STATIC OPTIONS -DGOOGLE_CUDA=1)
list(APPEND TF_LINKER_LIBS tensorflow_cuda_kernels)
endif()

# Create library
set_output_dir()
add_library(${TF_TARGET_LIB} SHARED ${SOURCES} ${TF_SOURCES})
Expand Down
128 changes: 128 additions & 0 deletions horovod/tensorflow/dense_update_functor.cc
@@ -0,0 +1,128 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow/core/kernels/dense_update_functor.h"

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

namespace functor {

template <>
struct DenseUpdate<CPUDevice, string, ASSIGN> {
void operator()(const CPUDevice& d, typename TTypes<tstring>::Flat params,
typename TTypes<tstring>::ConstFlat update) {
if (params.dimension(0) == 1) {
params.data()->resize(update.data()->size());
auto work = [&params, &update](int64_t start, int64_t end) {
memmove(const_cast<char*>(params.data()->data()) + start,
update.data()->data() + start, end - start);
};
d.parallelFor(update.data()->size(),
Eigen::TensorOpCost(.1, // chosen to force large chunks
.1, 0),
work);
} else {
auto work = [&params, &update](int64_t start, int64_t end) {
for (int i = start; i < end; ++i) {
params.data()[i].resize(update.data()[i].size());
memmove(const_cast<char*>(params.data()[i].data()),
update.data()[i].data(), update.data()[i].size());
}
};
int64_t estimated_string_size;
if (update.size() > 0) {
// first element of the tensor seems as good a guess as any of the sizes
// of the strings contained within...
estimated_string_size =
std::max(update.data()[0].size(), sizeof(tstring));
} else {
estimated_string_size = sizeof(tstring);
}
d.parallelFor(
params.dimension(0),
Eigen::TensorOpCost(estimated_string_size, estimated_string_size, 0),
work);
}
}
};

} // namespace functor

#define CPU_DENSE_COPY(T) \
case DataTypeToEnum<T>::value: { \
functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_; \
copy_functor_(context->eigen_device<CPUDevice>(), tensor.flat<T>(), \
from.flat<T>()); \
break; \
}

#define INSTANTIATE_GET_VARIANT_COPY_FN(DEVICE, TYPE_CALLER, TYPE_DENSE_COPY) \
template <> \
Status VariantCopyFn<DEVICE>(OpKernelContext * context, const Tensor& from, \
Tensor* to) { \
Tensor tensor; \
AllocatorAttributes attr; \
attr.set_gpu_compatible(true); \
attr.set_nic_compatible(true); \
TF_RETURN_IF_ERROR( \
context->allocate_temp(from.dtype(), from.shape(), &tensor, attr)); \
switch (from.dtype()) { \
TYPE_CALLER(TYPE_DENSE_COPY); \
default: \
return errors::InvalidArgument( \
"VariantCopyFn: Could not perform a deep copy of variant " \
"element of type: ", \
DataTypeString(from.dtype()), \
" using device: ", context->device()->name()); \
} \
*to = tensor; \
return Status::OK(); \
}

INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define GPU_DENSE_COPY(T) \
case DataTypeToEnum<T>::value: { \
functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_; \
copy_functor_(context->eigen_device<GPUDevice>(), tensor.flat<T>(), \
from.flat<T>()); \
break; \
}
#define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
TF_CALL_GPU_ALL_TYPES(T); \
TF_CALL_int32(T); \
TF_CALL_int64(T);
INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
GPU_DENSE_COPY);
#undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
#undef GPU_DENSE_COPY
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#undef CPU_DENSE_COPY
#undef INSTANTIATE_GET_VARIANT_COPY_FN

} // namespace tensorflow
78 changes: 78 additions & 0 deletions horovod/tensorflow/dense_update_functor_gpu.cc.cu
@@ -0,0 +1,78 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

#include "tensorflow/core/kernels/dense_update_functor.h"

#include "tensorflow/core/framework/register_types.h"

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;

namespace functor {

template <typename T>
struct DenseUpdate<GPUDevice, T, ASSIGN> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) = update;
}
};

template <typename T>
struct DenseUpdate<GPUDevice, T, ADD> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) += update;
}
};

template <typename T>
struct DenseUpdate<GPUDevice, T, SUB> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) -= update;
}
};

} // namespace functor

#define DEFINE_GPU_KERNELS(T) \
template struct functor::DenseUpdate<GPUDevice, T, ADD>; \
template struct functor::DenseUpdate<GPUDevice, T, SUB>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_int32(DEFINE_GPU_KERNELS);
TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_int8(DEFINE_GPU_KERNELS);
TF_CALL_uint8(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS

#define DEFINE_GPU_KERNELS(T) \
template struct functor::DenseUpdate<GPUDevice, T, ASSIGN>;
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_int32(DEFINE_GPU_KERNELS);
TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_int8(DEFINE_GPU_KERNELS);
TF_CALL_uint8(DEFINE_GPU_KERNELS);
TF_CALL_uint32(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS

} // end namespace tensorflow

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
30 changes: 30 additions & 0 deletions horovod/tensorflow/training_op_helpers.cc
@@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/kernels/training_op_helpers.h"

#include "tensorflow/core/util/ptr_util.h"

namespace tensorflow {


void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
int output) {
if (ctx->input_dtype(input) != DT_RESOURCE) {
ctx->forward_ref_input_to_ref_output(input, output);
}
}

} // end namespace tensorflow

0 comments on commit 950da7b

Please sign in to comment.