From 80339c342767eb09ca92a86031205800ca334878 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 30 May 2022 13:09:45 -0700 Subject: [PATCH] Enable distributed GPU training over Rabit (#7930) --- doc/build.rst | 4 +- src/common/common.h | 18 + src/common/device_helpers.cu | 63 ++- src/common/device_helpers.cuh | 454 +++++++++++++++++---- tests/cpp/common/test_quantile.cu | 10 - tests/cpp/common/test_transform_range.cu | 4 +- tests/cpp/metric/test_multiclass_metric.cc | 4 +- tests/distributed/runtests-federated.sh | 8 +- tests/distributed/test_federated.py | 10 +- 9 files changed, 452 insertions(+), 123 deletions(-) diff --git a/doc/build.rst b/doc/build.rst index 195b6b1f0922..b27d55930212 100644 --- a/doc/build.rst +++ b/doc/build.rst @@ -136,9 +136,9 @@ From the command line on Linux starting from the XGBoost directory: To speed up compilation, the compute version specific to your GPU could be passed to cmake as, e.g., ``-DGPU_COMPUTE_VER=50``. A quick explanation and numbers for some architectures can be found `in this page `_. -.. note:: Enabling distributed GPU training +.. note:: Faster distributed GPU training with NCCL - By default, distributed GPU training is disabled and only a single GPU will be used. To enable distributed GPU training, set the option ``USE_NCCL=ON``. Distributed GPU training depends on NCCL2, available at `this link `_. Since NCCL2 is only available for Linux machines, **distributed GPU training is available only for Linux**. + By default, distributed GPU training is enabled and uses Rabit for communication. For faster training, set the option ``USE_NCCL=ON``. Faster distributed GPU training depends on NCCL2, available at `this link `_. Since NCCL2 is only available for Linux machines, **faster distributed GPU training is available only for Linux**. .. code-block:: bash diff --git a/src/common/common.h b/src/common/common.h index 1f36f3d8d687..0f21739876b2 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -274,6 +274,24 @@ template XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) { return indptr[group + 1] - 1; } + +/** + * @brief A CRTP (curiously recurring template pattern) helper function. + * + * https://www.fluentcpp.com/2017/05/19/crtp-helper/ + * + * Does two things: + * 1. Makes "crtp" explicit in the inheritance structure of a CRTP base class. + * 2. Avoids having to `static_cast` in a lot of places. + * + * @tparam T The derived class in a CRTP hierarchy. + */ +template +struct Crtp { + T &Underlying() { return static_cast(*this); } + T const &Underlying() const { return static_cast(*this); } +}; + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/device_helpers.cu b/src/common/device_helpers.cu index ec69bc900780..4c7b6b90c045 100644 --- a/src/common/device_helpers.cu +++ b/src/common/device_helpers.cu @@ -30,19 +30,15 @@ std::string PrintUUID(xgboost::common::Span uuid) { return ss.str(); } - -void AllReducer::Init(int _device_ordinal) { #ifdef XGBOOST_USE_NCCL - device_ordinal_ = _device_ordinal; - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - +void NcclAllReducer::DoInit(int _device_ordinal) { int32_t const rank = rabit::GetRank(); int32_t const world = rabit::GetWorldSize(); std::vector uuids(world * kUuidLength, 0); auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); - GetCudaUUID(world, rank, device_ordinal_, s_this_uuid); + GetCudaUUID(world, rank, _device_ordinal, s_this_uuid); // No allgather yet. rabit::Allreduce(uuids.data(), uuids.size()); @@ -66,20 +62,11 @@ void AllReducer::Init(int _device_ordinal) { id_ = GetUniqueId(); dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank)); safe_cuda(cudaStreamCreate(&stream_)); - initialised_ = true; -#else - if (rabit::IsDistributed()) { - LOG(FATAL) << "XGBoost is not compiled with NCCL."; - } -#endif // XGBOOST_USE_NCCL } -void AllReducer::AllGather(void const *data, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *recvbuf) { -#ifdef XGBOOST_USE_NCCL - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); +void NcclAllReducer::DoAllGather(void const *data, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *recvbuf) { size_t world = rabit::GetWorldSize(); segments->clear(); segments->resize(world, 0); @@ -98,11 +85,9 @@ void AllReducer::AllGather(void const *data, size_t length_bytes, offset += as_bytes; } safe_nccl(ncclGroupEnd()); -#endif // XGBOOST_USE_NCCL } -AllReducer::~AllReducer() { -#ifdef XGBOOST_USE_NCCL +NcclAllReducer::~NcclAllReducer() { if (initialised_) { dh::safe_cuda(cudaStreamDestroy(stream_)); ncclCommDestroy(comm_); @@ -112,7 +97,41 @@ AllReducer::~AllReducer() { LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576; } -#endif // XGBOOST_USE_NCCL } +#else +void RabitAllReducer::DoInit(int _device_ordinal) { +#if !defined(XGBOOST_USE_FEDERATED) + if (rabit::IsDistributed()) { + LOG(CONSOLE) << "XGBoost is not compiled with NCCL, falling back to Rabit."; + } +#endif +} + +void RabitAllReducer::DoAllGather(void const *data, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *recvbuf) { + size_t world = rabit::GetWorldSize(); + segments->clear(); + segments->resize(world, 0); + segments->at(rabit::GetRank()) = length_bytes; + rabit::Allreduce(segments->data(), segments->size()); + auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); + recvbuf->resize(total_bytes); + + sendrecvbuf_.reserve(total_bytes); + auto rank = rabit::GetRank(); + size_t offset = 0; + for (int32_t i = 0; i < world; ++i) { + size_t as_bytes = segments->at(i); + if (i == rank) { + safe_cuda( + cudaMemcpy(sendrecvbuf_.data() + offset, data, segments->at(rank), cudaMemcpyDefault)); + } + rabit::Broadcast(sendrecvbuf_.data() + offset, as_bytes, i); + offset += as_bytes; + } + safe_cuda(cudaMemcpy(recvbuf->data().get(), sendrecvbuf_.data(), total_bytes, cudaMemcpyDefault)); +} +#endif // XGBOOST_USE_NCCL } // namespace dh diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 334e3b4f89bf..123dc14e57be 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -738,32 +738,56 @@ using TypedDiscard = * \class AllReducer * * \brief All reducer class that manages its own communication group and - * streams. Must be initialised before use. If XGBoost is compiled without NCCL - * this is a dummy class that will error if used with more than one GPU. + * streams. Must be initialised before use. If XGBoost is compiled without NCCL, + * this falls back to use Rabit. */ -class AllReducer { - bool initialised_ {false}; - size_t allreduce_bytes_ {0}; // Keep statistics of the number of bytes communicated - size_t allreduce_calls_ {0}; // Keep statistics of the number of reduce calls -#ifdef XGBOOST_USE_NCCL - ncclComm_t comm_; - cudaStream_t stream_; - int device_ordinal_; - ncclUniqueId id_; -#endif - +template +class AllReducerBase : public xgboost::common::Crtp { public: - AllReducer() = default; + virtual ~AllReducerBase() = default; /** - * \brief Initialise with the desired device ordinal for this communication - * group. + * \brief Initialise with the desired device ordinal for this allreducer. * * \param device_ordinal The device ordinal. */ - void Init(int _device_ordinal); + void Init(int _device_ordinal) { + device_ordinal_ = _device_ordinal; + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + this->Underlying().DoInit(_device_ordinal); + initialised_ = true; + } + + /** + * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept + * different size of data on different workers. + * + * \param data Buffer storing the input data. + * \param length_bytes Size of input data in bytes. + * \param segments Size of data on each worker. + * \param recvbuf Buffer storing the result of data from all workers. + */ + void AllGather(void const *data, size_t length_bytes, std::vector *segments, + dh::caching_device_vector *recvbuf) { + CHECK(initialised_); + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf); + } - ~AllReducer(); + /** + * \brief Allgather. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param data Buffer storing the input data. + * \param length Size of input data in bytes. + * \param recvbuf Buffer storing the result of data from all workers. + */ + void AllGather(uint32_t const *data, size_t length, + dh::caching_device_vector *recvbuf) { + CHECK(initialised_); + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + this->Underlying().DoAllGather(data, length, recvbuf); + } /** * \brief Allreduce. Use in exactly the same way as NCCL but without needing @@ -773,36 +797,12 @@ class AllReducer { * \param recvbuff The recvbuff. * \param count Number of elements. */ - void AllReduceSum(const double *sendbuff, double *recvbuff, int count) { -#ifdef XGBOOST_USE_NCCL CHECK(initialised_); dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm_, stream_)); + this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); allreduce_bytes_ += count * sizeof(double); allreduce_calls_ += 1; -#endif - } - - /** - * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept - * different size of data on different workers. - * \param length_bytes Size of input data in bytes. - * \param segments Size of data on each worker. - * \param recvbuf Buffer storing the result of data from all workers. - */ - void AllGather(void const* data, size_t length_bytes, - std::vector* segments, dh::caching_device_vector* recvbuf); - - void AllGather(uint32_t const* data, size_t length, - dh::caching_device_vector* recvbuf) { -#ifdef XGBOOST_USE_NCCL - CHECK(initialised_); - size_t world = rabit::GetWorldSize(); - recvbuf->resize(length * world); - safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, - comm_, stream_)); -#endif // XGBOOST_USE_NCCL } /** @@ -813,15 +813,12 @@ class AllReducer { * \param recvbuff The recvbuff. * \param count Number of elements. */ - void AllReduceSum(const float *sendbuff, float *recvbuff, int count) { -#ifdef XGBOOST_USE_NCCL CHECK(initialised_); dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm_, stream_)); + this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); allreduce_bytes_ += count * sizeof(float); allreduce_calls_ += 1; -#endif } /** @@ -833,48 +830,68 @@ class AllReducer { * \param recvbuff The recvbuff. * \param count Number of. */ - void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) { -#ifdef XGBOOST_USE_NCCL CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm_, stream_)); -#endif + this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); + allreduce_bytes_ += count * sizeof(int64_t); + allreduce_calls_ += 1; } + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) { -#ifdef XGBOOST_USE_NCCL CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_)); -#endif + this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); + allreduce_bytes_ += count * sizeof(uint32_t); + allreduce_calls_ += 1; } + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) { -#ifdef XGBOOST_USE_NCCL CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); -#endif + this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); + allreduce_bytes_ += count * sizeof(uint64_t); + allreduce_calls_ += 1; } - // Specialization for size_t, which is implementation defined so it might or might not - // be one of uint64_t/uint32_t/unsigned long long/unsigned long. + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * Specialization for size_t, which is implementation defined so it might or might not + * be one of uint64_t/uint32_t/unsigned long long/unsigned long. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ template ::value && !std::is_same::value> // NOLINT * = nullptr> - void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT -#ifdef XGBOOST_USE_NCCL + void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); -#endif + static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT + this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); + allreduce_bytes_ += count * sizeof(T); + allreduce_calls_ += 1; } /** @@ -883,13 +900,148 @@ class AllReducer { * \brief Synchronizes the entire communication group. */ void Synchronize() { -#ifdef XGBOOST_USE_NCCL + CHECK(initialised_); dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_cuda(cudaStreamSynchronize(stream_)); -#endif - }; + this->Underlying().DoSynchronize(); + } + + protected: + bool initialised_{false}; + size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. + size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. + + private: + int device_ordinal_{-1}; +}; #ifdef XGBOOST_USE_NCCL +class NcclAllReducer : public AllReducerBase { + public: + friend class AllReducerBase; + + ~NcclAllReducer() override; + + private: + /** + * \brief Initialise with the desired device ordinal for this communication + * group. + * + * \param device_ordinal The device ordinal. + */ + void DoInit(int _device_ordinal); + + /** + * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept + * different size of data on different workers. + * + * \param data Buffer storing the input data. + * \param length_bytes Size of input data in bytes. + * \param segments Size of data on each worker. + * \param recvbuf Buffer storing the result of data from all workers. + */ + void DoAllGather(void const *data, size_t length_bytes, std::vector *segments, + dh::caching_device_vector *recvbuf); + + /** + * \brief Allgather. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param data Buffer storing the input data. + * \param length Size of input data in bytes. + * \param recvbuf Buffer storing the result of data from all workers. + */ + void DoAllGather(uint32_t const *data, size_t length, + dh::caching_device_vector *recvbuf) { + size_t world = rabit::GetWorldSize(); + recvbuf->resize(length * world); + safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, comm_, stream_)); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) { + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm_, stream_)); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) { + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm_, stream_)); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms. + * + * \param count Number of. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of. + */ + void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) { + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm_, stream_)); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) { + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_)); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) { + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * Specialization for size_t, which is implementation defined so it might or might not + * be one of uint64_t/uint32_t/unsigned long long/unsigned long. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + template ::value && + !std::is_same::value> // NOLINT + * = nullptr> + void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); + } + + /** + * \brief Synchronizes the entire communication group. + */ + void DoSynchronize() { dh::safe_cuda(cudaStreamSynchronize(stream_)); } + /** * \fn ncclUniqueId GetUniqueId() * @@ -904,15 +1056,163 @@ class AllReducer { if (rabit::GetRank() == kRootRank) { dh::safe_nccl(ncclGetUniqueId(&id)); } - rabit::Broadcast( - static_cast(&id), - sizeof(ncclUniqueId), - static_cast(kRootRank)); + rabit::Broadcast(static_cast(&id), sizeof(ncclUniqueId), static_cast(kRootRank)); return id; } -#endif + + ncclComm_t comm_; + cudaStream_t stream_; + ncclUniqueId id_; +}; + +using AllReducer = NcclAllReducer; +#else +class RabitAllReducer : public AllReducerBase { + public: + friend class AllReducerBase; + + private: + /** + * \brief Initialise with the desired device ordinal for this allreducer. + * + * \param device_ordinal The device ordinal. + */ + static void DoInit(int _device_ordinal); + + /** + * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept + * different size of data on different workers. + * + * \param data Buffer storing the input data. + * \param length_bytes Size of input data in bytes. + * \param segments Size of data on each worker. + * \param recvbuf Buffer storing the result of data from all workers. + */ + void DoAllGather(void const *data, size_t length_bytes, std::vector *segments, + dh::caching_device_vector *recvbuf); + + /** + * \brief Allgather. Use in exactly the same way as NCCL. + * + * \param data Buffer storing the input data. + * \param length Size of input data in bytes. + * \param recvbuf Buffer storing the result of data from all workers. + */ + void DoAllGather(uint32_t *data, size_t length, dh::caching_device_vector *recvbuf) { + size_t world = rabit::GetWorldSize(); + auto total_size = length * world; + recvbuf->resize(total_size); + sendrecvbuf_.reserve(total_size); + auto rank = rabit::GetRank(); + safe_cuda(cudaMemcpy(sendrecvbuf_.data() + rank * length, data, length, cudaMemcpyDefault)); + rabit::Allgather(sendrecvbuf_.data(), total_size, rank * length, length, length); + safe_cuda(cudaMemcpy(data, sendrecvbuf_.data(), total_size, cudaMemcpyDefault)); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) { + RabitAllReduceSum(sendbuff, recvbuff, count); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) { + RabitAllReduceSum(sendbuff, recvbuff, count); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) { + RabitAllReduceSum(sendbuff, recvbuff, count); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) { + RabitAllReduceSum(sendbuff, recvbuff, count); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) { + RabitAllReduceSum(sendbuff, recvbuff, count); + } + + /** + * \brief Allreduce. Use in exactly the same way as NCCL. + * + * Specialization for size_t, which is implementation defined so it might or might not + * be one of uint64_t/uint32_t/unsigned long long/unsigned long. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + template ::value && + !std::is_same::value> // NOLINT + * = nullptr> + void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT + RabitAllReduceSum(sendbuff, recvbuff, count); + } + + /** + * \brief Synchronizes the allreducer. + */ + void DoSynchronize() {} + + /** + * \brief Allreduce. Use in exactly the same way as NCCL. + * + * Copy the device buffer to host, call rabit allreduce, then copy the buffer back + * to device. + * + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + template + void RabitAllReduceSum(const T *sendbuff, T *recvbuff, int count) { + auto total_size = count * sizeof(T); + sendrecvbuf_.reserve(total_size); + safe_cuda(cudaMemcpy(sendrecvbuf_.data(), sendbuff, total_size, cudaMemcpyDefault)); + rabit::Allreduce(reinterpret_cast(sendrecvbuf_.data()), count); + safe_cuda(cudaMemcpy(recvbuff, sendrecvbuf_.data(), total_size, cudaMemcpyDefault)); + } + + /// Host buffer used to call rabit functions. + std::vector sendrecvbuf_{}; }; +using AllReducer = RabitAllReducer; +#endif + template ::index_type> xgboost::common::Span ToSpan( diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index c124ab5055e6..dcb82d29c8d8 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -339,7 +339,6 @@ TEST(GPUQuantile, MultiMerge) { TEST(GPUQuantile, AllReduceBasic) { // This test is supposed to run by a python test that setups the environment. std::string msg {"Skipping AllReduce test"}; -#if defined(__linux__) && defined(XGBOOST_USE_NCCL) auto n_gpus = AllVisibleGPUs(); InitRabitContext(msg, n_gpus); auto world = rabit::GetWorldSize(); @@ -420,15 +419,10 @@ TEST(GPUQuantile, AllReduceBasic) { } }); rabit::Finalize(); -#else - LOG(WARNING) << msg; - return; -#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL) } TEST(GPUQuantile, SameOnAllWorkers) { std::string msg {"Skipping SameOnAllWorkers test"}; -#if defined(__linux__) && defined(XGBOOST_USE_NCCL) auto n_gpus = AllVisibleGPUs(); InitRabitContext(msg, n_gpus); auto world = rabit::GetWorldSize(); @@ -495,10 +489,6 @@ TEST(GPUQuantile, SameOnAllWorkers) { offset += size_as_float; } }); -#else - LOG(WARNING) << msg; - return; -#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL) } TEST(GPUQuantile, Push) { diff --git a/tests/cpp/common/test_transform_range.cu b/tests/cpp/common/test_transform_range.cu index c1609312728e..172d7aeb36af 100644 --- a/tests/cpp/common/test_transform_range.cu +++ b/tests/cpp/common/test_transform_range.cu @@ -4,7 +4,6 @@ */ #include "test_transform_range.cc" -#if defined(XGBOOST_USE_NCCL) namespace xgboost { namespace common { @@ -15,7 +14,7 @@ TEST(Transform, MGPU_SpecifiedGpuId) { // NOLINT } // Use 1 GPU, Numbering of GPU starts from 1 auto device = 1; - const size_t size {256}; + auto const size {256}; std::vector h_in(size); std::vector h_out(size); std::iota(h_in.begin(), h_in.end(), 0); @@ -34,4 +33,3 @@ TEST(Transform, MGPU_SpecifiedGpuId) { // NOLINT } // namespace common } // namespace xgboost -#endif diff --git a/tests/cpp/metric/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc index 5a2c939e9315..2d9721ca53e8 100644 --- a/tests/cpp/metric/test_multiclass_metric.cc +++ b/tests/cpp/metric/test_multiclass_metric.cc @@ -85,7 +85,7 @@ TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"mlogloss"}, GPUIDX); } -#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__) +#if defined(__CUDACC__) namespace xgboost { namespace common { TEST(Metric, MGPU_MultiClassError) { @@ -109,4 +109,4 @@ TEST(Metric, MGPU_MultiClassError) { } } // namespace common } // namespace xgboost -#endif // defined(XGBOOST_USE_NCCL) +#endif // defined(__CUDACC__) diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index 77724aa969ea..81a40c3505f4 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -4,14 +4,14 @@ set -e rm -f ./*.model* ./agaricus* ./*.pem -world_size=3 +world_size=$(nvidia-smi -L | wc -l) # Generate server and client certificates. openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost" openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost" # Split train and test files manually to simulate a federated environment. -split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train- -split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test- +split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.train agaricus.txt.train- +split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.test agaricus.txt.test- -python test_federated.py ${world_size} +python test_federated.py "${world_size}" diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index a3cdbc1e2912..cddd104e922c 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -17,7 +17,7 @@ def run_server(port: int, world_size: int) -> None: CLIENT_CERT) -def run_worker(port: int, world_size: int, rank: int) -> None: +def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None: # Always call this before using distributed module rabit_env = [ f'federated_server_address=localhost:{port}', @@ -34,6 +34,9 @@ def run_worker(port: int, world_size: int, rank: int) -> None: # Specify parameters via map, definition are same as c++ version param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + if with_gpu: + param['tree_method'] = 'gpu_hist' + param['gpu_id'] = rank # Specify validations set to watch performance watchlist = [(dtest, 'eval'), (dtrain, 'train')] @@ -49,7 +52,7 @@ def run_worker(port: int, world_size: int, rank: int) -> None: xgb.rabit.tracker_print("Finished training\n") -def run_test() -> None: +def run_test(with_gpu: bool = False) -> None: port = 9091 world_size = int(sys.argv[1]) @@ -61,7 +64,7 @@ def run_test() -> None: workers = [] for rank in range(world_size): - worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank)) + worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank, with_gpu)) workers.append(worker) worker.start() for worker in workers: @@ -71,3 +74,4 @@ def run_test() -> None: if __name__ == '__main__': run_test() + run_test(with_gpu=True)