From a2df84de449b603af3a6b00d5d9bb47a37bfbbf8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 6 Jul 2022 11:42:36 -0700 Subject: [PATCH 01/37] implement broadcast for federated communicator --- plugin/federated/federated_communicator.h | 59 +++++++++++ plugin/federated/federated_server.cc | 15 +-- src/collective/communicator.h | 77 ++++++++++++++ src/common/io.h | 17 +++ .../cpp/plugin/test_federated_communicator.cc | 100 ++++++++++++++++++ 5 files changed, 258 insertions(+), 10 deletions(-) create mode 100644 plugin/federated/federated_communicator.h create mode 100644 src/collective/communicator.h create mode 100644 tests/cpp/plugin/test_federated_communicator.cc diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h new file mode 100644 index 000000000000..2e8897c184bc --- /dev/null +++ b/plugin/federated/federated_communicator.h @@ -0,0 +1,59 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include "../../src/collective/communicator.h" +#include "../../src/common/io.h" +#include "federated_client.h" + +namespace xgboost { +namespace collective { + +/** + * @brief A federated learning communicator class that handles collective communication . + */ +class FederatedCommunicator : public Communicator { + public: + /** + * @brief Construct a new federated communicator. + * + * @param world_size Total number of processes. + * @param rank Rank of the current process. + */ + FederatedCommunicator(int world_size, int rank, std::string const &server_address, + std::string const &server_cert_path, std::string const &client_key_path, + std::string const &client_cert_path) + : Communicator{world_size, rank} { + client_.reset(new xgboost::federated::FederatedClient( + server_address, rank, xgboost::common::ReadAll(server_cert_path), + xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path))); + } + + /** @brief Insecure communicator for testing only. */ + FederatedCommunicator(int world_size, int rank, std::string const &server_address) + : Communicator{world_size, rank} { + client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); + } + + ~FederatedCommunicator() override { client_.reset(); } + + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) override {} + + void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { + if (GetWorldSize() == 1) return; + if (GetRank() == root) { + std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); + client_->Broadcast(send_buffer, root); + } else { + auto const received = client_->Broadcast("", root); + received.copy(reinterpret_cast(send_receive_buffer), size); + } + } + + private: + std::unique_ptr client_{}; +}; + +} // namespace collective +} // namespace xgboost diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index c38bdc4b8171..b569cd33df2f 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -10,6 +10,8 @@ #include #include +#include "../../src/common/io.h" + namespace xgboost { namespace federated { @@ -201,13 +203,6 @@ grpc::Status FederatedService::Handle(Request const* request, Reply* reply, return grpc::Status::OK; } -std::string ReadFile(char const* path) { - auto stream = std::ifstream(path); - std::ostringstream out; - out << stream.rdbuf(); - return out.str(); -} - void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); @@ -216,10 +211,10 @@ void RunServer(int port, int world_size, char const* server_key_file, char const grpc::ServerBuilder builder; auto options = grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); - options.pem_root_certs = ReadFile(client_cert_file); + options.pem_root_certs = xgboost::common::ReadAll(client_cert_file); auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); - key.private_key = ReadFile(server_key_file); - key.cert_chain = ReadFile(server_cert_file); + key.private_key = xgboost::common::ReadAll(server_key_file); + key.cert_chain = xgboost::common::ReadAll(server_cert_file); options.pem_key_cert_pairs.push_back(key); builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); diff --git a/src/collective/communicator.h b/src/collective/communicator.h new file mode 100644 index 000000000000..28e78f501123 --- /dev/null +++ b/src/collective/communicator.h @@ -0,0 +1,77 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include + +namespace xgboost { +namespace collective { + +/** @brief Defines the integral and floating data types. */ +enum class DataType { kInt, kFloat, kDouble }; + +/** @brief Defines the reduction operation. */ +enum class Operation { kMax, kSum }; + +/** + * @brief A communicator class that handles collective communication. + */ +class Communicator { + public: + /** + * @brief Construct a new communicator. + * + * @param world_size Total number of processes. + * @param rank Rank of the current process. + */ + Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) { + if (world_size < 1) { + LOG(FATAL) << "World size " << world_size << " is less than 1."; + } + if (rank < 0) { + LOG(FATAL) << "Rank " << rank << " is less than 0."; + } + if (rank >= world_size) { + LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << "."; + } + } + + virtual ~Communicator() = default; + + /** @brief Get the total number of processes. */ + int GetWorldSize() const { return world_size_; } + + /** @brief Get the rank of the current processes. */ + int GetRank() const { return rank_; } + + /** @brief Whether the communicator is running in distributed mode. */ + bool IsDistributed() const { return world_size_ > 1; }; + + /** + * @brief Combines values from all processes and distributes the result back to all processes. + * + * @param send_receive_buffer Buffer storing the data. + * @param count Number of elements in the buffer. + * @param data_type Data type stored in the buffer. + * @param op The operation to perform. + */ + virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) = 0; + + /** + * @brief Broadcasts a message from the process with rank `root` to all other processes of the + * group. + * + * @param send_receive_buffer Buffer storing the data. + * @param size Size of the data in bytes. + * @param root Rank of broadcast root. + */ + virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0; + + private: + int const world_size_; + int const rank_; +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/common/io.h b/src/common/io.h index b377623ea138..bcc6c4704a5d 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "common.h" @@ -111,6 +112,22 @@ inline std::string ReadAll(dmlc::Stream* fi, PeekableInStream* fp) { } return buffer; } + +/** + * \brief Read the whole file content into a string. + */ +inline std::string ReadAll(std::string const &path) { + std::ifstream stream(path); + if (!stream.is_open()) { + LOG(FATAL) << "Could not open file " << path; + } + std::string content{std::istreambuf_iterator(stream), std::istreambuf_iterator()}; + if (content.empty()) { + LOG(FATAL) << "Empty file " << path; + } + return content; +} + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_IO_H_ diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc new file mode 100644 index 000000000000..772e5085a51b --- /dev/null +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -0,0 +1,100 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include + +#include + +#include "../../../plugin/federated/federated_communicator.h" +#include "../../../plugin/federated/federated_server.h" + +namespace xgboost { +namespace collective { + +std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) + +class FederatedCommunicatorTest : public ::testing::Test { + public: + static void VerifyBroadcast(int rank) { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + CheckBroadcast(comm, rank); + } + + protected: + void SetUp() override { + server_thread_.reset(new std::thread([this] { + grpc::ServerBuilder builder; + federated::FederatedService service{kWorldSize}; + builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + server_->Wait(); + })); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_->join(); + } + + static void CheckBroadcast(FederatedCommunicator& comm, int rank) { + if (rank == 0) { + std::string buffer{"hello"}; + comm.Broadcast(&buffer[0], buffer.size(), 0); + EXPECT_EQ(buffer, "hello"); + } else { + std::string buffer{" "}; + comm.Broadcast(&buffer[0], buffer.size(), 0); + EXPECT_EQ(buffer, "hello"); + } + } + + static int const kWorldSize{3}; + std::unique_ptr server_thread_; + std::unique_ptr server_; +}; + +TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { + auto construct = []() { FederatedCommunicator comm{0, 0, kServerAddress, "", "", ""}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) { + auto construct = []() { FederatedCommunicator comm{1, -1, kServerAddress, "", "", ""}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { + auto construct = []() { FederatedCommunicator comm{1, 1, kServerAddress, "", "", ""}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { + FederatedCommunicator comm{6, 3, kServerAddress}; + EXPECT_EQ(comm.GetWorldSize(), 6); + EXPECT_EQ(comm.GetRank(), 3); +} + +TEST(FederatedCommunicatorSimpleTest, IsNotDistributed) { + FederatedCommunicator comm{1, 0, kServerAddress}; + EXPECT_FALSE(comm.IsDistributed()); +} + +TEST(FederatedCommunicatorSimpleTest, IsDistributed) { + FederatedCommunicator comm{2, 1, kServerAddress}; + EXPECT_TRUE(comm.IsDistributed()); +} + +TEST_F(FederatedCommunicatorTest, Broadcast) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +} // namespace collective +} // namespace xgboost From 489433432077d9958e776ebf47879ee74786cfb4 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 6 Jul 2022 15:39:14 -0700 Subject: [PATCH 02/37] implement allreduce --- plugin/federated/federated_communicator.h | 41 +++++++++++++++++-- src/collective/communicator.h | 16 ++++++++ .../cpp/plugin/test_federated_communicator.cc | 24 +++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 2e8897c184bc..c27c104ae047 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -10,7 +10,7 @@ namespace xgboost { namespace collective { /** - * @brief A federated learning communicator class that handles collective communication . + * @brief A federated learning communicator class that handles collective communication. */ class FederatedCommunicator : public Communicator { public: @@ -38,12 +38,18 @@ class FederatedCommunicator : public Communicator { ~FederatedCommunicator() override { client_.reset(); } void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override {} + Operation op) override { + std::string const send_buffer(reinterpret_cast(send_receive_buffer), + count * GetTypeSize(data_type)); + auto const received = + client_->Allreduce(send_buffer, ConvertDataType(data_type), ConvertOperation(op)); + received.copy(reinterpret_cast(send_receive_buffer), count * GetTypeSize(data_type)); + } void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { if (GetWorldSize() == 1) return; if (GetRank() == root) { - std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); + std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); client_->Broadcast(send_buffer, root); } else { auto const received = client_->Broadcast("", root); @@ -52,6 +58,35 @@ class FederatedCommunicator : public Communicator { } private: + static xgboost::federated::DataType ConvertDataType(DataType data_type) { + xgboost::federated::DataType result{}; + switch (data_type) { + case DataType::kInt: + result = xgboost::federated::DataType::INT; + break; + case DataType::kFloat: + result = xgboost::federated::DataType::FLOAT; + break; + case DataType::kDouble: + result = xgboost::federated::DataType::DOUBLE; + break; + } + return result; + } + + static xgboost::federated::ReduceOperation ConvertOperation(Operation operation) { + xgboost::federated::ReduceOperation result{}; + switch (operation) { + case Operation::kMax: + result = xgboost::federated::ReduceOperation::MAX; + break; + case Operation::kSum: + result = xgboost::federated::ReduceOperation::SUM; + break; + } + return result; + } + std::unique_ptr client_{}; }; diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 28e78f501123..81e8d9ef10d3 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -10,6 +10,22 @@ namespace collective { /** @brief Defines the integral and floating data types. */ enum class DataType { kInt, kFloat, kDouble }; +inline std::size_t GetTypeSize(DataType data_type) { + std::size_t size{0}; + switch (data_type) { + case DataType::kInt: + size = sizeof(int); + break; + case DataType::kFloat: + size = sizeof(float); + break; + case DataType::kDouble: + size = sizeof(double); + break; + } + return size; +} + /** @brief Defines the reduction operation. */ enum class Operation { kMax, kSum }; diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 772e5085a51b..d4aa08ee17ee 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -16,6 +16,11 @@ std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) class FederatedCommunicatorTest : public ::testing::Test { public: + static void VerifyAllreduce(int rank) { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + CheckAllreduce(comm); + } + static void VerifyBroadcast(int rank) { FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; CheckBroadcast(comm, rank); @@ -38,6 +43,15 @@ class FederatedCommunicatorTest : public ::testing::Test { server_thread_->join(); } + static void CheckAllreduce(FederatedCommunicator& comm) { + int buffer[] = {1, 2, 3, 4, 5}; + comm.AllReduce(buffer, sizeof(buffer)/sizeof(buffer[0]), DataType::kInt, Operation::kSum); + int expected[] = {3, 6, 9, 12, 15}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(buffer[i], expected[i]); + } + } + static void CheckBroadcast(FederatedCommunicator& comm, int rank) { if (rank == 0) { std::string buffer{"hello"}; @@ -86,6 +100,16 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { EXPECT_TRUE(comm.IsDistributed()); } +TEST_F(FederatedCommunicatorTest, Allreduce) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + TEST_F(FederatedCommunicatorTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { From c5fded6a9a885dc40215a357cd97a61f57d10ea7 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 11 Jul 2022 15:49:20 -0700 Subject: [PATCH 03/37] add communicator factory --- plugin/federated/federated_communicator.h | 86 +++++++++++++++ src/collective/communicator_factory.h | 100 ++++++++++++++++++ .../collective/test_communicator_factory.cc | 47 ++++++++ .../cpp/plugin/test_federated_communicator.cc | 86 ++++++++++++++- tests/distributed/test_federated.py | 1 + 5 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 src/collective/communicator_factory.h create mode 100644 tests/cpp/collective/test_communicator_factory.cc diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index c27c104ae047..11430136e37b 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -90,5 +90,91 @@ class FederatedCommunicator : public Communicator { std::unique_ptr client_{}; }; +class FederatedCommunicatorFactory { + public: + FederatedCommunicatorFactory(int argc, char *argv[]) { + // Parse environment variables first. + for (auto const &env_var : env_vars_) { + char const *value = getenv(env_var.c_str()); + if (value != nullptr) { + SetParam(env_var, value); + } + } + + // Command line argument overrides. + for (int i = 0; i < argc; ++i) { + std::string const key_value = argv[i]; + auto const delimiter = key_value.find('='); + if (delimiter != std::string::npos) { + SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1)); + } + } + } + + Communicator *Create() { + if (server_address_.empty()) { + LOG(FATAL) << "Federated server address must be set."; + } + if (world_size_ == 0) { + LOG(FATAL) << "Federated world size must be set."; + } + if (rank_ == -1) { + LOG(FATAL) << "Federated rank must be set."; + } + if (server_cert_.empty()) { + LOG(FATAL) << "Federated server cert must be set."; + } + if (client_key_.empty()) { + LOG(FATAL) << "Federated client key must be set."; + } + if (client_cert_.empty()) { + LOG(FATAL) << "Federated client cert must be set."; + } + return new FederatedCommunicator(world_size_, rank_, server_address_, server_cert_, client_key_, + client_cert_); + } + + std::string const &GetServerAddress() const { return server_address_; } + int GetWorldSize() const { return world_size_; } + int GetRank() const { return rank_; } + std::string const &GetServerCert() const { return server_cert_; } + std::string const &GetClientKey() const { return client_key_; } + std::string const &GetClientCert() const { return client_cert_; } + + private: + void SetParam(std::string const &name, std::string const &val) { + if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { + server_address_ = val; + } else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) { + world_size_ = std::stoi(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { + rank_ = std::stoi(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) { + server_cert_ = val; + } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) { + client_key_ = val; + } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) { + client_cert_ = val; + } + } + + // clang-format off + std::vector const env_vars_{ + "FEDERATED_SERVER_ADDRESS", + "FEDERATED_WORLD_SIZE", + "FEDERATED_RANK", + "FEDERATED_SERVER_CERT", + "FEDERATED_CLIENT_KEY", + "FEDERATED_CLIENT_CERT" }; + // clang-format on + + std::string server_address_{}; + int world_size_{0}; + int rank_{-1}; + std::string server_cert_{}; + std::string client_key_{}; + std::string client_cert_{}; +}; + } // namespace collective } // namespace xgboost diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h new file mode 100644 index 000000000000..5502058490a1 --- /dev/null +++ b/src/collective/communicator_factory.h @@ -0,0 +1,100 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include "communicator.h" + +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_communicator.h" +#endif + +namespace xgboost { +namespace collective { + +enum class CommunicatorType { kUnknown, kRabit, kMPI, kFederated }; + +class CommunicatorFactory { + public: + static constexpr const char* kCommunicatorKey = "XGBOOST_COMMUNICATOR"; + + static void Init(int argc, char* argv[]) { + if (communicator_) { + LOG(FATAL) << "Communicator can only be initialized once."; + } + + auto type = GetTypeFromEnv(); + auto const arg = GetTypeFromArgs(argc, argv); + if (arg != CommunicatorType::kUnknown) { + type = arg; + } + switch (type) { + case CommunicatorType::kRabit: + LOG(FATAL) << "Not implemented yet."; + break; + case CommunicatorType::kMPI: + LOG(FATAL) << "Not implemented yet."; + break; + case CommunicatorType::kFederated: { +#if defined(XGBOOST_USE_FEDERATED) + FederatedCommunicatorFactory factory{argc, argv}; + communicator_.reset(factory.Create()); +#else + LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; +#endif + break; + } + case CommunicatorType::kUnknown: + LOG(FATAL) << "Unknown communicator type."; + break; + } + } + + static void Finalize() { communicator_.reset(); } + + static Communicator* GetCommunicator() { return communicator_.get(); } + + static CommunicatorType GetTypeFromEnv() { + auto* env = std::getenv(kCommunicatorKey); + if (env != nullptr) { + return StringToType(env); + } else { + return CommunicatorType::kUnknown; + } + } + + static CommunicatorType GetTypeFromArgs(int argc, char* argv[]) { + for (int i = 0; i < argc; ++i) { + std::string const key_value = argv[i]; + auto const delimiter = key_value.find('='); + if (delimiter != std::string::npos) { + auto const key = key_value.substr(0, delimiter); + auto const value = key_value.substr(delimiter + 1); + if (!strcasecmp(key.c_str(), kCommunicatorKey)) { + return StringToType(value.c_str()); + } + } + } + + return CommunicatorType::kUnknown; + } + + private: + static CommunicatorType StringToType(char const* str) { + CommunicatorType result = CommunicatorType::kUnknown; + if (!strcasecmp("rabit", str)) { + result = CommunicatorType::kRabit; + } else if (!strcasecmp("mpi", str)) { + result = CommunicatorType::kMPI; + } else if (!strcasecmp("federated", str)) { + result = CommunicatorType::kFederated; + } else { + LOG(FATAL) << "Unknown communicator type " << str; + } + return result; + } + + static thread_local std::unique_ptr communicator_; +}; + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/collective/test_communicator_factory.cc b/tests/cpp/collective/test_communicator_factory.cc new file mode 100644 index 000000000000..3381e78a5b7a --- /dev/null +++ b/tests/cpp/collective/test_communicator_factory.cc @@ -0,0 +1,47 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include + +#include "../../../src/collective/communicator_factory.h" + +namespace xgboost { +namespace collective { + +TEST(CommunicatorFactory, TypeFromEnv) { + unsetenv(CommunicatorFactory::kCommunicatorKey); + EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromEnv()); + + setenv(CommunicatorFactory::kCommunicatorKey, "rabit", 1); + EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromEnv()); + + setenv(CommunicatorFactory::kCommunicatorKey, "MPI", 1); + EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromEnv()); + + setenv(CommunicatorFactory::kCommunicatorKey, "Federated", 1); + EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromEnv()); + + setenv(CommunicatorFactory::kCommunicatorKey, "foo", 1); + EXPECT_THROW(CommunicatorFactory::GetTypeFromEnv(), dmlc::Error); +} + +TEST(CommunicatorFactory, TypeFromArgs) { + char *args[1]; + args[0] = strdup("foo=bar"); + EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromArgs(1, args)); + + args[0] = strdup("xgboost_communicator=rabit"); + EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromArgs(1, args)); + + args[0] = strdup("xgboost_communicator=MPI"); + EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromArgs(1, args)); + + args[0] = strdup("xgboost_communicator=Federated"); + EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromArgs(1, args)); + + args[0] = strdup("xgboost_communicator=foo"); + EXPECT_THROW(CommunicatorFactory::GetTypeFromArgs(1, args), dmlc::Error); +} + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index d4aa08ee17ee..6465ccb207f2 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -45,7 +45,7 @@ class FederatedCommunicatorTest : public ::testing::Test { static void CheckAllreduce(FederatedCommunicator& comm) { int buffer[] = {1, 2, 3, 4, 5}; - comm.AllReduce(buffer, sizeof(buffer)/sizeof(buffer[0]), DataType::kInt, Operation::kSum); + comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt, Operation::kSum); int expected[] = {3, 6, 9, 12, 15}; for (auto i = 0; i < 5; i++) { EXPECT_EQ(buffer[i], expected[i]); @@ -120,5 +120,89 @@ TEST_F(FederatedCommunicatorTest, Broadcast) { } } +TEST(FederatedCommunicatorFactoryTest, ServerAddress) { + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetServerAddress(), ""); + + setenv("FEDERATED_SERVER_ADDRESS", "localhost:9091", 1); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetServerAddress(), "localhost:9091"); + + char *args[1]; + args[0] = strdup("federated_server_address=foo:9091"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetServerAddress(), "foo:9091"); +} + +TEST(FederatedCommunicatorFactoryTest, WorldSize) { + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetWorldSize(), 0); + + setenv("FEDERATED_WORLD_SIZE", "2", 1); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetWorldSize(), 2); + + char *args[1]; + args[0] = strdup("federated_world_size=3"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetWorldSize(), 3); +} + +TEST(FederatedCommunicatorFactoryTest, Rank) { + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetRank(), -1); + + setenv("FEDERATED_RANK", "1", 1); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetRank(), 1); + + char *args[1]; + args[0] = strdup("federated_rank=2"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetRank(), 2); +} + +TEST(FederatedCommunicatorFactoryTest, ServerCert) { + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetServerCert(), ""); + + setenv("FEDERATED_SERVER_CERT", "foo/server-cert.pem", 1); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetServerCert(), "foo/server-cert.pem"); + + char *args[1]; + args[0] = strdup("federated_server_cert=bar/server-cert.pem"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetServerCert(), "bar/server-cert.pem"); +} + +TEST(FederatedCommunicatorFactoryTest, ClientKey) { + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetClientKey(), ""); + + setenv("FEDERATED_CLIENT_KEY", "foo/client-key.pem", 1); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetClientKey(), "foo/client-key.pem"); + + char *args[1]; + args[0] = strdup("federated_client_key=bar/client-key.pem"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetClientKey(), "bar/client-key.pem"); +} + +TEST(FederatedCommunicatorFactoryTest, ClientCert) { + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetClientCert(), ""); + + setenv("FEDERATED_CLIENT_CERT", "foo/client-cert.pem", 1); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetClientCert(), "foo/client-cert.pem"); + + char *args[1]; + args[0] = strdup("federated_client_cert=bar/client-cert.pem"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetClientCert(), "bar/client-cert.pem"); +} + } // namespace collective } // namespace xgboost diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index cddd104e922c..4b3504d636d2 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -20,6 +20,7 @@ def run_server(port: int, world_size: int) -> None: def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None: # Always call this before using distributed module rabit_env = [ + 'xgboost_communicator=federated', f'federated_server_address=localhost:{port}', f'federated_world_size={world_size}', f'federated_rank={rank}', From b0831a0d2bd9c9fbdc23a4533024317d1f74cd63 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 12 Jul 2022 11:37:14 -0700 Subject: [PATCH 04/37] add device adapter --- src/collective/communicator.h | 5 +- src/collective/communicator_factory.h | 2 +- src/collective/device_communicator.cuh | 24 ++++ .../device_communicator_adapter.cuh | 72 ++++++++++++ tests/cpp/plugin/test_federated_adapter.cu | 106 ++++++++++++++++++ 5 files changed, 207 insertions(+), 2 deletions(-) create mode 100644 src/collective/device_communicator.cuh create mode 100644 src/collective/device_communicator_adapter.cuh create mode 100644 tests/cpp/plugin/test_federated_adapter.cu diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 81e8d9ef10d3..c08e886b5e71 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -8,7 +8,7 @@ namespace xgboost { namespace collective { /** @brief Defines the integral and floating data types. */ -enum class DataType { kInt, kFloat, kDouble }; +enum class DataType { kInt, kFloat, kDouble, kSizeT }; inline std::size_t GetTypeSize(DataType data_type) { std::size_t size{0}; @@ -22,6 +22,9 @@ inline std::size_t GetTypeSize(DataType data_type) { case DataType::kDouble: size = sizeof(double); break; + case DataType::kSizeT: + size = sizeof(std::size_t); + break; } return size; } diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h index 5502058490a1..6345276cd5a2 100644 --- a/src/collective/communicator_factory.h +++ b/src/collective/communicator_factory.h @@ -15,7 +15,7 @@ enum class CommunicatorType { kUnknown, kRabit, kMPI, kFederated }; class CommunicatorFactory { public: - static constexpr const char* kCommunicatorKey = "XGBOOST_COMMUNICATOR"; + static constexpr char const* kCommunicatorKey = "XGBOOST_COMMUNICATOR"; static void Init(int argc, char* argv[]) { if (communicator_) { diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh new file mode 100644 index 000000000000..af47a29e7f6a --- /dev/null +++ b/src/collective/device_communicator.cuh @@ -0,0 +1,24 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include + +#include "../common/device_helpers.cuh" + +namespace xgboost { +namespace collective { + +class DeviceCommunicator { + public: + virtual ~DeviceCommunicator() = default; + + virtual void DeviceAllReduceSum(double *send_receive_buffer, int count) = 0; + + virtual void DeviceAllGatherV(void const *send_buffer, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *receive_buffer) = 0; +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh new file mode 100644 index 000000000000..1fdb9e786830 --- /dev/null +++ b/src/collective/device_communicator_adapter.cuh @@ -0,0 +1,72 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once + +#include "communicator.h" +#include "device_communicator.cuh" + +namespace xgboost { +namespace collective { + +class DeviceCommunicatorAdapter : public DeviceCommunicator { + public: + DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator) + : device_ordinal_{device_ordinal}, communicator_{communicator} { + if (device_ordinal_ < 0) { + LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; + } + if (communicator_ == nullptr) { + LOG(FATAL) << "Communicator cannot be null."; + } + } + + ~DeviceCommunicatorAdapter() override = default; + + void DeviceAllReduceSum(double *send_receive_buffer, int count) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + auto size = count * sizeof(double); + host_buffer_.reserve(size); + dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); + communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum); + dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); + } + + void DeviceAllGatherV(void const *send_buffer, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *receive_buffer) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + int const world_size = communicator_->GetWorldSize(); + int const rank = communicator_->GetRank(); + + segments->clear(); + segments->resize(world_size, 0); + segments->at(rank) = length_bytes; + communicator_->AllReduce(segments->data(), segments->size(), DataType::kSizeT, Operation::kMax); + auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); + receive_buffer->resize(total_bytes); + + host_buffer_.reserve(total_bytes); + size_t offset = 0; + for (int32_t i = 0; i < world_size; ++i) { + size_t as_bytes = segments->at(i); + if (i == rank) { + dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, + segments->at(rank), cudaMemcpyDefault)); + } + communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i); + offset += as_bytes; + } + dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, + cudaMemcpyDefault)); + } + + private: + int const device_ordinal_; + Communicator *communicator_; + /// Host buffer used to call communicator functions. + std::vector host_buffer_{}; +}; + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu new file mode 100644 index 000000000000..da0c82844c4f --- /dev/null +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -0,0 +1,106 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include +#include + +#include + +#include "../../../plugin/federated/federated_communicator.h" +#include "../../../plugin/federated/federated_server.h" +#include "../../../src/collective/device_communicator_adapter.cuh" + +namespace xgboost { +namespace collective { + +std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) + +class FederatedAdapterTest : public ::testing::Test { + protected: + void SetUp() override { + server_thread_.reset(new std::thread([this] { + grpc::ServerBuilder builder; + federated::FederatedService service{kWorldSize}; + builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + server_->Wait(); + })); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_->join(); + } + + static int const kWorldSize{2}; + std::unique_ptr server_thread_; + std::unique_ptr server_; +}; + +TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { + auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { + auto construct = []() { DeviceCommunicatorAdapter adapter{0, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread([rank] { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + DeviceCommunicatorAdapter adapter{rank, &comm}; + int const count = 3; + thrust::device_vector buffer(count, 0); + thrust::sequence(buffer.begin(), buffer.end()); + adapter.DeviceAllReduceSum(buffer.data().get(), count); + thrust::host_vector host_buffer = buffer; + EXPECT_EQ(host_buffer.size(), count); + for (auto i = 0; i < count; i++) { + EXPECT_EQ(host_buffer[i], i * 2); + } + })); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedAdapterTest, DeviceAllGatherV) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread([rank] { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + DeviceCommunicatorAdapter adapter{rank, &comm}; + + int const count = rank + 2; + thrust::device_vector buffer(count, 0); + thrust::sequence(buffer.begin(), buffer.end()); + std::vector segments(kWorldSize); + dh::caching_device_vector receive_buffer{}; + + adapter.DeviceAllGatherV(buffer.data().get(), count, &segments, + &receive_buffer); + + EXPECT_EQ(segments[0], 2); + EXPECT_EQ(segments[1], 3); + thrust::host_vector host_buffer = receive_buffer; + EXPECT_EQ(host_buffer.size(), 5); + int expected[] = {0, 1, 0, 1, 2}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(host_buffer[i], expected[i]); + } + })); + } + for (auto& thread : threads) { + thread.join(); + } +} + +} // namespace collective +} // namespace xgboost From 05b2cebd3a60867114da4c1577d97384989bc98a Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 12 Jul 2022 17:02:34 -0700 Subject: [PATCH 05/37] add device communicator to factory --- src/collective/communicator_factory.cu | 68 +++++++++++++++++++ src/collective/communicator_factory.h | 51 +++++--------- .../device_communicator_adapter.cuh | 4 +- 3 files changed, 87 insertions(+), 36 deletions(-) create mode 100644 src/collective/communicator_factory.cu diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu new file mode 100644 index 000000000000..a6de8d4b8e61 --- /dev/null +++ b/src/collective/communicator_factory.cu @@ -0,0 +1,68 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "communicator_factory.h" +#include "device_communicator_adapter.cuh" + +namespace xgboost { +namespace collective { + +thread_local std::unique_ptr CommunicatorFactory::instance_{}; + +void CommunicatorFactory::Init(int argc, char* argv[]) { + if (instance_) { + LOG(FATAL) << "Communicator factory can only be initialized once."; + } + + auto type = GetTypeFromEnv(); + auto const arg = GetTypeFromArgs(argc, argv); + if (arg != CommunicatorType::kUnknown) { + type = arg; + } + switch (type) { + case CommunicatorType::kRabit: + LOG(FATAL) << "Not implemented yet."; + break; + case CommunicatorType::kMPI: + LOG(FATAL) << "Not implemented yet."; + break; + case CommunicatorType::kFederated: { +#if defined(XGBOOST_USE_FEDERATED) + FederatedCommunicatorFactory factory{argc, argv}; + auto* comm = factory.Create(); + instance_.reset(new CommunicatorFactory(type, comm)); +#else + LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; +#endif + break; + } + case CommunicatorType::kUnknown: + LOG(FATAL) << "Unknown communicator type."; + break; + } +} + +void CommunicatorFactory::Finalize() { instance_.reset(); } + +CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) + : type_{type}, communicator_{communicator} {} + +DeviceCommunicator* CommunicatorFactory::GetDeviceCommunicator(int device_ordinal) { + if (!device_communicator_) { +#ifdef XGBOOST_USE_NCCL + if (type_ != CommunicatorType::kFederated) { + // Use NCCL communicator. + LOG(FATAL) << "Not implemented yet."; + } else { + device_communicator_.reset( + new DeviceCommunicatorAdapter(device_ordinal, communicator_.get())); + } +#else + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, communicator_.get())); +#endif + } + return device_communicator_.get(); +} + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h index 6345276cd5a2..882f54d36891 100644 --- a/src/collective/communicator_factory.h +++ b/src/collective/communicator_factory.h @@ -13,46 +13,23 @@ namespace collective { enum class CommunicatorType { kUnknown, kRabit, kMPI, kFederated }; +class DeviceCommunicator; + class CommunicatorFactory { public: static constexpr char const* kCommunicatorKey = "XGBOOST_COMMUNICATOR"; - static void Init(int argc, char* argv[]) { - if (communicator_) { - LOG(FATAL) << "Communicator can only be initialized once."; - } + static void Init(int argc, char* argv[]); - auto type = GetTypeFromEnv(); - auto const arg = GetTypeFromArgs(argc, argv); - if (arg != CommunicatorType::kUnknown) { - type = arg; - } - switch (type) { - case CommunicatorType::kRabit: - LOG(FATAL) << "Not implemented yet."; - break; - case CommunicatorType::kMPI: - LOG(FATAL) << "Not implemented yet."; - break; - case CommunicatorType::kFederated: { -#if defined(XGBOOST_USE_FEDERATED) - FederatedCommunicatorFactory factory{argc, argv}; - communicator_.reset(factory.Create()); -#else - LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; -#endif - break; - } - case CommunicatorType::kUnknown: - LOG(FATAL) << "Unknown communicator type."; - break; - } - } + static void Finalize(); + + static CommunicatorFactory* GetInstance() { return instance_.get(); } - static void Finalize() { communicator_.reset(); } + Communicator* GetCommunicator() { return communicator_.get(); } - static Communicator* GetCommunicator() { return communicator_.get(); } + DeviceCommunicator* GetDeviceCommunicator(int device_ordinal); + /** @brief Get the communicator type from environment variables. Visible for testing. */ static CommunicatorType GetTypeFromEnv() { auto* env = std::getenv(kCommunicatorKey); if (env != nullptr) { @@ -62,6 +39,7 @@ class CommunicatorFactory { } } + /** @brief Get the communicator type from arguments. Visible for testing. */ static CommunicatorType GetTypeFromArgs(int argc, char* argv[]) { for (int i = 0; i < argc; ++i) { std::string const key_value = argv[i]; @@ -74,10 +52,12 @@ class CommunicatorFactory { } } } - return CommunicatorType::kUnknown; } + private: + CommunicatorFactory(CommunicatorType type, Communicator* communicator); + private: static CommunicatorType StringToType(char const* str) { CommunicatorType result = CommunicatorType::kUnknown; @@ -93,7 +73,10 @@ class CommunicatorFactory { return result; } - static thread_local std::unique_ptr communicator_; + static thread_local std::unique_ptr instance_; + CommunicatorType type_; + std::unique_ptr communicator_; + std::unique_ptr device_communicator_; }; } // namespace collective diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index 1fdb9e786830..e884b9114746 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -51,8 +51,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { for (int32_t i = 0; i < world_size; ++i) { size_t as_bytes = segments->at(i); if (i == rank) { - dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, - segments->at(rank), cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank), + cudaMemcpyDefault)); } communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i); offset += as_bytes; From f39d0a9f6c3a2aeaa899b1b31b045bf82ee48ad5 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 15 Jul 2022 18:09:55 -0700 Subject: [PATCH 06/37] add rabit communicator --- src/collective/rabit_communicator.h | 74 +++++++++++++++++++ tests/cpp/CMakeLists.txt | 2 +- .../cpp/collective/test_rabit_communicator.cc | 45 +++++++++++ 3 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 src/collective/rabit_communicator.h create mode 100644 tests/cpp/collective/test_rabit_communicator.cc diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h new file mode 100644 index 000000000000..193a5423251e --- /dev/null +++ b/src/collective/rabit_communicator.h @@ -0,0 +1,74 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include + +#include "communicator.h" + +namespace xgboost { +namespace collective { + +class RabitCommunicator : public Communicator { + public: + RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} + + ~RabitCommunicator() override { rabit::Finalize(); } + + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) override { + switch (data_type) { + case DataType::kInt: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kFloat: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kDouble: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kSizeT: + DoAllReduce(send_receive_buffer, count, op); + break; + default: + LOG(FATAL) << "Unknown data type"; + } + } + + void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { + rabit::Broadcast(send_receive_buffer, size, root); + } + + private: + template + void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { + switch (op) { + case Operation::kMax: + rabit::Allreduce(static_cast(send_receive_buffer), count); + break; + case Operation::kSum: + rabit::Allreduce(static_cast(send_receive_buffer), count); + break; + default: + LOG(FATAL) << "Unknown allreduce operation"; + } + } +}; + +class RabitCommunicatorFactory { + public: + RabitCommunicatorFactory(int argc, char *argv[]) { + rabit::Init(argc, argv); + world_size_ = rabit::GetWorldSize(); + rank_ = rabit::GetRank(); + } + + Communicator *Create() const { return new RabitCommunicator(world_size_, rank_); } + + private: + int world_size_; + int rank_; +}; + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 9dfd429d01e2..51cdecd9d4be 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -22,7 +22,7 @@ if (PLUGIN_FEDERATED) target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) target_link_libraries(testxgboost PRIVATE federated_client) else (PLUGIN_FEDERATED) - file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc") + file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.*") list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES}) endif (PLUGIN_FEDERATED) diff --git a/tests/cpp/collective/test_rabit_communicator.cc b/tests/cpp/collective/test_rabit_communicator.cc new file mode 100644 index 000000000000..cd605e090d0f --- /dev/null +++ b/tests/cpp/collective/test_rabit_communicator.cc @@ -0,0 +1,45 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include + +#include + +#include "../../../src/collective/rabit_communicator.h" + +namespace xgboost { +namespace collective { + +TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { + auto construct = []() { RabitCommunicator comm{0, 0}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooSmall) { + auto construct = []() { RabitCommunicator comm{1, -1}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooBig) { + auto construct = []() { RabitCommunicator comm{1, 1}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) { + RabitCommunicator comm{6, 3}; + EXPECT_EQ(comm.GetWorldSize(), 6); + EXPECT_EQ(comm.GetRank(), 3); +} + +TEST(RabitCommunicatorSimpleTest, IsNotDistributed) { + RabitCommunicator comm{1, 0}; + EXPECT_FALSE(comm.IsDistributed()); +} + +TEST(RabitCommunicatorSimpleTest, IsDistributed) { + RabitCommunicator comm{2, 1}; + EXPECT_TRUE(comm.IsDistributed()); +} + +} // namespace collective +} // namespace xgboost From cd0098d0df917556653e310f0675eeb03124942f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 18 Jul 2022 14:43:08 -0700 Subject: [PATCH 07/37] add rabit communicator to the factory --- src/collective/communicator_factory.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu index a6de8d4b8e61..1c301a2a3492 100644 --- a/src/collective/communicator_factory.cu +++ b/src/collective/communicator_factory.cu @@ -1,8 +1,11 @@ /*! * Copyright 2022 XGBoost contributors */ +#include + #include "communicator_factory.h" #include "device_communicator_adapter.cuh" +#include "rabit_communicator.h" namespace xgboost { namespace collective { @@ -20,9 +23,12 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { type = arg; } switch (type) { - case CommunicatorType::kRabit: - LOG(FATAL) << "Not implemented yet."; + case CommunicatorType::kRabit: { + RabitCommunicatorFactory factory{argc, argv}; + auto* comm = factory.Create(); + instance_.reset(new CommunicatorFactory(type, comm)); break; + } case CommunicatorType::kMPI: LOG(FATAL) << "Not implemented yet."; break; From 198ac9434e16c4526218c0bf75623e91adedb10f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 18 Jul 2022 16:25:51 -0700 Subject: [PATCH 08/37] add nccl device communicator --- src/collective/communicator.h | 7 +- src/collective/communicator_factory.cu | 4 +- src/collective/nccl_device_communicator.cuh | 144 ++++++++++++++++++ .../test_nccl_device_communicator.cu | 22 +++ .../cpp/collective/test_rabit_communicator.cc | 2 - 5 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 src/collective/nccl_device_communicator.cuh create mode 100644 tests/cpp/collective/test_nccl_device_communicator.cu diff --git a/src/collective/communicator.h b/src/collective/communicator.h index c08e886b5e71..4324ef9ab318 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -8,7 +8,7 @@ namespace xgboost { namespace collective { /** @brief Defines the integral and floating data types. */ -enum class DataType { kInt, kFloat, kDouble, kSizeT }; +enum class DataType { kInt, kFloat, kDouble, kSizeT, kUInt64 }; inline std::size_t GetTypeSize(DataType data_type) { std::size_t size{0}; @@ -25,6 +25,11 @@ inline std::size_t GetTypeSize(DataType data_type) { case DataType::kSizeT: size = sizeof(std::size_t); break; + case DataType::kUInt64: + size = sizeof(uint64_t); + break; + default: + LOG(FATAL) << "Unknown data type."; } return size; } diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu index 1c301a2a3492..3efaed4fefbd 100644 --- a/src/collective/communicator_factory.cu +++ b/src/collective/communicator_factory.cu @@ -6,6 +6,7 @@ #include "communicator_factory.h" #include "device_communicator_adapter.cuh" #include "rabit_communicator.h" +#include "nccl_device_communicator.cuh" namespace xgboost { namespace collective { @@ -57,8 +58,7 @@ DeviceCommunicator* CommunicatorFactory::GetDeviceCommunicator(int device_ordina if (!device_communicator_) { #ifdef XGBOOST_USE_NCCL if (type_ != CommunicatorType::kFederated) { - // Use NCCL communicator. - LOG(FATAL) << "Not implemented yet."; + device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, communicator_.get())); } else { device_communicator_.reset( new DeviceCommunicatorAdapter(device_ordinal, communicator_.get())); diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh new file mode 100644 index 000000000000..bcb153a153c9 --- /dev/null +++ b/src/collective/nccl_device_communicator.cuh @@ -0,0 +1,144 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once + +#include "../common/device_helpers.cuh" +#include "communicator.h" +#include "device_communicator.cuh" + +namespace xgboost { +namespace collective { + +class NcclDeviceCommunicator : public DeviceCommunicator { + public: + NcclDeviceCommunicator(int device_ordinal, Communicator *communicator) + : device_ordinal_{device_ordinal}, communicator_{communicator} { + if (device_ordinal_ < 0) { + LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; + } + if (communicator_ == nullptr) { + LOG(FATAL) << "Communicator cannot be null."; + } + + int32_t const rank = communicator_->GetRank(); + int32_t const world = communicator_->GetWorldSize(); + + std::vector uuids(world * kUuidLength, 0); + auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; + auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); + GetCudaUUID(s_this_uuid); + + // TODO(rongou): replace this with allgather. + communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); + + std::vector> converted(world); + size_t j = 0; + for (size_t i = 0; i < uuids.size(); i += kUuidLength) { + converted[j] = xgboost::common::Span{uuids.data() + i, kUuidLength}; + j++; + } + + auto iter = std::unique(converted.begin(), converted.end()); + auto n_uniques = std::distance(converted.begin(), iter); + + CHECK_EQ(n_uniques, world) + << "Multiple processes within communication group running on same CUDA " + << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; + + nccl_unique_id_ = GetUniqueId(); + dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); + dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); + } + + ~NcclDeviceCommunicator() override { + dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); + ncclCommDestroy(nccl_comm_); + if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { + LOG(CONSOLE) << "======== NCCL Statistics========"; + LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; + LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576; + } + } + + void DeviceAllReduceSum(double *send_receive_buffer, int count) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble, + ncclSum, nccl_comm_, cuda_stream_)); + allreduce_bytes_ += count * sizeof(double); + allreduce_calls_ += 1; + } + + void DeviceAllGatherV(void const *send_buffer, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *receive_buffer) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + int const world_size = communicator_->GetWorldSize(); + int const rank = communicator_->GetRank(); + + segments->clear(); + segments->resize(world_size, 0); + segments->at(rank) = length_bytes; + communicator_->AllReduce(segments->data(), segments->size(), DataType::kSizeT, Operation::kMax); + auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); + receive_buffer->resize(total_bytes); + + size_t offset = 0; + dh::safe_nccl(ncclGroupStart()); + for (int32_t i = 0; i < world_size; ++i) { + size_t as_bytes = segments->at(i); + dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, + ncclChar, i, nccl_comm_, cuda_stream_)); + offset += as_bytes; + } + dh::safe_nccl(ncclGroupEnd()); + } + + private: + static constexpr std::size_t kUuidLength = + sizeof(std::declval().uuid) / sizeof(uint64_t); + + void GetCudaUUID(xgboost::common::Span const &uuid) const { + cudaDeviceProp prob{}; + dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_)); + std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); + } + + static std::string PrintUUID(xgboost::common::Span const &uuid) { + std::stringstream ss; + for (auto v : uuid) { + ss << std::hex << v; + } + return ss.str(); + } + + /** + * \fn ncclUniqueId GetUniqueId() + * + * \brief Gets the Unique ID from NCCL to be used in setting up interprocess + * communication + * + * \return the Unique ID + */ + ncclUniqueId GetUniqueId() { + static const int kRootRank = 0; + ncclUniqueId id; + if (communicator_->GetRank() == kRootRank) { + dh::safe_nccl(ncclGetUniqueId(&id)); + } + communicator_->Broadcast(static_cast(&id), sizeof(ncclUniqueId), + static_cast(kRootRank)); + return id; + } + + int const device_ordinal_; + Communicator *communicator_; + ncclComm_t nccl_comm_{}; + cudaStream_t cuda_stream_{}; + ncclUniqueId nccl_unique_id_{}; + size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. + size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. +}; + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu new file mode 100644 index 000000000000..168d6f60b844 --- /dev/null +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -0,0 +1,22 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include + +#include "../../../src/collective/nccl_device_communicator.cuh" + +namespace xgboost { +namespace collective { + +TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { + auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) { + auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/collective/test_rabit_communicator.cc b/tests/cpp/collective/test_rabit_communicator.cc index cd605e090d0f..2b7dc17edfcb 100644 --- a/tests/cpp/collective/test_rabit_communicator.cc +++ b/tests/cpp/collective/test_rabit_communicator.cc @@ -3,8 +3,6 @@ */ #include -#include - #include "../../../src/collective/rabit_communicator.h" namespace xgboost { From 8ae0d7af9909c86b6a363f982cd942ae61a63139 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 19 Jul 2022 10:14:11 -0700 Subject: [PATCH 09/37] add synchronize to device communicator --- src/collective/communicator_factory.cu | 4 +--- src/collective/device_communicator.cuh | 10 ++++++---- src/collective/device_communicator_adapter.cuh | 11 +++++++---- src/collective/nccl_device_communicator.cuh | 12 ++++++++---- tests/cpp/plugin/test_federated_adapter.cu | 5 ++--- 5 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu index 3efaed4fefbd..8ba388342a86 100644 --- a/src/collective/communicator_factory.cu +++ b/src/collective/communicator_factory.cu @@ -1,12 +1,10 @@ /*! * Copyright 2022 XGBoost contributors */ -#include - #include "communicator_factory.h" #include "device_communicator_adapter.cuh" -#include "rabit_communicator.h" #include "nccl_device_communicator.cuh" +#include "rabit_communicator.h" namespace xgboost { namespace collective { diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh index af47a29e7f6a..9efb978897ca 100644 --- a/src/collective/device_communicator.cuh +++ b/src/collective/device_communicator.cuh @@ -13,11 +13,13 @@ class DeviceCommunicator { public: virtual ~DeviceCommunicator() = default; - virtual void DeviceAllReduceSum(double *send_receive_buffer, int count) = 0; + virtual void AllReduceSum(double *send_receive_buffer, int count) = 0; - virtual void DeviceAllGatherV(void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) = 0; + virtual void AllGatherV(void const *send_buffer, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *receive_buffer) = 0; + + virtual void Synchronize() = 0; }; } // namespace collective diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index e884b9114746..266dc1d0f9ca 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -23,7 +23,7 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { ~DeviceCommunicatorAdapter() override = default; - void DeviceAllReduceSum(double *send_receive_buffer, int count) override { + void AllReduceSum(double *send_receive_buffer, int count) override { dh::safe_cuda(cudaSetDevice(device_ordinal_)); auto size = count * sizeof(double); host_buffer_.reserve(size); @@ -32,9 +32,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); } - void DeviceAllGatherV(void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) override { + void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, + dh::caching_device_vector *receive_buffer) override { dh::safe_cuda(cudaSetDevice(device_ordinal_)); int const world_size = communicator_->GetWorldSize(); int const rank = communicator_->GetRank(); @@ -61,6 +60,10 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { cudaMemcpyDefault)); } + void Synchronize() override { + // Noop. + } + private: int const device_ordinal_; Communicator *communicator_; diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index bcb153a153c9..54bff82f9dbe 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -61,7 +61,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator { } } - void DeviceAllReduceSum(double *send_receive_buffer, int count) override { + void AllReduceSum(double *send_receive_buffer, int count) override { dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble, ncclSum, nccl_comm_, cuda_stream_)); @@ -69,9 +69,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator { allreduce_calls_ += 1; } - void DeviceAllGatherV(void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) override { + void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, + dh::caching_device_vector *receive_buffer) override { dh::safe_cuda(cudaSetDevice(device_ordinal_)); int const world_size = communicator_->GetWorldSize(); int const rank = communicator_->GetRank(); @@ -94,6 +93,11 @@ class NcclDeviceCommunicator : public DeviceCommunicator { dh::safe_nccl(ncclGroupEnd()); } + void Synchronize() override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); + } + private: static constexpr std::size_t kUuidLength = sizeof(std::declval().uuid) / sizeof(uint64_t); diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index da0c82844c4f..09187f940c5f 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -58,7 +58,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { int const count = 3; thrust::device_vector buffer(count, 0); thrust::sequence(buffer.begin(), buffer.end()); - adapter.DeviceAllReduceSum(buffer.data().get(), count); + adapter.AllReduceSum(buffer.data().get(), count); thrust::host_vector host_buffer = buffer; EXPECT_EQ(host_buffer.size(), count); for (auto i = 0; i < count; i++) { @@ -84,8 +84,7 @@ TEST_F(FederatedAdapterTest, DeviceAllGatherV) { std::vector segments(kWorldSize); dh::caching_device_vector receive_buffer{}; - adapter.DeviceAllGatherV(buffer.data().get(), count, &segments, - &receive_buffer); + adapter.AllGatherV(buffer.data().get(), count, &segments, &receive_buffer); EXPECT_EQ(segments[0], 2); EXPECT_EQ(segments[1], 3); From de9580cebbc8ce4600a52af8e4f6ae9bfc5654f7 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 21 Jul 2022 09:59:27 -0700 Subject: [PATCH 10/37] add back print and getprocessorname --- plugin/federated/federated_communicator.h | 4 ++++ src/collective/communicator.h | 10 ++++++++++ src/collective/rabit_communicator.h | 4 ++++ 3 files changed, 18 insertions(+) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 11430136e37b..bf6e15c24970 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -57,6 +57,10 @@ class FederatedCommunicator : public Communicator { } } + std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); } + + void Print(const std::string &message) override { LOG(CONSOLE) << message; } + private: static xgboost::federated::DataType ConvertDataType(DataType data_type) { xgboost::federated::DataType result{}; diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 4324ef9ab318..7d96b1b4aec2 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -92,6 +92,16 @@ class Communicator { */ virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0; + /** + * @brief Gets the name of the processor. + */ + virtual std::string GetProcessorName() = 0; + + /** + * @brief Prints the message. + */ + virtual void Print(std::string const &message) = 0; + private: int const world_size_; int const rank_; diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 193a5423251e..a4c9dfab9f94 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -39,6 +39,10 @@ class RabitCommunicator : public Communicator { rabit::Broadcast(send_receive_buffer, size, root); } + std::string GetProcessorName() override { return rabit::GetProcessorName(); } + + void Print(const std::string &message) override { rabit::TrackerPrint(message); } + private: template void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { From 69ea687a3f9e43ff0c4e2b0e57fa2e6101e59efc Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 22 Jul 2022 11:25:09 -0700 Subject: [PATCH 11/37] add python wrapper and c api --- include/xgboost/c_api.h | 82 +++++++ plugin/federated/federated_communicator.h | 24 +- python-package/xgboost/collective.py | 217 ++++++++++++++++++ src/c_api/c_api.cc | 60 +++++ src/collective/communicator.h | 38 ++- .../device_communicator_adapter.cuh | 3 +- src/collective/nccl_device_communicator.cuh | 3 +- src/collective/rabit_communicator.h | 22 +- .../cpp/plugin/test_federated_communicator.cc | 2 +- 9 files changed, 432 insertions(+), 19 deletions(-) create mode 100644 python-package/xgboost/collective.py diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 17cd5f4af36d..c8bb61d8d5cd 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1378,4 +1378,86 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config, bst_ulong *out_dim, bst_ulong const **out_shape, float const **out_scores); + +/*! + * \brief initialize the collective communicator, + * call this once before using anything + * The additional arguments is not necessary. + * Usually the communicator will detect settings + * from environment variables. + * \param argc number of arguments in argv + * \param argv the array of input arguments + */ +XGB_DLL int XGCommunicatorInit(int argc, char *argv[]); + +/*! + * \brief finalize the collective communicator, + * call this function after you finished all jobs. + * \return true if the communicator is finalized successfully otherwise false + */ +XGB_DLL int XGCommunicatorFinalize(void); + +/*! + * \brief get rank of current process + * \return rank number of worker + * */ +XGB_DLL int XGCommunicatorGetRank(void); + +/*! + * \brief get total number of process + * \return total world size + * */ +XGB_DLL int XGCommunicatorGetWorldSize(void); + +/*! + * \brief get if the communicator is distributed + * \return if the communicator is distributed + * */ +XGB_DLL int XGCommunicatorIsDistributed(void); + +/*! + * \brief print the msg to the communicator, + * this function can be used to communicate the information of the progress to + * the user who monitors the communicator + * \param message the message to be printed + */ +XGB_DLL int XGCommunicatorPrint(char const *message); + +/*! + * \brief get name of processor + * \param out_name hold output string + * \param out_len hold length of output string + * \param max_len maximum buffer length of input + */ +XGB_DLL int XGCommunicatorGetProcessorName(char *out_name, + bst_ulong *out_len, + bst_ulong max_len); +/*! + * \brief broadcast an memory region to all others from root + * + * Example: int a = 1; Broadcast(&a, sizeof(a), root); + * \param send_receive_buffer the pointer to send or receive buffer, + * \param size the size of the data + * \param root the root of process + */ +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, std::size_t size, int root); + +/*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * + * Example Usage: the following code gives sum of the result + * vector data(10); + * ... + * Allreduce(&data[0], data.size()); + * ... + * \param send_receive_buffer buffer for both sending and receiving data + * \param count number of elements to be reduced + * \param enum_dtype the enumeration of data type, see xgboost::collective::DataType in communicator.h + * \param enum_op the enumeration of operation type, see xgboost::collective::Operation in communicator.h + */ +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, std::size_t count, int enum_dtype, + int enum_op); + + #endif // XGBOOST_C_API_H_ diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index bf6e15c24970..524c5b3bce78 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -65,15 +65,32 @@ class FederatedCommunicator : public Communicator { static xgboost::federated::DataType ConvertDataType(DataType data_type) { xgboost::federated::DataType result{}; switch (data_type) { - case DataType::kInt: + case DataType::kInt8: + result = xgboost::federated::DataType::CHAR; + break; + case DataType::kUInt8: + result = xgboost::federated::DataType::UCHAR; + break; + case DataType::kInt32: result = xgboost::federated::DataType::INT; break; + case DataType::kUInt32: + result = xgboost::federated::DataType::UINT; + break; + case DataType::kInt64: + result = xgboost::federated::DataType::LONG; + break; + case DataType::kUInt64: + result = xgboost::federated::DataType::ULONG; + break; case DataType::kFloat: result = xgboost::federated::DataType::FLOAT; break; case DataType::kDouble: result = xgboost::federated::DataType::DOUBLE; break; + default: + LOG(FATAL) << "Unknown data type."; } return result; } @@ -84,9 +101,14 @@ class FederatedCommunicator : public Communicator { case Operation::kMax: result = xgboost::federated::ReduceOperation::MAX; break; + case Operation::kMin: + result = xgboost::federated::ReduceOperation::MIN; + break; case Operation::kSum: result = xgboost::federated::ReduceOperation::SUM; break; + default: + LOG(FATAL) << "Unknown reduce operation."; } return result; } diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py new file mode 100644 index 000000000000..26a4d2e28fc4 --- /dev/null +++ b/python-package/xgboost/collective.py @@ -0,0 +1,217 @@ +"""XGBoost collective communication related API.""" +import ctypes +from enum import IntEnum, unique +import logging +import pickle +from typing import Any, TypeVar, Optional, cast, List, Union + +import numpy as np + +from .core import _LIB, c_str, _check_call + +LOGGER = logging.getLogger("[xgboost.collective]") + + +def _init_collective() -> None: + """internal library initializer.""" + if _LIB is not None: + _LIB.XGCommunicatorGetRank.restype = ctypes.c_int + _LIB.XGCommunicatorGetWorldSize.restype = ctypes.c_int + _LIB.XGCommunicatorIsDistributed.restype = ctypes.c_int + + +def init(args: Optional[List[bytes]] = None) -> None: + """Initialize the collective library with arguments""" + if args is None: + args = [] + arr = (ctypes.c_char_p * len(args))() + arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args) + _LIB.XGCommunicatorInit(len(arr), arr) + + +def finalize() -> None: + """Finalize the process, notify tracker everything is done.""" + _LIB.XGCommunicatorFinalize() + + +def get_rank() -> int: + """Get rank of current process. + + Returns + ------- + rank : int + Rank of current process. + """ + ret = _LIB.XGCommunicatorGetRank() + return ret + + +def get_world_size() -> int: + """Get total number workers. + + Returns + ------- + n : int + Total number of process. + """ + ret = _LIB.XGCommunicatorGetWorldSize() + return ret + + +def is_distributed() -> int: + """If the collective communicator is distributed.""" + is_dist = _LIB.XGCommunicatorIsDistributed() + return is_dist + + +def communicator_print(msg: Any) -> None: + """Print message to the communicator. + + This function can be used to communicate the information of + the progress to the communicator. + + Parameters + ---------- + msg : str + The message to be printed to the communicator. + """ + if not isinstance(msg, str): + msg = str(msg) + is_dist = _LIB.XGCommunicatorIsDistributed() + if is_dist != 0: + _check_call(_LIB.XGCommunicatorPrint(c_str(msg))) + else: + print(msg.strip(), flush=True) + + +def get_processor_name() -> bytes: + """Get the processor name. + + Returns + ------- + name : str + the name of processor(host) + """ + mxlen = 256 + length = ctypes.c_ulong() + buf = ctypes.create_string_buffer(mxlen) + _LIB.XGCommunicatorGetProcessorName(buf, ctypes.byref(length), mxlen) + return buf.value + + +T = TypeVar("T") # pylint:disable=invalid-name + + +def broadcast(data: T, root: int) -> T: + """Broadcast object from one node to all other nodes. + + Parameters + ---------- + data : any type that can be pickled + Input data, if current rank does not equal root, this can be None + root : int + Rank of the node to broadcast data from. + + Returns + ------- + object : int + the result of broadcast. + """ + rank = get_rank() + length = ctypes.c_ulong() + if root == rank: + assert data is not None, 'need to pass in data when broadcasting' + s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + length.value = len(s) + # run first broadcast + _check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length), + ctypes.sizeof(ctypes.c_ulong), root)) + if root != rank: + dptr = (ctypes.c_char * length.value)() + # run second + _check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p), + length.value, root)) + data = pickle.loads(dptr.raw) + del dptr + else: + _check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), + length.value, root)) + del s + return data + + +# enumeration of dtypes +DTYPE_ENUM__ = { + np.dtype('int8'): 0, + np.dtype('uint8'): 1, + np.dtype('int32'): 2, + np.dtype('uint32'): 3, + np.dtype('int64'): 4, + np.dtype('uint64'): 5, + np.dtype('float32'): 6, + np.dtype('float64'): 7 +} + + +@unique +class Op(IntEnum): + """Supported operations for rabit.""" + MAX = 0 + MIN = 1 + SUM = 2 + + +def allreduce( # pylint:disable=invalid-name + data: np.ndarray, op: Op +) -> np.ndarray: + """Perform allreduce, return the result. + + Parameters + ---------- + data : + Input data. + op : + Reduction operators, can be MAX or SUM + + Returns + ------- + result : + The result of allreduce, have same shape as data + + Notes + ----- + This function is not thread-safe. + """ + if not isinstance(data, np.ndarray): + raise Exception('allreduce only takes in numpy.ndarray') + buf = data.ravel() + if buf.base is data.base: + buf = buf.copy() + if buf.dtype not in DTYPE_ENUM__: + raise Exception(f"data type {buf.dtype} not supported") + _check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p), + buf.size, DTYPE_ENUM__[buf.dtype], + int(op), None, None)) + return buf + + +class CommunicatorContext: + """A context controlling collective communicator initialization and finalization.""" + + def __init__(self, args: List[bytes] = None) -> None: + if args is None: + args = [] + self.args = args + + def __enter__(self) -> None: + init(self.args) + assert is_distributed() + LOGGER.debug("-------------- communicator say hello ------------------") + + def __exit__(self, *args: List) -> None: + finalize() + LOGGER.debug("--------------- communicator say bye ------------------") + + +# initialization script +_init_collective() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index daabc45d26e7..56f31c0b0597 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -22,6 +22,8 @@ #include "c_api_error.h" #include "c_api_utils.h" +#include "../collective/communicator.h" +#include "../collective/communicator_factory.h" #include "../common/io.h" #include "../common/charconv.h" #include "../data/adapter.h" @@ -1346,6 +1348,64 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, API_END(); } +using xgboost::collective::CommunicatorFactory; + +XGB_DLL int XGCommunicatorInit(int argc, char *argv[]) { + API_BEGIN(); + CommunicatorFactory::Init(argc, argv); + API_END(); +} + +XGB_DLL int XGCommunicatorFinalize(void) { + API_BEGIN(); + CommunicatorFactory::Finalize(); + API_END(); +} + +XGB_DLL int XGCommunicatorGetRank(void) { + return CommunicatorFactory::GetInstance()->GetCommunicator()->GetRank(); +} + +XGB_DLL int XGCommunicatorGetWorldSize(void) { + return CommunicatorFactory::GetInstance()->GetCommunicator()->GetWorldSize(); +} + +XGB_DLL int XGCommunicatorIsDistributed(void) { + return CommunicatorFactory::GetInstance()->GetCommunicator()->IsDistributed(); +} + +XGB_DLL int XGCommunicatorPrint(char const *message) { + API_BEGIN(); + CommunicatorFactory::GetInstance()->GetCommunicator()->Print(message); + API_END(); +} + +XGB_DLL int XGCommunicatorGetProcessorName(char *out_name, bst_ulong *out_len, bst_ulong max_len) { + API_BEGIN(); + auto s = CommunicatorFactory::GetInstance()->GetCommunicator()->GetProcessorName(); + if (s.length() > max_len) { + s.resize(max_len - 1); + } + s.copy(out_name, s.length()); + *out_len = static_cast(s.length()); + API_END(); +} + +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, std::size_t size, int root) { + API_BEGIN(); + CommunicatorFactory::GetInstance()->GetCommunicator()->Broadcast(send_receive_buffer, size, root); + API_END(); +} + +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, std::size_t count, int enum_dtype, + int enum_op) { + API_BEGIN(); + CommunicatorFactory::GetInstance()->GetCommunicator()->AllReduce( + send_receive_buffer, count, static_cast(enum_dtype), + static_cast(enum_op)); + API_END(); +} + #if defined(XGBOOST_USE_FEDERATED) XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path, char const *server_cert_path, char const *client_cert_path) { diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 7d96b1b4aec2..50c182745f50 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -8,13 +8,37 @@ namespace xgboost { namespace collective { /** @brief Defines the integral and floating data types. */ -enum class DataType { kInt, kFloat, kDouble, kSizeT, kUInt64 }; +enum class DataType { + kInt8 = 0, + kUInt8 = 1, + kInt32 = 2, + kUInt32 = 3, + kInt64 = 4, + kUInt64 = 5, + kFloat = 6, + kDouble = 7 +}; inline std::size_t GetTypeSize(DataType data_type) { std::size_t size{0}; switch (data_type) { - case DataType::kInt: - size = sizeof(int); + case DataType::kInt8: + size = sizeof(std::int8_t); + break; + case DataType::kUInt8: + size = sizeof(std::uint8_t); + break; + case DataType::kInt32: + size = sizeof(std::int32_t); + break; + case DataType::kUInt32: + size = sizeof(std::uint32_t); + break; + case DataType::kInt64: + size = sizeof(std::int64_t); + break; + case DataType::kUInt64: + size = sizeof(std::uint64_t); break; case DataType::kFloat: size = sizeof(float); @@ -22,12 +46,6 @@ inline std::size_t GetTypeSize(DataType data_type) { case DataType::kDouble: size = sizeof(double); break; - case DataType::kSizeT: - size = sizeof(std::size_t); - break; - case DataType::kUInt64: - size = sizeof(uint64_t); - break; default: LOG(FATAL) << "Unknown data type."; } @@ -35,7 +53,7 @@ inline std::size_t GetTypeSize(DataType data_type) { } /** @brief Defines the reduction operation. */ -enum class Operation { kMax, kSum }; +enum class Operation { kMax = 0, kMin = 1, kSum = 2 }; /** * @brief A communicator class that handles collective communication. diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index 266dc1d0f9ca..794049bfcb1c 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -41,7 +41,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { segments->clear(); segments->resize(world_size, 0); segments->at(rank) = length_bytes; - communicator_->AllReduce(segments->data(), segments->size(), DataType::kSizeT, Operation::kMax); + communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, + Operation::kMax); auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); receive_buffer->resize(total_bytes); diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index 54bff82f9dbe..ad9f57589c53 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -78,7 +78,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator { segments->clear(); segments->resize(world_size, 0); segments->at(rank) = length_bytes; - communicator_->AllReduce(segments->data(), segments->size(), DataType::kSizeT, Operation::kMax); + communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, + Operation::kMax); auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); receive_buffer->resize(total_bytes); diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index a4c9dfab9f94..d7d2a460d3bb 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -18,8 +18,23 @@ class RabitCommunicator : public Communicator { void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { switch (data_type) { - case DataType::kInt: - DoAllReduce(send_receive_buffer, count, op); + case DataType::kInt8: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kUInt8: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kInt32: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kUInt32: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kInt64: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kUInt64: + DoAllReduce(send_receive_buffer, count, op); break; case DataType::kFloat: DoAllReduce(send_receive_buffer, count, op); @@ -27,9 +42,6 @@ class RabitCommunicator : public Communicator { case DataType::kDouble: DoAllReduce(send_receive_buffer, count, op); break; - case DataType::kSizeT: - DoAllReduce(send_receive_buffer, count, op); - break; default: LOG(FATAL) << "Unknown data type"; } diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 6465ccb207f2..805d2580c4ac 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -45,7 +45,7 @@ class FederatedCommunicatorTest : public ::testing::Test { static void CheckAllreduce(FederatedCommunicator& comm) { int buffer[] = {1, 2, 3, 4, 5}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt, Operation::kSum); + comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); int expected[] = {3, 6, 9, 12, 15}; for (auto i = 0; i < 5; i++) { EXPECT_EQ(buffer[i], expected[i]); From 695de5f09d4c7c39e5143794b13cc6c0f07193ab Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 22 Jul 2022 14:41:46 -0700 Subject: [PATCH 12/37] clean up types --- plugin/federated/engine_federated.cc | 57 ++----------------- plugin/federated/federated.proto | 14 ++--- plugin/federated/federated_communicator.h | 54 +----------------- plugin/federated/federated_server.cc | 51 +++++++---------- src/collective/communicator_factory.cu | 2 + .../test_nccl_device_communicator.cu | 4 ++ tests/cpp/plugin/test_federated_server.cc | 2 +- 7 files changed, 42 insertions(+), 142 deletions(-) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index 9b43c3997cc3..f767e2f88edb 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -66,7 +66,9 @@ class FederatedEngine : public IEngine { void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { auto *buffer = reinterpret_cast(sendrecvbuf); std::string const send_buffer(buffer, size); - auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op)); + auto const receive_buffer = + client_->Allreduce(send_buffer, static_cast(dtype), + static_cast(op)); receive_buffer.copy(buffer, size); } @@ -84,13 +86,9 @@ class FederatedEngine : public IEngine { } } - int LoadCheckPoint() override { - return 0; - } + int LoadCheckPoint() override { return 0; } - void CheckPoint() override { - version_number_ += 1; - } + void CheckPoint() override { version_number_ += 1; } int VersionNumber() const override { return version_number_; } @@ -112,51 +110,6 @@ class FederatedEngine : public IEngine { } private: - /** @brief Transform mpi::DataType to xgboost::federated::DataType. */ - static xgboost::federated::DataType GetDataType(mpi::DataType data_type) { - switch (data_type) { - case mpi::kChar: - return xgboost::federated::CHAR; - case mpi::kUChar: - return xgboost::federated::UCHAR; - case mpi::kInt: - return xgboost::federated::INT; - case mpi::kUInt: - return xgboost::federated::UINT; - case mpi::kLong: - return xgboost::federated::LONG; - case mpi::kULong: - return xgboost::federated::ULONG; - case mpi::kFloat: - return xgboost::federated::FLOAT; - case mpi::kDouble: - return xgboost::federated::DOUBLE; - case mpi::kLongLong: - return xgboost::federated::LONGLONG; - case mpi::kULongLong: - return xgboost::federated::ULONGLONG; - } - utils::Error("unknown mpi::DataType"); - return xgboost::federated::CHAR; - } - - /** @brief Transform mpi::OpType to enum to MPI OP */ - static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) { - switch (op_type) { - case mpi::kMax: - return xgboost::federated::MAX; - case mpi::kMin: - return xgboost::federated::MIN; - case mpi::kSum: - return xgboost::federated::SUM; - case mpi::kBitwiseOR: - utils::Error("Bitwise OR is not supported"); - return xgboost::federated::MAX; - } - utils::Error("unknown mpi::OpType"); - return xgboost::federated::MAX; - } - void SetParam(std::string const &name, std::string const &val) { if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { server_address_ = val; diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index cba897c0ea81..5a338ba0d8f9 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -12,16 +12,14 @@ service Federated { } enum DataType { - CHAR = 0; - UCHAR = 1; - INT = 2; - UINT = 3; - LONG = 4; - ULONG = 5; + INT8 = 0; + UINT8 = 1; + INT32 = 2; + UINT32 = 3; + INT64 = 4; + UINT64 = 5; FLOAT = 6; DOUBLE = 7; - LONGLONG = 8; - ULONGLONG = 9; } enum ReduceOperation { diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 524c5b3bce78..469db9d7fec0 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -42,7 +42,8 @@ class FederatedCommunicator : public Communicator { std::string const send_buffer(reinterpret_cast(send_receive_buffer), count * GetTypeSize(data_type)); auto const received = - client_->Allreduce(send_buffer, ConvertDataType(data_type), ConvertOperation(op)); + client_->Allreduce(send_buffer, static_cast(data_type), + static_cast(op)); received.copy(reinterpret_cast(send_receive_buffer), count * GetTypeSize(data_type)); } @@ -62,57 +63,6 @@ class FederatedCommunicator : public Communicator { void Print(const std::string &message) override { LOG(CONSOLE) << message; } private: - static xgboost::federated::DataType ConvertDataType(DataType data_type) { - xgboost::federated::DataType result{}; - switch (data_type) { - case DataType::kInt8: - result = xgboost::federated::DataType::CHAR; - break; - case DataType::kUInt8: - result = xgboost::federated::DataType::UCHAR; - break; - case DataType::kInt32: - result = xgboost::federated::DataType::INT; - break; - case DataType::kUInt32: - result = xgboost::federated::DataType::UINT; - break; - case DataType::kInt64: - result = xgboost::federated::DataType::LONG; - break; - case DataType::kUInt64: - result = xgboost::federated::DataType::ULONG; - break; - case DataType::kFloat: - result = xgboost::federated::DataType::FLOAT; - break; - case DataType::kDouble: - result = xgboost::federated::DataType::DOUBLE; - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return result; - } - - static xgboost::federated::ReduceOperation ConvertOperation(Operation operation) { - xgboost::federated::ReduceOperation result{}; - switch (operation) { - case Operation::kMax: - result = xgboost::federated::ReduceOperation::MAX; - break; - case Operation::kMin: - result = xgboost::federated::ReduceOperation::MIN; - break; - case Operation::kSum: - result = xgboost::federated::ReduceOperation::SUM; - break; - default: - LOG(FATAL) << "Unknown reduce operation."; - } - return result; - } - std::unique_ptr client_{}; }; diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index b569cd33df2f..8d39a06afc4f 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -73,32 +73,35 @@ class AllreduceFunctor { void Accumulate(std::string& buffer, std::string const& input, DataType data_type, ReduceOperation reduce_operation) const { switch (data_type) { - case DataType::CHAR: - Accumulate(&buffer[0], reinterpret_cast(input.data()), buffer.size(), + case DataType::INT8: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), buffer.size(), reduce_operation); break; - case DataType::UCHAR: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), buffer.size(), + case DataType::UINT8: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), buffer.size(), reduce_operation); break; - case DataType::INT: - Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), - buffer.size() / sizeof(int), reduce_operation); + case DataType::INT32: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::uint32_t), reduce_operation); break; - case DataType::UINT: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(unsigned int), reduce_operation); + case DataType::UINT32: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::uint32_t), reduce_operation); break; - case DataType::LONG: - Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), - buffer.size() / sizeof(long), reduce_operation); + case DataType::INT64: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::int64_t), reduce_operation); break; - case DataType::ULONG: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(unsigned long), reduce_operation); + case DataType::UINT64: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::uint64_t), reduce_operation); break; case DataType::FLOAT: Accumulate(reinterpret_cast(&buffer[0]), @@ -110,16 +113,6 @@ class AllreduceFunctor { reinterpret_cast(input.data()), buffer.size() / sizeof(double), reduce_operation); break; - case DataType::LONGLONG: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(long long), reduce_operation); - break; - case DataType::ULONGLONG: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(unsigned long long), reduce_operation); - break; default: throw std::invalid_argument("Invalid data type"); } diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu index 8ba388342a86..fb93b31b0430 100644 --- a/src/collective/communicator_factory.cu +++ b/src/collective/communicator_factory.cu @@ -3,7 +3,9 @@ */ #include "communicator_factory.h" #include "device_communicator_adapter.cuh" +#ifdef XGBOOST_USE_NCCL #include "nccl_device_communicator.cuh" +#endif #include "rabit_communicator.h" namespace xgboost { diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu index 168d6f60b844..47de054c6d4b 100644 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -1,6 +1,8 @@ /*! * Copyright 2022 XGBoost contributors */ +#ifdef XGBOOST_USE_NCCL + #include #include "../../../src/collective/nccl_device_communicator.cuh" @@ -20,3 +22,5 @@ TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) { } // namespace collective } // namespace xgboost + +#endif diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index b20c17b09de5..1c3e4f0bc84c 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -62,7 +62,7 @@ class FederatedServerTest : public ::testing::Test { static void CheckAllreduce(federated::FederatedClient& client) { int data[] = {1, 2, 3, 4, 5}; std::string send_buffer(reinterpret_cast(data), sizeof(data)); - auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM); + auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM); auto const* result = reinterpret_cast(reply.data()); int expected[] = {3, 6, 9, 12, 15}; for (auto i = 0; i < 5; i++) { From e4d00296f373a74f6f341ade48e35407c1116a55 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 22 Jul 2022 15:28:59 -0700 Subject: [PATCH 13/37] fix non-gpu build --- src/collective/communicator_factory.cc | 61 ++++++++++++++++++++++++++ src/collective/communicator_factory.cu | 11 +++-- src/collective/communicator_factory.h | 8 ++-- 3 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 src/collective/communicator_factory.cc diff --git a/src/collective/communicator_factory.cc b/src/collective/communicator_factory.cc new file mode 100644 index 000000000000..fab6f7298346 --- /dev/null +++ b/src/collective/communicator_factory.cc @@ -0,0 +1,61 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "communicator_factory.h" + +#include "rabit_communicator.h" + +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_communicator.h" +#endif + +namespace xgboost { +namespace collective { + +#ifndef XGBOOST_USE_CUDA +thread_local std::unique_ptr CommunicatorFactory::instance_{}; + +CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) + : type_{type}, communicator_{communicator} {} + +void CommunicatorFactory::Init(int argc, char* argv[]) { + if (instance_) { + LOG(FATAL) << "Communicator factory can only be initialized once."; + } + + auto type = GetTypeFromEnv(); + auto const arg = GetTypeFromArgs(argc, argv); + if (arg != CommunicatorType::kUnknown) { + type = arg; + } + switch (type) { + case CommunicatorType::kRabit: { + RabitCommunicatorFactory factory{argc, argv}; + auto* comm = factory.Create(); + instance_.reset(new CommunicatorFactory(type, comm)); + break; + } + case CommunicatorType::kMPI: + LOG(FATAL) << "Not implemented yet."; + break; + case CommunicatorType::kFederated: { +#if defined(XGBOOST_USE_FEDERATED) + FederatedCommunicatorFactory factory{argc, argv}; + auto* comm = factory.Create(); + instance_.reset(new CommunicatorFactory(type, comm)); +#else + LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; +#endif + break; + } + case CommunicatorType::kUnknown: + LOG(FATAL) << "Unknown communicator type."; + break; + } +} + +void CommunicatorFactory::Finalize() { instance_.reset(); } +#endif + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu index fb93b31b0430..1cec0a9045a2 100644 --- a/src/collective/communicator_factory.cu +++ b/src/collective/communicator_factory.cu @@ -3,16 +3,22 @@ */ #include "communicator_factory.h" #include "device_communicator_adapter.cuh" +#include "rabit_communicator.h" #ifdef XGBOOST_USE_NCCL #include "nccl_device_communicator.cuh" #endif -#include "rabit_communicator.h" +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_communicator.h" +#endif namespace xgboost { namespace collective { thread_local std::unique_ptr CommunicatorFactory::instance_{}; +CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) + : type_{type}, communicator_{communicator} {} + void CommunicatorFactory::Init(int argc, char* argv[]) { if (instance_) { LOG(FATAL) << "Communicator factory can only be initialized once."; @@ -51,9 +57,6 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { void CommunicatorFactory::Finalize() { instance_.reset(); } -CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) - : type_{type}, communicator_{communicator} {} - DeviceCommunicator* CommunicatorFactory::GetDeviceCommunicator(int device_ordinal) { if (!device_communicator_) { #ifdef XGBOOST_USE_NCCL diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h index 882f54d36891..d969b647260f 100644 --- a/src/collective/communicator_factory.h +++ b/src/collective/communicator_factory.h @@ -4,10 +4,6 @@ #pragma once #include "communicator.h" -#if defined(XGBOOST_USE_FEDERATED) -#include "../../plugin/federated/federated_communicator.h" -#endif - namespace xgboost { namespace collective { @@ -27,7 +23,9 @@ class CommunicatorFactory { Communicator* GetCommunicator() { return communicator_.get(); } +#if defined(XGBOOST_USE_CUDA) DeviceCommunicator* GetDeviceCommunicator(int device_ordinal); +#endif /** @brief Get the communicator type from environment variables. Visible for testing. */ static CommunicatorType GetTypeFromEnv() { @@ -76,7 +74,9 @@ class CommunicatorFactory { static thread_local std::unique_ptr instance_; CommunicatorType type_; std::unique_ptr communicator_; +#if defined(XGBOOST_USE_CUDA) std::unique_ptr device_communicator_; +#endif }; } // namespace collective From d656387719f7728df54778e3004c3dc2460d6337 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 22 Jul 2022 16:02:31 -0700 Subject: [PATCH 14/37] try to fix ci --- amalgamation/xgboost-all0.cc | 3 +++ src/collective/communicator_factory.h | 2 ++ 2 files changed, 5 insertions(+) diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 2cbde50a0f41..a67e097f5b6c 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -70,6 +70,9 @@ #include "../src/logging.cc" #include "../src/global_config.cc" +// collective +#include "../src/collective/communicator_factory.cc" + // common #include "../src/common/common.cc" #include "../src/common/column_matrix.cc" diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h index d969b647260f..3ea157c8f894 100644 --- a/src/collective/communicator_factory.h +++ b/src/collective/communicator_factory.h @@ -2,6 +2,8 @@ * Copyright 2022 XGBoost contributors */ #pragma once +#include + #include "communicator.h" namespace xgboost { From 92ae35e81fb571002949f498835d060cc884a94d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 22 Jul 2022 16:10:14 -0700 Subject: [PATCH 15/37] fix std::size_t --- include/xgboost/c_api.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index c8bb61d8d5cd..57a2b7b7086f 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -9,10 +9,12 @@ #ifdef __cplusplus #define XGB_EXTERN_C extern "C" +#include #include #include #else #define XGB_EXTERN_C +#include #include #include #endif // __cplusplus From de52150fad41294a0b616c9cf6e8c6e8c7706ea5 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 22 Jul 2022 22:38:42 -0700 Subject: [PATCH 16/37] portable string compare ignore case --- plugin/federated/federated_communicator.h | 12 ++++++------ src/collective/communicator.h | 9 +++++++++ src/collective/communicator_factory.h | 10 ++++------ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 469db9d7fec0..babace2cdf5a 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -119,17 +119,17 @@ class FederatedCommunicatorFactory { private: void SetParam(std::string const &name, std::string const &val) { - if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { + if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { server_address_ = val; - } else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) { + } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_WORLD_SIZE")) { world_size_ = std::stoi(val); - } else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { + } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_RANK")) { rank_ = std::stoi(val); - } else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) { + } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_SERVER_CERT")) { server_cert_ = val; - } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) { + } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_CLIENT_KEY")) { client_key_ = val; - } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) { + } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_CLIENT_CERT")) { client_cert_ = val; } } diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 50c182745f50..77d7fafb3987 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -125,5 +125,14 @@ class Communicator { int const rank_; }; +/* \brief Case-insensitive string comparison */ +inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) { +#ifdef _MSC_VER + return _stricmp(s1, s2); +#else // _MSC_VER + return strcasecmp(s1, s2); +#endif // _MSC_VER +} + } // namespace collective } // namespace xgboost diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h index 3ea157c8f894..bbd63b4337c3 100644 --- a/src/collective/communicator_factory.h +++ b/src/collective/communicator_factory.h @@ -2,8 +2,6 @@ * Copyright 2022 XGBoost contributors */ #pragma once -#include - #include "communicator.h" namespace xgboost { @@ -47,7 +45,7 @@ class CommunicatorFactory { if (delimiter != std::string::npos) { auto const key = key_value.substr(0, delimiter); auto const value = key_value.substr(delimiter + 1); - if (!strcasecmp(key.c_str(), kCommunicatorKey)) { + if (!CompareStringsCaseInsensitive(key.c_str(), kCommunicatorKey)) { return StringToType(value.c_str()); } } @@ -61,11 +59,11 @@ class CommunicatorFactory { private: static CommunicatorType StringToType(char const* str) { CommunicatorType result = CommunicatorType::kUnknown; - if (!strcasecmp("rabit", str)) { + if (!CompareStringsCaseInsensitive("rabit", str)) { result = CommunicatorType::kRabit; - } else if (!strcasecmp("mpi", str)) { + } else if (!CompareStringsCaseInsensitive("mpi", str)) { result = CommunicatorType::kMPI; - } else if (!strcasecmp("federated", str)) { + } else if (!CompareStringsCaseInsensitive("federated", str)) { result = CommunicatorType::kFederated; } else { LOG(FATAL) << "Unknown communicator type " << str; From 4e2a5b8b294d2ae29f4e41823d7869e9a941e3f3 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 15 Aug 2022 10:29:15 -0700 Subject: [PATCH 17/37] c style size_t --- include/xgboost/c_api.h | 4 ++-- src/c_api/c_api.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 5b2ee606291e..31f0990e6db6 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1450,7 +1450,7 @@ XGB_DLL int XGCommunicatorGetProcessorName(char *out_name, * \param size the size of the data * \param root the root of process */ -XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, std::size_t size, int root); +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root); /*! * \brief perform in-place allreduce, on sendrecvbuf @@ -1466,7 +1466,7 @@ XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, std::size_t size, * \param enum_dtype the enumeration of data type, see xgboost::collective::DataType in communicator.h * \param enum_op the enumeration of operation type, see xgboost::collective::Operation in communicator.h */ -XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, std::size_t count, int enum_dtype, +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, int enum_op); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 074c20e3477f..1d50636a5fb2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1415,13 +1415,13 @@ XGB_DLL int XGCommunicatorGetProcessorName(char *out_name, bst_ulong *out_len, b API_END(); } -XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, std::size_t size, int root) { +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { API_BEGIN(); CommunicatorFactory::GetInstance()->GetCommunicator()->Broadcast(send_receive_buffer, size, root); API_END(); } -XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, std::size_t count, int enum_dtype, +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, int enum_op) { API_BEGIN(); CommunicatorFactory::GetInstance()->GetCommunicator()->AllReduce( From 131f4b1260087f349415ba701e9909a2c2a2572c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 15 Aug 2022 11:00:22 -0700 Subject: [PATCH 18/37] fix lint errors --- src/collective/communicator.h | 8 +++++--- src/collective/communicator_factory.h | 3 +++ src/collective/rabit_communicator.h | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 77d7fafb3987..cfa3d6863d3e 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -4,6 +4,8 @@ #pragma once #include +#include + namespace xgboost { namespace collective { @@ -87,7 +89,7 @@ class Communicator { int GetRank() const { return rank_; } /** @brief Whether the communicator is running in distributed mode. */ - bool IsDistributed() const { return world_size_ > 1; }; + bool IsDistributed() const { return world_size_ > 1; } /** * @brief Combines values from all processes and distributes the result back to all processes. @@ -126,10 +128,10 @@ class Communicator { }; /* \brief Case-insensitive string comparison */ -inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) { +inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { #ifdef _MSC_VER return _stricmp(s1, s2); -#else // _MSC_VER +#else // _MSC_VER return strcasecmp(s1, s2); #endif // _MSC_VER } diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h index bbd63b4337c3..d4568d95be5a 100644 --- a/src/collective/communicator_factory.h +++ b/src/collective/communicator_factory.h @@ -2,6 +2,9 @@ * Copyright 2022 XGBoost contributors */ #pragma once +#include +#include + #include "communicator.h" namespace xgboost { diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index d7d2a460d3bb..d6e930258fdb 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -4,6 +4,8 @@ #pragma once #include +#include + #include "communicator.h" namespace xgboost { From 88510a3c4ce1b6fb92e6abb1dca823fbb9125255 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 15 Aug 2022 11:44:34 -0700 Subject: [PATCH 19/37] cross platform setenv --- .../collective/test_communicator_factory.cc | 10 +- .../cpp/plugin/test_federated_communicator.cc | 117 +++++++++--------- 2 files changed, 64 insertions(+), 63 deletions(-) diff --git a/tests/cpp/collective/test_communicator_factory.cc b/tests/cpp/collective/test_communicator_factory.cc index 3381e78a5b7a..62473bade099 100644 --- a/tests/cpp/collective/test_communicator_factory.cc +++ b/tests/cpp/collective/test_communicator_factory.cc @@ -1,6 +1,7 @@ /*! * Copyright 2022 XGBoost contributors */ +#include #include #include "../../../src/collective/communicator_factory.h" @@ -9,19 +10,18 @@ namespace xgboost { namespace collective { TEST(CommunicatorFactory, TypeFromEnv) { - unsetenv(CommunicatorFactory::kCommunicatorKey); EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromEnv()); - setenv(CommunicatorFactory::kCommunicatorKey, "rabit", 1); + dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "rabit"); EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromEnv()); - setenv(CommunicatorFactory::kCommunicatorKey, "MPI", 1); + dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "MPI"); EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromEnv()); - setenv(CommunicatorFactory::kCommunicatorKey, "Federated", 1); + dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "Federated"); EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromEnv()); - setenv(CommunicatorFactory::kCommunicatorKey, "foo", 1); + dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "foo"); EXPECT_THROW(CommunicatorFactory::GetTypeFromEnv(), dmlc::Error); } diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 805d2580c4ac..2bf9d5102284 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -1,6 +1,7 @@ /*! * Copyright 2022 XGBoost contributors */ +#include #include #include @@ -43,7 +44,7 @@ class FederatedCommunicatorTest : public ::testing::Test { server_thread_->join(); } - static void CheckAllreduce(FederatedCommunicator& comm) { + static void CheckAllreduce(FederatedCommunicator &comm) { int buffer[] = {1, 2, 3, 4, 5}; comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); int expected[] = {3, 6, 9, 12, 15}; @@ -52,7 +53,7 @@ class FederatedCommunicatorTest : public ::testing::Test { } } - static void CheckBroadcast(FederatedCommunicator& comm, int rank) { + static void CheckBroadcast(FederatedCommunicator &comm, int rank) { if (rank == 0) { std::string buffer{"hello"}; comm.Broadcast(&buffer[0], buffer.size(), 0); @@ -105,7 +106,7 @@ TEST_F(FederatedCommunicatorTest, Allreduce) { for (auto rank = 0; rank < kWorldSize; rank++) { threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank)); } - for (auto& thread : threads) { + for (auto &thread : threads) { thread.join(); } } @@ -115,93 +116,93 @@ TEST_F(FederatedCommunicatorTest, Broadcast) { for (auto rank = 0; rank < kWorldSize; rank++) { threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank)); } - for (auto& thread : threads) { + for (auto &thread : threads) { thread.join(); } } TEST(FederatedCommunicatorFactoryTest, ServerAddress) { - FederatedCommunicatorFactory factory0{0, nullptr}; - EXPECT_EQ(factory0.GetServerAddress(), ""); + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetServerAddress(), ""); - setenv("FEDERATED_SERVER_ADDRESS", "localhost:9091", 1); - FederatedCommunicatorFactory factory1{0, nullptr}; - EXPECT_EQ(factory1.GetServerAddress(), "localhost:9091"); + dmlc::SetEnv("FEDERATED_SERVER_ADDRESS", "localhost:9091"); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetServerAddress(), "localhost:9091"); - char *args[1]; - args[0] = strdup("federated_server_address=foo:9091"); - FederatedCommunicatorFactory factory2{1, args}; - EXPECT_EQ(factory2.GetServerAddress(), "foo:9091"); + char *args[1]; + args[0] = strdup("federated_server_address=foo:9091"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetServerAddress(), "foo:9091"); } TEST(FederatedCommunicatorFactoryTest, WorldSize) { - FederatedCommunicatorFactory factory0{0, nullptr}; - EXPECT_EQ(factory0.GetWorldSize(), 0); + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetWorldSize(), 0); - setenv("FEDERATED_WORLD_SIZE", "2", 1); - FederatedCommunicatorFactory factory1{0, nullptr}; - EXPECT_EQ(factory1.GetWorldSize(), 2); + dmlc::SetEnv("FEDERATED_WORLD_SIZE", "2"); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetWorldSize(), 2); - char *args[1]; - args[0] = strdup("federated_world_size=3"); - FederatedCommunicatorFactory factory2{1, args}; - EXPECT_EQ(factory2.GetWorldSize(), 3); + char *args[1]; + args[0] = strdup("federated_world_size=3"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetWorldSize(), 3); } TEST(FederatedCommunicatorFactoryTest, Rank) { - FederatedCommunicatorFactory factory0{0, nullptr}; - EXPECT_EQ(factory0.GetRank(), -1); + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetRank(), -1); - setenv("FEDERATED_RANK", "1", 1); - FederatedCommunicatorFactory factory1{0, nullptr}; - EXPECT_EQ(factory1.GetRank(), 1); + dmlc::SetEnv("FEDERATED_RANK", "1"); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetRank(), 1); - char *args[1]; - args[0] = strdup("federated_rank=2"); - FederatedCommunicatorFactory factory2{1, args}; - EXPECT_EQ(factory2.GetRank(), 2); + char *args[1]; + args[0] = strdup("federated_rank=2"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetRank(), 2); } TEST(FederatedCommunicatorFactoryTest, ServerCert) { - FederatedCommunicatorFactory factory0{0, nullptr}; - EXPECT_EQ(factory0.GetServerCert(), ""); + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetServerCert(), ""); - setenv("FEDERATED_SERVER_CERT", "foo/server-cert.pem", 1); - FederatedCommunicatorFactory factory1{0, nullptr}; - EXPECT_EQ(factory1.GetServerCert(), "foo/server-cert.pem"); + dmlc::SetEnv("FEDERATED_SERVER_CERT", "foo/server-cert.pem"); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetServerCert(), "foo/server-cert.pem"); - char *args[1]; - args[0] = strdup("federated_server_cert=bar/server-cert.pem"); - FederatedCommunicatorFactory factory2{1, args}; - EXPECT_EQ(factory2.GetServerCert(), "bar/server-cert.pem"); + char *args[1]; + args[0] = strdup("federated_server_cert=bar/server-cert.pem"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetServerCert(), "bar/server-cert.pem"); } TEST(FederatedCommunicatorFactoryTest, ClientKey) { - FederatedCommunicatorFactory factory0{0, nullptr}; - EXPECT_EQ(factory0.GetClientKey(), ""); + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetClientKey(), ""); - setenv("FEDERATED_CLIENT_KEY", "foo/client-key.pem", 1); - FederatedCommunicatorFactory factory1{0, nullptr}; - EXPECT_EQ(factory1.GetClientKey(), "foo/client-key.pem"); + dmlc::SetEnv("FEDERATED_CLIENT_KEY", "foo/client-key.pem"); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetClientKey(), "foo/client-key.pem"); - char *args[1]; - args[0] = strdup("federated_client_key=bar/client-key.pem"); - FederatedCommunicatorFactory factory2{1, args}; - EXPECT_EQ(factory2.GetClientKey(), "bar/client-key.pem"); + char *args[1]; + args[0] = strdup("federated_client_key=bar/client-key.pem"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetClientKey(), "bar/client-key.pem"); } TEST(FederatedCommunicatorFactoryTest, ClientCert) { - FederatedCommunicatorFactory factory0{0, nullptr}; - EXPECT_EQ(factory0.GetClientCert(), ""); + FederatedCommunicatorFactory factory0{0, nullptr}; + EXPECT_EQ(factory0.GetClientCert(), ""); - setenv("FEDERATED_CLIENT_CERT", "foo/client-cert.pem", 1); - FederatedCommunicatorFactory factory1{0, nullptr}; - EXPECT_EQ(factory1.GetClientCert(), "foo/client-cert.pem"); + dmlc::SetEnv("FEDERATED_CLIENT_CERT", "foo/client-cert.pem"); + FederatedCommunicatorFactory factory1{0, nullptr}; + EXPECT_EQ(factory1.GetClientCert(), "foo/client-cert.pem"); - char *args[1]; - args[0] = strdup("federated_client_cert=bar/client-cert.pem"); - FederatedCommunicatorFactory factory2{1, args}; - EXPECT_EQ(factory2.GetClientCert(), "bar/client-cert.pem"); + char *args[1]; + args[0] = strdup("federated_client_cert=bar/client-cert.pem"); + FederatedCommunicatorFactory factory2{1, args}; + EXPECT_EQ(factory2.GetClientCert(), "bar/client-cert.pem"); } } // namespace collective From e758449c0d10c0be7cb7eb4a20f46b1fb884bfe9 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 15 Aug 2022 12:29:10 -0700 Subject: [PATCH 20/37] fix memory leak --- .../collective/test_communicator_factory.cc | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/cpp/collective/test_communicator_factory.cc b/tests/cpp/collective/test_communicator_factory.cc index 62473bade099..2b6b39c342a6 100644 --- a/tests/cpp/collective/test_communicator_factory.cc +++ b/tests/cpp/collective/test_communicator_factory.cc @@ -26,21 +26,20 @@ TEST(CommunicatorFactory, TypeFromEnv) { } TEST(CommunicatorFactory, TypeFromArgs) { - char *args[1]; - args[0] = strdup("foo=bar"); - EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromArgs(1, args)); + char *args0[] = {(char *)("foo=bar")}; + EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromArgs(1, args0)); - args[0] = strdup("xgboost_communicator=rabit"); - EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromArgs(1, args)); + char *args1[] = {(char *)("xgboost_communicator=rabit")}; + EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromArgs(1, args1)); - args[0] = strdup("xgboost_communicator=MPI"); - EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromArgs(1, args)); + char *args2[] = {(char *)("xgboost_communicator=MPI")}; + EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromArgs(1, args2)); - args[0] = strdup("xgboost_communicator=Federated"); - EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromArgs(1, args)); + char *args3[] = {(char *)("xgboost_communicator=federated")}; + EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromArgs(1, args3)); - args[0] = strdup("xgboost_communicator=foo"); - EXPECT_THROW(CommunicatorFactory::GetTypeFromArgs(1, args), dmlc::Error); + char *args4[] = {(char *)("xgboost_communicator=foo")}; + EXPECT_THROW(CommunicatorFactory::GetTypeFromArgs(1, args4), dmlc::Error); } } // namespace collective From 89236093eb604dd26f7eb8991331f862532704d0 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 15 Aug 2022 13:43:22 -0700 Subject: [PATCH 21/37] fix lint errors --- tests/cpp/collective/test_communicator_factory.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/cpp/collective/test_communicator_factory.cc b/tests/cpp/collective/test_communicator_factory.cc index 2b6b39c342a6..b43a9046aab3 100644 --- a/tests/cpp/collective/test_communicator_factory.cc +++ b/tests/cpp/collective/test_communicator_factory.cc @@ -26,19 +26,20 @@ TEST(CommunicatorFactory, TypeFromEnv) { } TEST(CommunicatorFactory, TypeFromArgs) { - char *args0[] = {(char *)("foo=bar")}; + char *args0[] = {(char *)("foo=bar")}; // NOLINT(google-readability-casting) EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromArgs(1, args0)); - char *args1[] = {(char *)("xgboost_communicator=rabit")}; + char *args1[] = {(char *)("xgboost_communicator=rabit")}; // NOLINT(google-readability-casting) EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromArgs(1, args1)); - char *args2[] = {(char *)("xgboost_communicator=MPI")}; + char *args2[] = {(char *)("xgboost_communicator=MPI")}; // NOLINT(google-readability-casting) EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromArgs(1, args2)); - char *args3[] = {(char *)("xgboost_communicator=federated")}; + char *args3[] = { + (char *)("xgboost_communicator=federated")}; // NOLINT(google-readability-casting) EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromArgs(1, args3)); - char *args4[] = {(char *)("xgboost_communicator=foo")}; + char *args4[] = {(char *)("xgboost_communicator=foo")}; // NOLINT(google-readability-casting) EXPECT_THROW(CommunicatorFactory::GetTypeFromArgs(1, args4), dmlc::Error); } From 3aeae6578995ec8cd36086c73bf6169f0916bd47 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 25 Aug 2022 10:25:59 -0700 Subject: [PATCH 22/37] address review feedback --- include/xgboost/c_api.h | 10 +++---- python-package/xgboost/collective.py | 40 +++++++++------------------- src/c_api/c_api.cc | 11 +++----- 3 files changed, 21 insertions(+), 40 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 31f0990e6db6..d450ad01d59b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1435,13 +1435,11 @@ XGB_DLL int XGCommunicatorPrint(char const *message); /*! * \brief get name of processor - * \param out_name hold output string - * \param out_len hold length of output string - * \param max_len maximum buffer length of input + * \param name_str pointer to received returned processor name. + * \return 0 for success, -1 for failure */ -XGB_DLL int XGCommunicatorGetProcessorName(char *out_name, - bst_ulong *out_len, - bst_ulong max_len); +XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str); + /*! * \brief broadcast an memory region to all others from root * diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 26a4d2e28fc4..862192bc9fe8 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -1,25 +1,18 @@ """XGBoost collective communication related API.""" import ctypes -from enum import IntEnum, unique import logging import pickle -from typing import Any, TypeVar, Optional, cast, List, Union +from enum import IntEnum, unique +from typing import Any, Optional, cast, List, Union import numpy as np -from .core import _LIB, c_str, _check_call +from ._typing import _T +from .core import _LIB, _check_call, c_str, py_str LOGGER = logging.getLogger("[xgboost.collective]") -def _init_collective() -> None: - """internal library initializer.""" - if _LIB is not None: - _LIB.XGCommunicatorGetRank.restype = ctypes.c_int - _LIB.XGCommunicatorGetWorldSize.restype = ctypes.c_int - _LIB.XGCommunicatorIsDistributed.restype = ctypes.c_int - - def init(args: Optional[List[bytes]] = None) -> None: """Initialize the collective library with arguments""" if args is None: @@ -84,7 +77,7 @@ def communicator_print(msg: Any) -> None: print(msg.strip(), flush=True) -def get_processor_name() -> bytes: +def get_processor_name() -> str: """Get the processor name. Returns @@ -92,17 +85,14 @@ def get_processor_name() -> bytes: name : str the name of processor(host) """ - mxlen = 256 - length = ctypes.c_ulong() - buf = ctypes.create_string_buffer(mxlen) - _LIB.XGCommunicatorGetProcessorName(buf, ctypes.byref(length), mxlen) - return buf.value + name_str = ctypes.c_char_p() + _check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str))) + value = name_str.value + assert value + return py_str(value) -T = TypeVar("T") # pylint:disable=invalid-name - - -def broadcast(data: T, root: int) -> T: +def broadcast(data: _T, root: int) -> _T: """Broadcast object from one node to all other nodes. Parameters @@ -161,8 +151,8 @@ class Op(IntEnum): SUM = 2 -def allreduce( # pylint:disable=invalid-name - data: np.ndarray, op: Op +def allreduce( # pylint:disable=invalid-name + data: np.ndarray, op: Op ) -> np.ndarray: """Perform allreduce, return the result. @@ -211,7 +201,3 @@ def __enter__(self) -> None: def __exit__(self, *args: List) -> None: finalize() LOGGER.debug("--------------- communicator say bye ------------------") - - -# initialization script -_init_collective() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 931ddab2b1ab..7608298a420e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1404,14 +1404,11 @@ XGB_DLL int XGCommunicatorPrint(char const *message) { API_END(); } -XGB_DLL int XGCommunicatorGetProcessorName(char *out_name, bst_ulong *out_len, bst_ulong max_len) { +XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { API_BEGIN(); - auto s = CommunicatorFactory::GetInstance()->GetCommunicator()->GetProcessorName(); - if (s.length() > max_len) { - s.resize(max_len - 1); - } - s.copy(out_name, s.length()); - *out_len = static_cast(s.length()); + auto& local = *GlobalConfigAPIThreadLocalStore::Get(); + local.ret_str = CommunicatorFactory::GetInstance()->GetCommunicator()->GetProcessorName(); + *name_str = local.ret_str.c_str(); API_END(); } From 183ab75b3a6c47798bc4ee4be8fcd8660bb61f82 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 26 Aug 2022 17:05:07 -0700 Subject: [PATCH 23/37] add python test for rabit communicator --- plugin/federated/federated_communicator.h | 2 ++ python-package/xgboost/__init__.py | 3 ++ python-package/xgboost/collective.py | 4 +-- src/collective/communicator.h | 2 +- src/collective/communicator_factory.cc | 5 ++- src/collective/communicator_factory.cu | 5 ++- src/collective/rabit_communicator.h | 2 ++ tests/python/test_collective.py | 39 +++++++++++++++++++++++ 8 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 tests/python/test_collective.py diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index babace2cdf5a..df5d21aaea5e 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -37,6 +37,8 @@ class FederatedCommunicator : public Communicator { ~FederatedCommunicator() override { client_.reset(); } + bool IsDistributed() const override { return true; } + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { std::string const send_buffer(reinterpret_cast(send_receive_buffer), diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 6c29de98d9dc..84bd40fb790f 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -5,6 +5,7 @@ from . import rabit # noqa from . import tracker # noqa +from . import collective from . import dask from .core import ( Booster, @@ -63,4 +64,6 @@ "XGBRFRegressor", # dask "dask", + # collective + "collective", ] diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 862192bc9fe8..f6ccd716d75b 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -19,12 +19,12 @@ def init(args: Optional[List[bytes]] = None) -> None: args = [] arr = (ctypes.c_char_p * len(args))() arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args) - _LIB.XGCommunicatorInit(len(arr), arr) + _check_call(_LIB.XGCommunicatorInit(len(arr), arr)) def finalize() -> None: """Finalize the process, notify tracker everything is done.""" - _LIB.XGCommunicatorFinalize() + _check_call(_LIB.XGCommunicatorFinalize()) def get_rank() -> int: diff --git a/src/collective/communicator.h b/src/collective/communicator.h index cfa3d6863d3e..05162f9c4a90 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -89,7 +89,7 @@ class Communicator { int GetRank() const { return rank_; } /** @brief Whether the communicator is running in distributed mode. */ - bool IsDistributed() const { return world_size_ > 1; } + virtual bool IsDistributed() const = 0; /** * @brief Combines values from all processes and distributes the result back to all processes. diff --git a/src/collective/communicator_factory.cc b/src/collective/communicator_factory.cc index fab6f7298346..40072a652ea8 100644 --- a/src/collective/communicator_factory.cc +++ b/src/collective/communicator_factory.cc @@ -28,6 +28,10 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { if (arg != CommunicatorType::kUnknown) { type = arg; } + if (type == CommunicatorType::kUnknown) { + // Default to Rabit if unspecified. + type = CommunicatorType::kRabit; + } switch (type) { case CommunicatorType::kRabit: { RabitCommunicatorFactory factory{argc, argv}; @@ -50,7 +54,6 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { } case CommunicatorType::kUnknown: LOG(FATAL) << "Unknown communicator type."; - break; } } diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu index 1cec0a9045a2..fc5bb4ae2785 100644 --- a/src/collective/communicator_factory.cu +++ b/src/collective/communicator_factory.cu @@ -29,6 +29,10 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { if (arg != CommunicatorType::kUnknown) { type = arg; } + if (type == CommunicatorType::kUnknown) { + // Default to Rabit if unspecified. + type = CommunicatorType::kRabit; + } switch (type) { case CommunicatorType::kRabit: { RabitCommunicatorFactory factory{argc, argv}; @@ -51,7 +55,6 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { } case CommunicatorType::kUnknown: LOG(FATAL) << "Unknown communicator type."; - break; } } diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index d6e930258fdb..b5adf2954986 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -17,6 +17,8 @@ class RabitCommunicator : public Communicator { ~RabitCommunicator() override { rabit::Finalize(); } + bool IsDistributed() const override { return rabit::IsDistributed(); } + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { switch (data_type) { diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py new file mode 100644 index 000000000000..74810bbe3afa --- /dev/null +++ b/tests/python/test_collective.py @@ -0,0 +1,39 @@ +import multiprocessing +import socket + +import numpy as np + +import xgboost as xgb +from xgboost import RabitTracker + + +def run_rabit_worker(rabit_env, world_size, rank): + with xgb.collective.CommunicatorContext(rabit_env): + assert xgb.collective.get_world_size() == world_size + assert xgb.collective.get_rank() == rank + assert xgb.collective.is_distributed() + assert xgb.collective.get_processor_name() == socket.gethostname() + ret = xgb.collective.broadcast('test1234', 0) + assert str(ret) == 'test1234' + ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) + assert np.array_equal(ret, np.asarray([2, 4, 6])) + + +def test_rabit_communicator(): + world_size = 2 + tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) + tracker.start(world_size) + worker_env = tracker.worker_envs() + rabit_env = [] + for k, v in worker_env.items(): + rabit_env.append(f"{k}={v}".encode()) + + workers = [] + for rank in reversed(range(world_size)): + worker = multiprocessing.Process(target=run_rabit_worker, + args=(rabit_env, world_size, rank)) + workers.append(worker) + worker.start() + for worker in workers: + worker.join() + assert worker.exitcode == 0 From e3c87e0780598950c64287070cae0a20dc3833ad Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 26 Aug 2022 18:46:58 -0700 Subject: [PATCH 24/37] fix failing gtest --- tests/cpp/collective/test_rabit_communicator.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/cpp/collective/test_rabit_communicator.cc b/tests/cpp/collective/test_rabit_communicator.cc index 2b7dc17edfcb..ba22d8fdb84f 100644 --- a/tests/cpp/collective/test_rabit_communicator.cc +++ b/tests/cpp/collective/test_rabit_communicator.cc @@ -30,13 +30,9 @@ TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) { } TEST(RabitCommunicatorSimpleTest, IsNotDistributed) { - RabitCommunicator comm{1, 0}; - EXPECT_FALSE(comm.IsDistributed()); -} - -TEST(RabitCommunicatorSimpleTest, IsDistributed) { RabitCommunicator comm{2, 1}; - EXPECT_TRUE(comm.IsDistributed()); + // Rabit is only distributed with a tracker. + EXPECT_FALSE(comm.IsDistributed()); } } // namespace collective From fcfb1d5a9e3e87d7a6b17c189ee14c8f5bb0456a Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 30 Aug 2022 16:56:46 -0700 Subject: [PATCH 25/37] use json to configure communicators --- include/xgboost/c_api.h | 50 +++++++++++--- plugin/federated/federated_communicator.h | 35 +++++++--- python-package/xgboost/collective.py | 67 +++++++++++++++---- src/c_api/c_api.cc | 5 +- src/collective/communicator_factory.cc | 8 +-- src/collective/communicator_factory.cu | 8 +-- src/collective/communicator_factory.h | 29 ++++---- src/collective/rabit_communicator.h | 32 ++++++++- .../collective/test_communicator_factory.cc | 46 ++++++++----- .../cpp/plugin/test_federated_communicator.cc | 65 +++++++++--------- tests/python/test_collective.py | 9 +-- 11 files changed, 240 insertions(+), 114 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index d450ad01d59b..3d63b93ac322 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1390,15 +1390,49 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config, float const **out_scores); /*! - * \brief initialize the collective communicator, - * call this once before using anything - * The additional arguments is not necessary. - * Usually the communicator will detect settings + * \brief Initialize the collective communicator. + * + * Call this once before using anything. + * + * The additional configuration is not required. Usually the communicator will detect settings * from environment variables. - * \param argc number of arguments in argv - * \param argv the array of input arguments - */ -XGB_DLL int XGCommunicatorInit(int argc, char *argv[]); + * + * \param json_config JSON encoded configuration. Accepted JSON keys are: + * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. + * * rabit: Use Rabit. This is the default if the type is unspecified. + * * mpi: Use MPI. + * * federated: Use the gRPC interface for Federated Learning. + * Only applicable to the Rabit communicator (these are case-sensitive): + * - rabit_tracker_uri: Hostname of the tracker. + * - rabit_tracker_port: Port number of the tracker. + * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. + * - rabit_world_size: Total number of workers. + * - rabit_hadoop_mode: Enable Hadoop support. + * - rabit_tree_reduce_minsize: Minimal size for tree reduce. + * - rabit_reduce_ring_mincount: Minimal count to perform ring reduce. + * - rabit_reduce_buffer: Size of the reduce buffer. + * - rabit_bootstrap_cache: Size of the bootstrap cache. + * - rabit_debug: Enable debugging. + * - rabit_timeout: Enable timeout. + * - rabit_timeout_sec: Timeout in seconds. + * - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. + * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as + * environment variables): + * - DMLC_TRACKER_URI: Hostname of the tracker. + * - DMLC_TRACKER_PORT: Port number of the tracker. + * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. + * - DMLC_ROLE: Role of the current task, "worker" or "server". + * - DMLC_NUM_ATTEMPT: Number of attempts after task failure. + * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. + * Only applicable to the Federated communicator (these are not case-sensitive): + * - federated_server_address: Address of the federated server. + * - federated_world_size: Number of federated workers. + * - federated_rank: Rank of the current worker. + * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. + * - federated_client_key: Client key file path. Only needed for the SSL mode. + * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. + */ +XGB_DLL int XGCommunicatorInit(char const* json_config); /*! * \brief finalize the collective communicator, diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index df5d21aaea5e..0fcd192d9d63 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -2,6 +2,8 @@ * Copyright 2022 XGBoost contributors */ #pragma once +#include + #include "../../src/collective/communicator.h" #include "../../src/common/io.h" #include "federated_client.h" @@ -70,7 +72,7 @@ class FederatedCommunicator : public Communicator { class FederatedCommunicatorFactory { public: - FederatedCommunicatorFactory(int argc, char *argv[]) { + explicit FederatedCommunicatorFactory(Json const &config) { // Parse environment variables first. for (auto const &env_var : env_vars_) { char const *value = getenv(env_var.c_str()); @@ -79,13 +81,30 @@ class FederatedCommunicatorFactory { } } - // Command line argument overrides. - for (int i = 0; i < argc; ++i) { - std::string const key_value = argv[i]; - auto const delimiter = key_value.find('='); - if (delimiter != std::string::npos) { - SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1)); - } + // Runtime configuration overrides. + auto const &j_server_address = config["federated_server_address"]; + if (IsA(j_server_address)) { + server_address_ = get(j_server_address); + } + auto const &j_world_size = config["federated_world_size"]; + if (IsA(j_world_size)) { + world_size_ = static_cast(get(j_world_size)); + } + auto const &j_rank = config["federated_rank"]; + if (IsA(j_rank)) { + rank_ = static_cast(get(j_rank)); + } + auto const &j_server_cert = config["federated_server_cert"]; + if (IsA(j_server_cert)) { + server_cert_ = get(j_server_cert); + } + auto const &j_client_key = config["federated_client_key"]; + if (IsA(j_client_key)) { + client_key_ = get(j_client_key); + } + auto const &j_client_cert = config["federated_client_cert"]; + if (IsA(j_client_cert)) { + client_cert_ = get(j_client_cert); } } diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index f6ccd716d75b..829184970d32 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -1,25 +1,68 @@ """XGBoost collective communication related API.""" import ctypes +import json import logging import pickle from enum import IntEnum, unique -from typing import Any, Optional, cast, List, Union +from typing import Any, List import numpy as np from ._typing import _T -from .core import _LIB, _check_call, c_str, py_str +from .core import _LIB, _check_call, c_str, py_str, from_pystr_to_cstr LOGGER = logging.getLogger("[xgboost.collective]") -def init(args: Optional[List[bytes]] = None) -> None: - """Initialize the collective library with arguments""" - if args is None: - args = [] - arr = (ctypes.c_char_p * len(args))() - arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args) - _check_call(_LIB.XGCommunicatorInit(len(arr), arr)) +def init(**args: Any) -> None: + """Initialize the collective library with arguments. + + Parameters + ---------- + args: Dict[str, Any] + Keyword arguments representing the parameters and their values. + + Accepted parameters: + - xgboost_communicator: The type of the communicator. Can be set as an environment + variable. + * rabit: Use Rabit. This is the default if the type is unspecified. + * mpi: Use MPI. + * federated: Use the gRPC interface for Federated Learning. + Only applicable to the Rabit communicator (these are case sensitive): + -- rabit_tracker_uri: Hostname of the tracker. + -- rabit_tracker_port: Port number of the tracker. + -- rabit_task_id: ID of the current task, can be used to obtain deterministic rank + assignment. + -- rabit_world_size: Total number of workers. + -- rabit_hadoop_mode: Enable Hadoop support. + -- rabit_tree_reduce_minsize: Minimal size for tree reduce. + -- rabit_reduce_ring_mincount: Minimal count to perform ring reduce. + -- rabit_reduce_buffer: Size of the reduce buffer. + -- rabit_bootstrap_cache: Size of the bootstrap cache. + -- rabit_debug: Enable debugging. + -- rabit_timeout: Enable timeout. + -- rabit_timeout_sec: Timeout in seconds. + -- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. + Only applicable to the Rabit communicator (these are case-sensitive, and can be set as + environment variables): + -- DMLC_TRACKER_URI: Hostname of the tracker. + -- DMLC_TRACKER_PORT: Port number of the tracker. + -- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank + assignment. + -- DMLC_ROLE: Role of the current task, "worker" or "server". + -- DMLC_NUM_ATTEMPT: Number of attempts after task failure. + -- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. + Only applicable to the Federated communicator (use upper case for environment variables, use + lower case for runtime configuration): + -- federated_server_address: Address of the federated server. + -- federated_world_size: Number of federated workers. + -- federated_rank: Rank of the current worker. + -- federated_server_cert: Server certificate file path. Only needed for the SSL mode. + -- federated_client_key: Client key file path. Only needed for the SSL mode. + -- federated_client_cert: Client certificate file path. Only needed for the SSL mode. + """ + config = from_pystr_to_cstr(json.dumps(args)) + _check_call(_LIB.XGCommunicatorInit(config)) def finalize() -> None: @@ -188,13 +231,11 @@ def allreduce( # pylint:disable=invalid-name class CommunicatorContext: """A context controlling collective communicator initialization and finalization.""" - def __init__(self, args: List[bytes] = None) -> None: - if args is None: - args = [] + def __init__(self, **args: Any) -> None: self.args = args def __enter__(self) -> None: - init(self.args) + init(**self.args) assert is_distributed() LOGGER.debug("-------------- communicator say hello ------------------") diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 7608298a420e..535213fe82ef 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1374,9 +1374,10 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, using xgboost::collective::CommunicatorFactory; -XGB_DLL int XGCommunicatorInit(int argc, char *argv[]) { +XGB_DLL int XGCommunicatorInit(char const* json_config) { API_BEGIN(); - CommunicatorFactory::Init(argc, argv); + Json config { Json::Load(StringView{json_config}) }; + CommunicatorFactory::Init(config); API_END(); } diff --git a/src/collective/communicator_factory.cc b/src/collective/communicator_factory.cc index 40072a652ea8..44f6df9d3fd8 100644 --- a/src/collective/communicator_factory.cc +++ b/src/collective/communicator_factory.cc @@ -18,13 +18,13 @@ thread_local std::unique_ptr CommunicatorFactory::instance_ CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) : type_{type}, communicator_{communicator} {} -void CommunicatorFactory::Init(int argc, char* argv[]) { +void CommunicatorFactory::Init(Json const& config) { if (instance_) { LOG(FATAL) << "Communicator factory can only be initialized once."; } auto type = GetTypeFromEnv(); - auto const arg = GetTypeFromArgs(argc, argv); + auto const arg = GetTypeFromConfig(config); if (arg != CommunicatorType::kUnknown) { type = arg; } @@ -34,7 +34,7 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { } switch (type) { case CommunicatorType::kRabit: { - RabitCommunicatorFactory factory{argc, argv}; + RabitCommunicatorFactory factory{config}; auto* comm = factory.Create(); instance_.reset(new CommunicatorFactory(type, comm)); break; @@ -44,7 +44,7 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { break; case CommunicatorType::kFederated: { #if defined(XGBOOST_USE_FEDERATED) - FederatedCommunicatorFactory factory{argc, argv}; + FederatedCommunicatorFactory factory{config}; auto* comm = factory.Create(); instance_.reset(new CommunicatorFactory(type, comm)); #else diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu index fc5bb4ae2785..ae194c925af3 100644 --- a/src/collective/communicator_factory.cu +++ b/src/collective/communicator_factory.cu @@ -19,13 +19,13 @@ thread_local std::unique_ptr CommunicatorFactory::instance_ CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) : type_{type}, communicator_{communicator} {} -void CommunicatorFactory::Init(int argc, char* argv[]) { +void CommunicatorFactory::Init(Json const& config) { if (instance_) { LOG(FATAL) << "Communicator factory can only be initialized once."; } auto type = GetTypeFromEnv(); - auto const arg = GetTypeFromArgs(argc, argv); + auto const arg = GetTypeFromConfig(config); if (arg != CommunicatorType::kUnknown) { type = arg; } @@ -35,7 +35,7 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { } switch (type) { case CommunicatorType::kRabit: { - RabitCommunicatorFactory factory{argc, argv}; + RabitCommunicatorFactory factory{config}; auto* comm = factory.Create(); instance_.reset(new CommunicatorFactory(type, comm)); break; @@ -45,7 +45,7 @@ void CommunicatorFactory::Init(int argc, char* argv[]) { break; case CommunicatorType::kFederated: { #if defined(XGBOOST_USE_FEDERATED) - FederatedCommunicatorFactory factory{argc, argv}; + FederatedCommunicatorFactory factory{config}; auto* comm = factory.Create(); instance_.reset(new CommunicatorFactory(type, comm)); #else diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h index d4568d95be5a..632a0f3a9198 100644 --- a/src/collective/communicator_factory.h +++ b/src/collective/communicator_factory.h @@ -2,6 +2,8 @@ * Copyright 2022 XGBoost contributors */ #pragma once +#include + #include #include @@ -16,9 +18,7 @@ class DeviceCommunicator; class CommunicatorFactory { public: - static constexpr char const* kCommunicatorKey = "XGBOOST_COMMUNICATOR"; - - static void Init(int argc, char* argv[]); + static void Init(Json const& config); static void Finalize(); @@ -32,7 +32,7 @@ class CommunicatorFactory { /** @brief Get the communicator type from environment variables. Visible for testing. */ static CommunicatorType GetTypeFromEnv() { - auto* env = std::getenv(kCommunicatorKey); + auto* env = std::getenv("XGBOOST_COMMUNICATOR"); if (env != nullptr) { return StringToType(env); } else { @@ -40,18 +40,15 @@ class CommunicatorFactory { } } - /** @brief Get the communicator type from arguments. Visible for testing. */ - static CommunicatorType GetTypeFromArgs(int argc, char* argv[]) { - for (int i = 0; i < argc; ++i) { - std::string const key_value = argv[i]; - auto const delimiter = key_value.find('='); - if (delimiter != std::string::npos) { - auto const key = key_value.substr(0, delimiter); - auto const value = key_value.substr(delimiter + 1); - if (!CompareStringsCaseInsensitive(key.c_str(), kCommunicatorKey)) { - return StringToType(value.c_str()); - } - } + /** @brief Get the communicator type from runtime configuration. Visible for testing. */ + static CommunicatorType GetTypeFromConfig(Json const& config) { + auto const& j_upper = config["XGBOOST_COMMUNICATOR"]; + if (IsA(j_upper)) { + return StringToType(get(j_upper).c_str()); + } + auto const& j_lower = config["xgboost_communicator"]; + if (IsA(j_lower)) { + return StringToType(get(j_lower).c_str()); } return CommunicatorType::kUnknown; } diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index b5adf2954986..afe84cd5f0e2 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -7,6 +7,7 @@ #include #include "communicator.h" +#include "xgboost/json.h" namespace xgboost { namespace collective { @@ -77,8 +78,35 @@ class RabitCommunicator : public Communicator { class RabitCommunicatorFactory { public: - RabitCommunicatorFactory(int argc, char *argv[]) { - rabit::Init(argc, argv); + explicit RabitCommunicatorFactory(Json const &config) { + std::vector args_str; + for (auto &items : get(config)) { + switch (items.second.GetValue().Type()) { + case xgboost::Value::ValueKind::kString: { + args_str.push_back(items.first + "=" + get(items.second)); + break; + } + case xgboost::Value::ValueKind::kInteger: { + args_str.push_back(items.first + "=" + std::to_string(get(items.second))); + break; + } + case xgboost::Value::ValueKind::kBoolean: { + if (get(items.second)) { + args_str.push_back(items.first + "=1"); + } else { + args_str.push_back(items.first + "=0"); + } + break; + } + default: + break; + } + } + std::vector args; + for (auto &key_value : args_str) { + args.push_back(&key_value[0]); + } + rabit::Init(static_cast(args.size()), &args[0]); world_size_ = rabit::GetWorldSize(); rank_ = rabit::GetRank(); } diff --git a/tests/cpp/collective/test_communicator_factory.cc b/tests/cpp/collective/test_communicator_factory.cc index b43a9046aab3..ec94a42870cb 100644 --- a/tests/cpp/collective/test_communicator_factory.cc +++ b/tests/cpp/collective/test_communicator_factory.cc @@ -12,35 +12,51 @@ namespace collective { TEST(CommunicatorFactory, TypeFromEnv) { EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromEnv()); - dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "rabit"); + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "rabit"); EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromEnv()); - dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "MPI"); + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "MPI"); EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromEnv()); - dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "Federated"); + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "Federated"); EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromEnv()); - dmlc::SetEnv(CommunicatorFactory::kCommunicatorKey, "foo"); + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "foo"); EXPECT_THROW(CommunicatorFactory::GetTypeFromEnv(), dmlc::Error); } TEST(CommunicatorFactory, TypeFromArgs) { - char *args0[] = {(char *)("foo=bar")}; // NOLINT(google-readability-casting) - EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromArgs(1, args0)); + Json config{JsonObject()}; + EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromConfig(config)); - char *args1[] = {(char *)("xgboost_communicator=rabit")}; // NOLINT(google-readability-casting) - EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromArgs(1, args1)); + config["xgboost_communicator"] = String("rabit"); + EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromConfig(config)); - char *args2[] = {(char *)("xgboost_communicator=MPI")}; // NOLINT(google-readability-casting) - EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromArgs(1, args2)); + config["xgboost_communicator"] = String("MPI"); + EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromConfig(config)); - char *args3[] = { - (char *)("xgboost_communicator=federated")}; // NOLINT(google-readability-casting) - EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromArgs(1, args3)); + config["xgboost_communicator"] = String("federated"); + EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromConfig(config)); - char *args4[] = {(char *)("xgboost_communicator=foo")}; // NOLINT(google-readability-casting) - EXPECT_THROW(CommunicatorFactory::GetTypeFromArgs(1, args4), dmlc::Error); + config["xgboost_communicator"] = String("foo"); + EXPECT_THROW(CommunicatorFactory::GetTypeFromConfig(config), dmlc::Error); +} + +TEST(CommunicatorFactory, TypeFromArgsUpperCase) { + Json config{JsonObject()}; + EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("rabit"); + EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("MPI"); + EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("federated"); + EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("foo"); + EXPECT_THROW(CommunicatorFactory::GetTypeFromConfig(config), dmlc::Error); } } // namespace collective diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 2bf9d5102284..225f49ed5235 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -91,11 +91,6 @@ TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { EXPECT_EQ(comm.GetRank(), 3); } -TEST(FederatedCommunicatorSimpleTest, IsNotDistributed) { - FederatedCommunicator comm{1, 0, kServerAddress}; - EXPECT_FALSE(comm.IsDistributed()); -} - TEST(FederatedCommunicatorSimpleTest, IsDistributed) { FederatedCommunicator comm{2, 1, kServerAddress}; EXPECT_TRUE(comm.IsDistributed()); @@ -122,86 +117,86 @@ TEST_F(FederatedCommunicatorTest, Broadcast) { } TEST(FederatedCommunicatorFactoryTest, ServerAddress) { - FederatedCommunicatorFactory factory0{0, nullptr}; + Json config{JsonObject()}; + FederatedCommunicatorFactory factory0{config}; EXPECT_EQ(factory0.GetServerAddress(), ""); dmlc::SetEnv("FEDERATED_SERVER_ADDRESS", "localhost:9091"); - FederatedCommunicatorFactory factory1{0, nullptr}; + FederatedCommunicatorFactory factory1{config}; EXPECT_EQ(factory1.GetServerAddress(), "localhost:9091"); - char *args[1]; - args[0] = strdup("federated_server_address=foo:9091"); - FederatedCommunicatorFactory factory2{1, args}; + config["federated_server_address"] = String("foo:9091"); + FederatedCommunicatorFactory factory2{config}; EXPECT_EQ(factory2.GetServerAddress(), "foo:9091"); } TEST(FederatedCommunicatorFactoryTest, WorldSize) { - FederatedCommunicatorFactory factory0{0, nullptr}; + Json config{JsonObject()}; + FederatedCommunicatorFactory factory0{config}; EXPECT_EQ(factory0.GetWorldSize(), 0); dmlc::SetEnv("FEDERATED_WORLD_SIZE", "2"); - FederatedCommunicatorFactory factory1{0, nullptr}; + FederatedCommunicatorFactory factory1{config}; EXPECT_EQ(factory1.GetWorldSize(), 2); - char *args[1]; - args[0] = strdup("federated_world_size=3"); - FederatedCommunicatorFactory factory2{1, args}; + config["federated_world_size"] = Integer(3); + FederatedCommunicatorFactory factory2{config}; EXPECT_EQ(factory2.GetWorldSize(), 3); } TEST(FederatedCommunicatorFactoryTest, Rank) { - FederatedCommunicatorFactory factory0{0, nullptr}; + Json config{JsonObject()}; + FederatedCommunicatorFactory factory0{config}; EXPECT_EQ(factory0.GetRank(), -1); dmlc::SetEnv("FEDERATED_RANK", "1"); - FederatedCommunicatorFactory factory1{0, nullptr}; + FederatedCommunicatorFactory factory1{config}; EXPECT_EQ(factory1.GetRank(), 1); - char *args[1]; - args[0] = strdup("federated_rank=2"); - FederatedCommunicatorFactory factory2{1, args}; + config["federated_rank"] = Integer(2); + FederatedCommunicatorFactory factory2{config}; EXPECT_EQ(factory2.GetRank(), 2); } TEST(FederatedCommunicatorFactoryTest, ServerCert) { - FederatedCommunicatorFactory factory0{0, nullptr}; + Json config{JsonObject()}; + FederatedCommunicatorFactory factory0{config}; EXPECT_EQ(factory0.GetServerCert(), ""); dmlc::SetEnv("FEDERATED_SERVER_CERT", "foo/server-cert.pem"); - FederatedCommunicatorFactory factory1{0, nullptr}; + FederatedCommunicatorFactory factory1{config}; EXPECT_EQ(factory1.GetServerCert(), "foo/server-cert.pem"); - char *args[1]; - args[0] = strdup("federated_server_cert=bar/server-cert.pem"); - FederatedCommunicatorFactory factory2{1, args}; + config["federated_server_cert"] = String("bar/server-cert.pem"); + FederatedCommunicatorFactory factory2{config}; EXPECT_EQ(factory2.GetServerCert(), "bar/server-cert.pem"); } TEST(FederatedCommunicatorFactoryTest, ClientKey) { - FederatedCommunicatorFactory factory0{0, nullptr}; + Json config{JsonObject()}; + FederatedCommunicatorFactory factory0{config}; EXPECT_EQ(factory0.GetClientKey(), ""); dmlc::SetEnv("FEDERATED_CLIENT_KEY", "foo/client-key.pem"); - FederatedCommunicatorFactory factory1{0, nullptr}; + FederatedCommunicatorFactory factory1{config}; EXPECT_EQ(factory1.GetClientKey(), "foo/client-key.pem"); - char *args[1]; - args[0] = strdup("federated_client_key=bar/client-key.pem"); - FederatedCommunicatorFactory factory2{1, args}; + config["federated_client_key"] = String("bar/client-key.pem"); + FederatedCommunicatorFactory factory2{config}; EXPECT_EQ(factory2.GetClientKey(), "bar/client-key.pem"); } TEST(FederatedCommunicatorFactoryTest, ClientCert) { - FederatedCommunicatorFactory factory0{0, nullptr}; + Json config{JsonObject()}; + FederatedCommunicatorFactory factory0{config}; EXPECT_EQ(factory0.GetClientCert(), ""); dmlc::SetEnv("FEDERATED_CLIENT_CERT", "foo/client-cert.pem"); - FederatedCommunicatorFactory factory1{0, nullptr}; + FederatedCommunicatorFactory factory1{config}; EXPECT_EQ(factory1.GetClientCert(), "foo/client-cert.pem"); - char *args[1]; - args[0] = strdup("federated_client_cert=bar/client-cert.pem"); - FederatedCommunicatorFactory factory2{1, args}; + config["federated_client_cert"] = String("bar/client-cert.pem"); + FederatedCommunicatorFactory factory2{config}; EXPECT_EQ(factory2.GetClientCert(), "bar/client-cert.pem"); } diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 74810bbe3afa..01e2b01fb6e7 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -8,7 +8,7 @@ def run_rabit_worker(rabit_env, world_size, rank): - with xgb.collective.CommunicatorContext(rabit_env): + with xgb.collective.CommunicatorContext(**rabit_env): assert xgb.collective.get_world_size() == world_size assert xgb.collective.get_rank() == rank assert xgb.collective.is_distributed() @@ -23,15 +23,10 @@ def test_rabit_communicator(): world_size = 2 tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) tracker.start(world_size) - worker_env = tracker.worker_envs() - rabit_env = [] - for k, v in worker_env.items(): - rabit_env.append(f"{k}={v}".encode()) - workers = [] for rank in reversed(range(world_size)): worker = multiprocessing.Process(target=run_rabit_worker, - args=(rabit_env, world_size, rank)) + args=(tracker.worker_envs(), world_size, rank)) workers.append(worker) worker.start() for worker in workers: From b57a1be4a09f532f3c7b837643b189265295f6c6 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 30 Aug 2022 17:17:34 -0700 Subject: [PATCH 26/37] fix lint error --- src/collective/rabit_communicator.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index afe84cd5f0e2..78f31f2d3a32 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -5,6 +5,7 @@ #include #include +#include #include "communicator.h" #include "xgboost/json.h" From ba5a6e1a24a3a17f0519cc328c8eb4dcabebf868 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Sep 2022 16:04:33 -0700 Subject: [PATCH 27/37] get rid of factories --- amalgamation/xgboost-all0.cc | 2 +- plugin/federated/federated_communicator.h | 187 ++++++++---------- src/c_api/c_api.cc | 21 +- ...ommunicator_factory.cc => communicator.cc} | 27 ++- src/collective/communicator.cu | 40 ++++ src/collective/communicator.h | 119 ++++++++--- src/collective/communicator_factory.cu | 80 -------- src/collective/communicator_factory.h | 83 -------- src/collective/rabit_communicator.h | 75 +++---- tests/cpp/collective/test_communicator.cc | 63 ++++++ .../collective/test_communicator_factory.cc | 63 ------ .../cpp/plugin/test_federated_communicator.cc | 85 -------- 12 files changed, 333 insertions(+), 512 deletions(-) rename src/collective/{communicator_factory.cc => communicator.cc} (57%) create mode 100644 src/collective/communicator.cu delete mode 100644 src/collective/communicator_factory.cu delete mode 100644 src/collective/communicator_factory.h create mode 100644 tests/cpp/collective/test_communicator.cc delete mode 100644 tests/cpp/collective/test_communicator_factory.cc diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 1dc154fb9f02..ded96bcbab4c 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -72,7 +72,7 @@ #include "../src/global_config.cc" // collective -#include "../src/collective/communicator_factory.cc" +#include "../src/collective/communicator.cc" // common #include "../src/common/common.cc" diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 0fcd192d9d63..1ae5cafba249 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -16,6 +16,79 @@ namespace collective { */ class FederatedCommunicator : public Communicator { public: + static Communicator *Create(Json const &config) { + std::string server_address{}; + int world_size{0}; + int rank{-1}; + std::string server_cert{}; + std::string client_key{}; + std::string client_cert{}; + + // Parse environment variables first. + auto *value = getenv("FEDERATED_SERVER_ADDRESS"); + if (value != nullptr) { + server_address = value; + } + value = getenv("FEDERATED_WORLD_SIZE"); + if (value != nullptr) { + world_size = std::stoi(value); + } + value = getenv("FEDERATED_RANK"); + if (value != nullptr) { + rank = std::stoi(value); + } + value = getenv("FEDERATED_SERVER_CERT"); + if (value != nullptr) { + server_cert = value; + } + value = getenv("FEDERATED_CLIENT_KEY"); + if (value != nullptr) { + client_key = value; + } + value = getenv("FEDERATED_CLIENT_CERT"); + if (value != nullptr) { + client_cert = value; + } + + // Runtime configuration overrides. + auto const &j_server_address = config["federated_server_address"]; + if (IsA(j_server_address)) { + server_address = get(j_server_address); + } + auto const &j_world_size = config["federated_world_size"]; + if (IsA(j_world_size)) { + world_size = static_cast(get(j_world_size)); + } + auto const &j_rank = config["federated_rank"]; + if (IsA(j_rank)) { + rank = static_cast(get(j_rank)); + } + auto const &j_server_cert = config["federated_server_cert"]; + if (IsA(j_server_cert)) { + server_cert = get(j_server_cert); + } + auto const &j_client_key = config["federated_client_key"]; + if (IsA(j_client_key)) { + client_key = get(j_client_key); + } + auto const &j_client_cert = config["federated_client_cert"]; + if (IsA(j_client_cert)) { + client_cert = get(j_client_cert); + } + + if (server_address.empty()) { + LOG(FATAL) << "Federated server address must be set."; + } + if (world_size == 0) { + LOG(FATAL) << "Federated world size must be set."; + } + if (rank == -1) { + LOG(FATAL) << "Federated rank must be set."; + } + return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key, + client_cert); + } + /** * @brief Construct a new federated communicator. * @@ -26,9 +99,13 @@ class FederatedCommunicator : public Communicator { std::string const &server_cert_path, std::string const &client_key_path, std::string const &client_cert_path) : Communicator{world_size, rank} { - client_.reset(new xgboost::federated::FederatedClient( - server_address, rank, xgboost::common::ReadAll(server_cert_path), - xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path))); + if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) { + client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); + } else { + client_.reset(new xgboost::federated::FederatedClient( + server_address, rank, xgboost::common::ReadAll(server_cert_path), + xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path))); + } } /** @brief Insecure communicator for testing only. */ @@ -69,109 +146,5 @@ class FederatedCommunicator : public Communicator { private: std::unique_ptr client_{}; }; - -class FederatedCommunicatorFactory { - public: - explicit FederatedCommunicatorFactory(Json const &config) { - // Parse environment variables first. - for (auto const &env_var : env_vars_) { - char const *value = getenv(env_var.c_str()); - if (value != nullptr) { - SetParam(env_var, value); - } - } - - // Runtime configuration overrides. - auto const &j_server_address = config["federated_server_address"]; - if (IsA(j_server_address)) { - server_address_ = get(j_server_address); - } - auto const &j_world_size = config["federated_world_size"]; - if (IsA(j_world_size)) { - world_size_ = static_cast(get(j_world_size)); - } - auto const &j_rank = config["federated_rank"]; - if (IsA(j_rank)) { - rank_ = static_cast(get(j_rank)); - } - auto const &j_server_cert = config["federated_server_cert"]; - if (IsA(j_server_cert)) { - server_cert_ = get(j_server_cert); - } - auto const &j_client_key = config["federated_client_key"]; - if (IsA(j_client_key)) { - client_key_ = get(j_client_key); - } - auto const &j_client_cert = config["federated_client_cert"]; - if (IsA(j_client_cert)) { - client_cert_ = get(j_client_cert); - } - } - - Communicator *Create() { - if (server_address_.empty()) { - LOG(FATAL) << "Federated server address must be set."; - } - if (world_size_ == 0) { - LOG(FATAL) << "Federated world size must be set."; - } - if (rank_ == -1) { - LOG(FATAL) << "Federated rank must be set."; - } - if (server_cert_.empty()) { - LOG(FATAL) << "Federated server cert must be set."; - } - if (client_key_.empty()) { - LOG(FATAL) << "Federated client key must be set."; - } - if (client_cert_.empty()) { - LOG(FATAL) << "Federated client cert must be set."; - } - return new FederatedCommunicator(world_size_, rank_, server_address_, server_cert_, client_key_, - client_cert_); - } - - std::string const &GetServerAddress() const { return server_address_; } - int GetWorldSize() const { return world_size_; } - int GetRank() const { return rank_; } - std::string const &GetServerCert() const { return server_cert_; } - std::string const &GetClientKey() const { return client_key_; } - std::string const &GetClientCert() const { return client_cert_; } - - private: - void SetParam(std::string const &name, std::string const &val) { - if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { - server_address_ = val; - } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_WORLD_SIZE")) { - world_size_ = std::stoi(val); - } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_RANK")) { - rank_ = std::stoi(val); - } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_SERVER_CERT")) { - server_cert_ = val; - } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_CLIENT_KEY")) { - client_key_ = val; - } else if (!CompareStringsCaseInsensitive(name.c_str(), "FEDERATED_CLIENT_CERT")) { - client_cert_ = val; - } - } - - // clang-format off - std::vector const env_vars_{ - "FEDERATED_SERVER_ADDRESS", - "FEDERATED_WORLD_SIZE", - "FEDERATED_RANK", - "FEDERATED_SERVER_CERT", - "FEDERATED_CLIENT_KEY", - "FEDERATED_CLIENT_CERT" }; - // clang-format on - - std::string server_address_{}; - int world_size_{0}; - int rank_{-1}; - std::string server_cert_{}; - std::string client_key_{}; - std::string client_cert_{}; -}; - } // namespace collective } // namespace xgboost diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 535213fe82ef..9b5bea3acd00 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -23,7 +23,6 @@ #include "c_api_error.h" #include "c_api_utils.h" #include "../collective/communicator.h" -#include "../collective/communicator_factory.h" #include "../common/io.h" #include "../common/charconv.h" #include "../data/adapter.h" @@ -1372,57 +1371,57 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, API_END(); } -using xgboost::collective::CommunicatorFactory; +using xgboost::collective::Communicator; XGB_DLL int XGCommunicatorInit(char const* json_config) { API_BEGIN(); Json config { Json::Load(StringView{json_config}) }; - CommunicatorFactory::Init(config); + Communicator::Init(config); API_END(); } XGB_DLL int XGCommunicatorFinalize(void) { API_BEGIN(); - CommunicatorFactory::Finalize(); + Communicator::Finalize(); API_END(); } XGB_DLL int XGCommunicatorGetRank(void) { - return CommunicatorFactory::GetInstance()->GetCommunicator()->GetRank(); + return Communicator::Get()->GetRank(); } XGB_DLL int XGCommunicatorGetWorldSize(void) { - return CommunicatorFactory::GetInstance()->GetCommunicator()->GetWorldSize(); + return Communicator::Get()->GetWorldSize(); } XGB_DLL int XGCommunicatorIsDistributed(void) { - return CommunicatorFactory::GetInstance()->GetCommunicator()->IsDistributed(); + return Communicator::Get()->IsDistributed(); } XGB_DLL int XGCommunicatorPrint(char const *message) { API_BEGIN(); - CommunicatorFactory::GetInstance()->GetCommunicator()->Print(message); + Communicator::Get()->Print(message); API_END(); } XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { API_BEGIN(); auto& local = *GlobalConfigAPIThreadLocalStore::Get(); - local.ret_str = CommunicatorFactory::GetInstance()->GetCommunicator()->GetProcessorName(); + local.ret_str = Communicator::Get()->GetProcessorName(); *name_str = local.ret_str.c_str(); API_END(); } XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { API_BEGIN(); - CommunicatorFactory::GetInstance()->GetCommunicator()->Broadcast(send_receive_buffer, size, root); + Communicator::Get()->Broadcast(send_receive_buffer, size, root); API_END(); } XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, int enum_op) { API_BEGIN(); - CommunicatorFactory::GetInstance()->GetCommunicator()->AllReduce( + Communicator::Get()->AllReduce( send_receive_buffer, count, static_cast(enum_dtype), static_cast(enum_op)); API_END(); diff --git a/src/collective/communicator_factory.cc b/src/collective/communicator.cc similarity index 57% rename from src/collective/communicator_factory.cc rename to src/collective/communicator.cc index 44f6df9d3fd8..47e7d1ac581a 100644 --- a/src/collective/communicator_factory.cc +++ b/src/collective/communicator.cc @@ -1,7 +1,7 @@ /*! * Copyright 2022 XGBoost contributors */ -#include "communicator_factory.h" +#include "communicator.h" #include "rabit_communicator.h" @@ -12,15 +12,12 @@ namespace xgboost { namespace collective { -#ifndef XGBOOST_USE_CUDA -thread_local std::unique_ptr CommunicatorFactory::instance_{}; - -CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) - : type_{type}, communicator_{communicator} {} +thread_local std::unique_ptr Communicator::communicator_{}; +thread_local CommunicatorType Communicator::type_{}; -void CommunicatorFactory::Init(Json const& config) { - if (instance_) { - LOG(FATAL) << "Communicator factory can only be initialized once."; +void Communicator::Init(Json const& config) { + if (communicator_) { + LOG(FATAL) << "Communicator can only be initialized once."; } auto type = GetTypeFromEnv(); @@ -32,11 +29,10 @@ void CommunicatorFactory::Init(Json const& config) { // Default to Rabit if unspecified. type = CommunicatorType::kRabit; } + type_ = type; switch (type) { case CommunicatorType::kRabit: { - RabitCommunicatorFactory factory{config}; - auto* comm = factory.Create(); - instance_.reset(new CommunicatorFactory(type, comm)); + communicator_.reset(RabitCommunicator::Create(config)); break; } case CommunicatorType::kMPI: @@ -44,9 +40,7 @@ void CommunicatorFactory::Init(Json const& config) { break; case CommunicatorType::kFederated: { #if defined(XGBOOST_USE_FEDERATED) - FederatedCommunicatorFactory factory{config}; - auto* comm = factory.Create(); - instance_.reset(new CommunicatorFactory(type, comm)); + communicator_.reset(FederatedCommunicator::Create(config)); #else LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; #endif @@ -57,7 +51,8 @@ void CommunicatorFactory::Init(Json const& config) { } } -void CommunicatorFactory::Finalize() { instance_.reset(); } +#ifndef XGBOOST_USE_CUDA +void CommunicatorFactory::Finalize() { communicator_.reset(); } #endif } // namespace collective diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu new file mode 100644 index 000000000000..9495cf6ae5ba --- /dev/null +++ b/src/collective/communicator.cu @@ -0,0 +1,40 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "communicator.h" +#include "device_communicator.cuh" +#include "device_communicator_adapter.cuh" +#ifdef XGBOOST_USE_NCCL +#include "nccl_device_communicator.cuh" +#endif + +namespace xgboost { +namespace collective { + +thread_local int Communicator::device_ordinal_{-1}; +thread_local std::unique_ptr Communicator::device_communicator_{}; + +void Communicator::Finalize() { + communicator_.reset(); + device_ordinal_ = -1; + device_communicator_.reset(); +} + +DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { + if (!device_communicator_ || device_ordinal_ != device_ordinal) { + device_ordinal_ = device_ordinal; +#ifdef XGBOOST_USE_NCCL + if (type_ != CommunicatorType::kFederated) { + device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get())); + } else { + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); + } +#else + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); +#endif + } + return device_communicator_.get(); +} + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 05162f9c4a90..2e5b10e5afc7 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -2,6 +2,7 @@ * Copyright 2022 XGBoost contributors */ #pragma once +#include #include #include @@ -57,28 +58,46 @@ inline std::size_t GetTypeSize(DataType data_type) { /** @brief Defines the reduction operation. */ enum class Operation { kMax = 0, kMin = 1, kSum = 2 }; +class DeviceCommunicator; + +enum class CommunicatorType { kUnknown, kRabit, kMPI, kFederated }; + +/* \brief Case-insensitive string comparison */ +inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { +#ifdef _MSC_VER + return _stricmp(s1, s2); +#else // _MSC_VER + return strcasecmp(s1, s2); +#endif // _MSC_VER +} + /** * @brief A communicator class that handles collective communication. */ class Communicator { public: /** - * @brief Construct a new communicator. + * @brief Initialize the communicator. This can only be done once. * - * @param world_size Total number of processes. - * @param rank Rank of the current process. + * @param config JSON configuration for the communicator. */ - Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) { - if (world_size < 1) { - LOG(FATAL) << "World size " << world_size << " is less than 1."; - } - if (rank < 0) { - LOG(FATAL) << "Rank " << rank << " is less than 0."; - } - if (rank >= world_size) { - LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << "."; - } - } + static void Init(Json const &config); + + /** @brief Finalize the communicator. */ + static void Finalize(); + + /** @brief Get the communicator instance. */ + static Communicator *Get() { return communicator_.get(); } + +#if defined(XGBOOST_USE_CUDA) + /** + * @brief Get the device communicator. + * + * @param device_ordinal ID of the device. + * @return An instance of device communicator. + */ + static DeviceCommunicator *GetDevice(int device_ordinal); +#endif virtual ~Communicator() = default; @@ -122,19 +141,73 @@ class Communicator { */ virtual void Print(std::string const &message) = 0; + /** @brief Get the communicator type from environment variables. Visible for testing. */ + static CommunicatorType GetTypeFromEnv() { + auto *env = std::getenv("XGBOOST_COMMUNICATOR"); + if (env != nullptr) { + return StringToType(env); + } else { + return CommunicatorType::kUnknown; + } + } + + /** @brief Get the communicator type from runtime configuration. Visible for testing. */ + static CommunicatorType GetTypeFromConfig(Json const &config) { + auto const &j_upper = config["XGBOOST_COMMUNICATOR"]; + if (IsA(j_upper)) { + return StringToType(get(j_upper).c_str()); + } + auto const &j_lower = config["xgboost_communicator"]; + if (IsA(j_lower)) { + return StringToType(get(j_lower).c_str()); + } + return CommunicatorType::kUnknown; + } + + protected: + /** + * @brief Construct a new communicator. + * + * @param world_size Total number of processes. + * @param rank Rank of the current process. + */ + Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) { + if (world_size < 1) { + LOG(FATAL) << "World size " << world_size << " is less than 1."; + } + if (rank < 0) { + LOG(FATAL) << "Rank " << rank << " is less than 0."; + } + if (rank >= world_size) { + LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << "."; + } + } + private: + static CommunicatorType StringToType(char const *str) { + CommunicatorType result = CommunicatorType::kUnknown; + if (!CompareStringsCaseInsensitive("rabit", str)) { + result = CommunicatorType::kRabit; + } else if (!CompareStringsCaseInsensitive("mpi", str)) { + result = CommunicatorType::kMPI; + } else if (!CompareStringsCaseInsensitive("federated", str)) { + result = CommunicatorType::kFederated; + } else { + LOG(FATAL) << "Unknown communicator type " << str; + } + return result; + } + + static thread_local std::unique_ptr communicator_; + static thread_local CommunicatorType type_; +#if defined(XGBOOST_USE_CUDA) + static thread_local int device_ordinal_; + static thread_local std::unique_ptr device_communicator_; +#endif + int const world_size_; int const rank_; }; -/* \brief Case-insensitive string comparison */ -inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { -#ifdef _MSC_VER - return _stricmp(s1, s2); -#else // _MSC_VER - return strcasecmp(s1, s2); -#endif // _MSC_VER -} - } // namespace collective } // namespace xgboost diff --git a/src/collective/communicator_factory.cu b/src/collective/communicator_factory.cu deleted file mode 100644 index ae194c925af3..000000000000 --- a/src/collective/communicator_factory.cu +++ /dev/null @@ -1,80 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "communicator_factory.h" -#include "device_communicator_adapter.cuh" -#include "rabit_communicator.h" -#ifdef XGBOOST_USE_NCCL -#include "nccl_device_communicator.cuh" -#endif -#if defined(XGBOOST_USE_FEDERATED) -#include "../../plugin/federated/federated_communicator.h" -#endif - -namespace xgboost { -namespace collective { - -thread_local std::unique_ptr CommunicatorFactory::instance_{}; - -CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator) - : type_{type}, communicator_{communicator} {} - -void CommunicatorFactory::Init(Json const& config) { - if (instance_) { - LOG(FATAL) << "Communicator factory can only be initialized once."; - } - - auto type = GetTypeFromEnv(); - auto const arg = GetTypeFromConfig(config); - if (arg != CommunicatorType::kUnknown) { - type = arg; - } - if (type == CommunicatorType::kUnknown) { - // Default to Rabit if unspecified. - type = CommunicatorType::kRabit; - } - switch (type) { - case CommunicatorType::kRabit: { - RabitCommunicatorFactory factory{config}; - auto* comm = factory.Create(); - instance_.reset(new CommunicatorFactory(type, comm)); - break; - } - case CommunicatorType::kMPI: - LOG(FATAL) << "Not implemented yet."; - break; - case CommunicatorType::kFederated: { -#if defined(XGBOOST_USE_FEDERATED) - FederatedCommunicatorFactory factory{config}; - auto* comm = factory.Create(); - instance_.reset(new CommunicatorFactory(type, comm)); -#else - LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; -#endif - break; - } - case CommunicatorType::kUnknown: - LOG(FATAL) << "Unknown communicator type."; - } -} - -void CommunicatorFactory::Finalize() { instance_.reset(); } - -DeviceCommunicator* CommunicatorFactory::GetDeviceCommunicator(int device_ordinal) { - if (!device_communicator_) { -#ifdef XGBOOST_USE_NCCL - if (type_ != CommunicatorType::kFederated) { - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, communicator_.get())); - } else { - device_communicator_.reset( - new DeviceCommunicatorAdapter(device_ordinal, communicator_.get())); - } -#else - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, communicator_.get())); -#endif - } - return device_communicator_.get(); -} - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/communicator_factory.h b/src/collective/communicator_factory.h deleted file mode 100644 index 632a0f3a9198..000000000000 --- a/src/collective/communicator_factory.h +++ /dev/null @@ -1,83 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include - -#include -#include - -#include "communicator.h" - -namespace xgboost { -namespace collective { - -enum class CommunicatorType { kUnknown, kRabit, kMPI, kFederated }; - -class DeviceCommunicator; - -class CommunicatorFactory { - public: - static void Init(Json const& config); - - static void Finalize(); - - static CommunicatorFactory* GetInstance() { return instance_.get(); } - - Communicator* GetCommunicator() { return communicator_.get(); } - -#if defined(XGBOOST_USE_CUDA) - DeviceCommunicator* GetDeviceCommunicator(int device_ordinal); -#endif - - /** @brief Get the communicator type from environment variables. Visible for testing. */ - static CommunicatorType GetTypeFromEnv() { - auto* env = std::getenv("XGBOOST_COMMUNICATOR"); - if (env != nullptr) { - return StringToType(env); - } else { - return CommunicatorType::kUnknown; - } - } - - /** @brief Get the communicator type from runtime configuration. Visible for testing. */ - static CommunicatorType GetTypeFromConfig(Json const& config) { - auto const& j_upper = config["XGBOOST_COMMUNICATOR"]; - if (IsA(j_upper)) { - return StringToType(get(j_upper).c_str()); - } - auto const& j_lower = config["xgboost_communicator"]; - if (IsA(j_lower)) { - return StringToType(get(j_lower).c_str()); - } - return CommunicatorType::kUnknown; - } - - private: - CommunicatorFactory(CommunicatorType type, Communicator* communicator); - - private: - static CommunicatorType StringToType(char const* str) { - CommunicatorType result = CommunicatorType::kUnknown; - if (!CompareStringsCaseInsensitive("rabit", str)) { - result = CommunicatorType::kRabit; - } else if (!CompareStringsCaseInsensitive("mpi", str)) { - result = CommunicatorType::kMPI; - } else if (!CompareStringsCaseInsensitive("federated", str)) { - result = CommunicatorType::kFederated; - } else { - LOG(FATAL) << "Unknown communicator type " << str; - } - return result; - } - - static thread_local std::unique_ptr instance_; - CommunicatorType type_; - std::unique_ptr communicator_; -#if defined(XGBOOST_USE_CUDA) - std::unique_ptr device_communicator_; -#endif -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 78f31f2d3a32..248bf3df555b 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -15,6 +15,38 @@ namespace collective { class RabitCommunicator : public Communicator { public: + static Communicator *Create(Json const &config) { + std::vector args_str; + for (auto &items : get(config)) { + switch (items.second.GetValue().Type()) { + case xgboost::Value::ValueKind::kString: { + args_str.push_back(items.first + "=" + get(items.second)); + break; + } + case xgboost::Value::ValueKind::kInteger: { + args_str.push_back(items.first + "=" + std::to_string(get(items.second))); + break; + } + case xgboost::Value::ValueKind::kBoolean: { + if (get(items.second)) { + args_str.push_back(items.first + "=1"); + } else { + args_str.push_back(items.first + "=0"); + } + break; + } + default: + break; + } + } + std::vector args; + for (auto &key_value : args_str) { + args.push_back(&key_value[0]); + } + rabit::Init(static_cast(args.size()), &args[0]); + return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank()); + } + RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} ~RabitCommunicator() override { rabit::Finalize(); } @@ -76,48 +108,5 @@ class RabitCommunicator : public Communicator { } } }; - -class RabitCommunicatorFactory { - public: - explicit RabitCommunicatorFactory(Json const &config) { - std::vector args_str; - for (auto &items : get(config)) { - switch (items.second.GetValue().Type()) { - case xgboost::Value::ValueKind::kString: { - args_str.push_back(items.first + "=" + get(items.second)); - break; - } - case xgboost::Value::ValueKind::kInteger: { - args_str.push_back(items.first + "=" + std::to_string(get(items.second))); - break; - } - case xgboost::Value::ValueKind::kBoolean: { - if (get(items.second)) { - args_str.push_back(items.first + "=1"); - } else { - args_str.push_back(items.first + "=0"); - } - break; - } - default: - break; - } - } - std::vector args; - for (auto &key_value : args_str) { - args.push_back(&key_value[0]); - } - rabit::Init(static_cast(args.size()), &args[0]); - world_size_ = rabit::GetWorldSize(); - rank_ = rabit::GetRank(); - } - - Communicator *Create() const { return new RabitCommunicator(world_size_, rank_); } - - private: - int world_size_; - int rank_; -}; - } // namespace collective } // namespace xgboost diff --git a/tests/cpp/collective/test_communicator.cc b/tests/cpp/collective/test_communicator.cc new file mode 100644 index 000000000000..00c7fccfe582 --- /dev/null +++ b/tests/cpp/collective/test_communicator.cc @@ -0,0 +1,63 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include + +#include "../../../src/collective/communicator.h" + +namespace xgboost { +namespace collective { + +TEST(CommunicatorFactory, TypeFromEnv) { + EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv()); + + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "rabit"); + EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv()); + + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "MPI"); + EXPECT_EQ(CommunicatorType::kMPI, Communicator::GetTypeFromEnv()); + + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "Federated"); + EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv()); + + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "foo"); + EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error); +} + +TEST(CommunicatorFactory, TypeFromArgs) { + Json config{JsonObject()}; + EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); + + config["xgboost_communicator"] = String("rabit"); + EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); + + config["xgboost_communicator"] = String("MPI"); + EXPECT_EQ(CommunicatorType::kMPI, Communicator::GetTypeFromConfig(config)); + + config["xgboost_communicator"] = String("federated"); + EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); + + config["xgboost_communicator"] = String("foo"); + EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); +} + +TEST(CommunicatorFactory, TypeFromArgsUpperCase) { + Json config{JsonObject()}; + EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("rabit"); + EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("MPI"); + EXPECT_EQ(CommunicatorType::kMPI, Communicator::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("federated"); + EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("foo"); + EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); +} + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/collective/test_communicator_factory.cc b/tests/cpp/collective/test_communicator_factory.cc deleted file mode 100644 index ec94a42870cb..000000000000 --- a/tests/cpp/collective/test_communicator_factory.cc +++ /dev/null @@ -1,63 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include "../../../src/collective/communicator_factory.h" - -namespace xgboost { -namespace collective { - -TEST(CommunicatorFactory, TypeFromEnv) { - EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "rabit"); - EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "MPI"); - EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "Federated"); - EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "foo"); - EXPECT_THROW(CommunicatorFactory::GetTypeFromEnv(), dmlc::Error); -} - -TEST(CommunicatorFactory, TypeFromArgs) { - Json config{JsonObject()}; - EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("rabit"); - EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("MPI"); - EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("federated"); - EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("foo"); - EXPECT_THROW(CommunicatorFactory::GetTypeFromConfig(config), dmlc::Error); -} - -TEST(CommunicatorFactory, TypeFromArgsUpperCase) { - Json config{JsonObject()}; - EXPECT_EQ(CommunicatorType::kUnknown, CommunicatorFactory::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("rabit"); - EXPECT_EQ(CommunicatorType::kRabit, CommunicatorFactory::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("MPI"); - EXPECT_EQ(CommunicatorType::kMPI, CommunicatorFactory::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("federated"); - EXPECT_EQ(CommunicatorType::kFederated, CommunicatorFactory::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("foo"); - EXPECT_THROW(CommunicatorFactory::GetTypeFromConfig(config), dmlc::Error); -} - -} // namespace collective -} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 225f49ed5235..2d9f233db573 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -115,90 +115,5 @@ TEST_F(FederatedCommunicatorTest, Broadcast) { thread.join(); } } - -TEST(FederatedCommunicatorFactoryTest, ServerAddress) { - Json config{JsonObject()}; - FederatedCommunicatorFactory factory0{config}; - EXPECT_EQ(factory0.GetServerAddress(), ""); - - dmlc::SetEnv("FEDERATED_SERVER_ADDRESS", "localhost:9091"); - FederatedCommunicatorFactory factory1{config}; - EXPECT_EQ(factory1.GetServerAddress(), "localhost:9091"); - - config["federated_server_address"] = String("foo:9091"); - FederatedCommunicatorFactory factory2{config}; - EXPECT_EQ(factory2.GetServerAddress(), "foo:9091"); -} - -TEST(FederatedCommunicatorFactoryTest, WorldSize) { - Json config{JsonObject()}; - FederatedCommunicatorFactory factory0{config}; - EXPECT_EQ(factory0.GetWorldSize(), 0); - - dmlc::SetEnv("FEDERATED_WORLD_SIZE", "2"); - FederatedCommunicatorFactory factory1{config}; - EXPECT_EQ(factory1.GetWorldSize(), 2); - - config["federated_world_size"] = Integer(3); - FederatedCommunicatorFactory factory2{config}; - EXPECT_EQ(factory2.GetWorldSize(), 3); -} - -TEST(FederatedCommunicatorFactoryTest, Rank) { - Json config{JsonObject()}; - FederatedCommunicatorFactory factory0{config}; - EXPECT_EQ(factory0.GetRank(), -1); - - dmlc::SetEnv("FEDERATED_RANK", "1"); - FederatedCommunicatorFactory factory1{config}; - EXPECT_EQ(factory1.GetRank(), 1); - - config["federated_rank"] = Integer(2); - FederatedCommunicatorFactory factory2{config}; - EXPECT_EQ(factory2.GetRank(), 2); -} - -TEST(FederatedCommunicatorFactoryTest, ServerCert) { - Json config{JsonObject()}; - FederatedCommunicatorFactory factory0{config}; - EXPECT_EQ(factory0.GetServerCert(), ""); - - dmlc::SetEnv("FEDERATED_SERVER_CERT", "foo/server-cert.pem"); - FederatedCommunicatorFactory factory1{config}; - EXPECT_EQ(factory1.GetServerCert(), "foo/server-cert.pem"); - - config["federated_server_cert"] = String("bar/server-cert.pem"); - FederatedCommunicatorFactory factory2{config}; - EXPECT_EQ(factory2.GetServerCert(), "bar/server-cert.pem"); -} - -TEST(FederatedCommunicatorFactoryTest, ClientKey) { - Json config{JsonObject()}; - FederatedCommunicatorFactory factory0{config}; - EXPECT_EQ(factory0.GetClientKey(), ""); - - dmlc::SetEnv("FEDERATED_CLIENT_KEY", "foo/client-key.pem"); - FederatedCommunicatorFactory factory1{config}; - EXPECT_EQ(factory1.GetClientKey(), "foo/client-key.pem"); - - config["federated_client_key"] = String("bar/client-key.pem"); - FederatedCommunicatorFactory factory2{config}; - EXPECT_EQ(factory2.GetClientKey(), "bar/client-key.pem"); -} - -TEST(FederatedCommunicatorFactoryTest, ClientCert) { - Json config{JsonObject()}; - FederatedCommunicatorFactory factory0{config}; - EXPECT_EQ(factory0.GetClientCert(), ""); - - dmlc::SetEnv("FEDERATED_CLIENT_CERT", "foo/client-cert.pem"); - FederatedCommunicatorFactory factory1{config}; - EXPECT_EQ(factory1.GetClientCert(), "foo/client-cert.pem"); - - config["federated_client_cert"] = String("bar/client-cert.pem"); - FederatedCommunicatorFactory factory2{config}; - EXPECT_EQ(factory2.GetClientCert(), "bar/client-cert.pem"); -} - } // namespace collective } // namespace xgboost From 2039858016086553c91b40d8fdb39f2f81821da2 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Sep 2022 16:15:38 -0700 Subject: [PATCH 28/37] fix cpu build --- src/collective/communicator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index 47e7d1ac581a..af59f00b8fa5 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -52,7 +52,7 @@ void Communicator::Init(Json const& config) { } #ifndef XGBOOST_USE_CUDA -void CommunicatorFactory::Finalize() { communicator_.reset(); } +void Communicator::Finalize() { communicator_.reset(); } #endif } // namespace collective From c3a42fb9d19dc57baaaf81bcfd24fb9723329bf0 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Sep 2022 16:20:03 -0700 Subject: [PATCH 29/37] fix include --- src/collective/communicator.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 2e5b10e5afc7..8957ca74a73d 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -5,6 +5,7 @@ #include #include +#include #include namespace xgboost { From 0a115ff7ecf99949463616dc0e2c0ed0023adc8b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Sep 2022 16:23:23 -0700 Subject: [PATCH 30/37] fix python import --- python-package/xgboost/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 84bd40fb790f..c5d32a62beb3 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -5,8 +5,7 @@ from . import rabit # noqa from . import tracker # noqa -from . import collective -from . import dask +from . import collective, dask from .core import ( Booster, DataIter, From 63ae4e8161e910802361c7408d5368f7a15c60c7 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Sep 2022 16:46:13 -0700 Subject: [PATCH 31/37] don't export collective.py yet --- python-package/xgboost/__init__.py | 4 +--- tests/python/test_collective.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index c5d32a62beb3..6c29de98d9dc 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -5,7 +5,7 @@ from . import rabit # noqa from . import tracker # noqa -from . import collective, dask +from . import dask from .core import ( Booster, DataIter, @@ -63,6 +63,4 @@ "XGBRFRegressor", # dask "dask", - # collective - "collective", ] diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 01e2b01fb6e7..09a8323dc541 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -5,12 +5,12 @@ import xgboost as xgb from xgboost import RabitTracker +from xgboost import collective -def run_rabit_worker(rabit_env, world_size, rank): +def run_rabit_worker(rabit_env, world_size): with xgb.collective.CommunicatorContext(**rabit_env): assert xgb.collective.get_world_size() == world_size - assert xgb.collective.get_rank() == rank assert xgb.collective.is_distributed() assert xgb.collective.get_processor_name() == socket.gethostname() ret = xgb.collective.broadcast('test1234', 0) @@ -24,9 +24,9 @@ def test_rabit_communicator(): tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) tracker.start(world_size) workers = [] - for rank in reversed(range(world_size)): + for _ in range(world_size): worker = multiprocessing.Process(target=run_rabit_worker, - args=(tracker.worker_envs(), world_size, rank)) + args=(tracker.worker_envs(), world_size)) workers.append(worker) worker.start() for worker in workers: From fe79c38013d4cf3d76ff2901ec39cbd13ae4bc29 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Sep 2022 17:26:43 -0700 Subject: [PATCH 32/37] skip collective communicator pytest on windows --- tests/python/test_collective.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 09a8323dc541..1b9727ebf05b 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -1,12 +1,17 @@ import multiprocessing import socket +import sys import numpy as np +import pytest import xgboost as xgb from xgboost import RabitTracker from xgboost import collective +if sys.platform.startswith("win"): + pytest.skip("Skipping collective tests on Windows", allow_module_level=True) + def run_rabit_worker(rabit_env, world_size): with xgb.collective.CommunicatorContext(**rabit_env): From ed8fed20d319fbc05c15f9c2a70d58210eb0b65b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 6 Sep 2022 14:33:00 -0700 Subject: [PATCH 33/37] add review feedback --- include/xgboost/c_api.h | 3 +++ python-package/xgboost/collective.py | 6 +++--- src/collective/rabit_communicator.h | 3 +++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 3d63b93ac322..dc0aa54113c0 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1392,6 +1392,9 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config, /*! * \brief Initialize the collective communicator. * + * Currently the communicator API is experimental, function signatures may change in the future + * without notice. + * * Call this once before using anything. * * The additional configuration is not required. Usually the communicator will detect settings diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 829184970d32..8519085ce987 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -188,7 +188,7 @@ def broadcast(data: _T, root: int) -> _T: @unique class Op(IntEnum): - """Supported operations for rabit.""" + """Supported operations for allreduce.""" MAX = 0 MIN = 1 SUM = 2 @@ -204,7 +204,7 @@ def allreduce( # pylint:disable=invalid-name data : Input data. op : - Reduction operators, can be MAX or SUM + Reduction operator. Returns ------- @@ -216,7 +216,7 @@ def allreduce( # pylint:disable=invalid-name This function is not thread-safe. """ if not isinstance(data, np.ndarray): - raise Exception('allreduce only takes in numpy.ndarray') + raise TypeError('allreduce only takes in numpy.ndarray') buf = data.ravel() if buf.base is data.base: buf = buf.copy() diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 248bf3df555b..5708d1bbc4fe 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -100,6 +100,9 @@ class RabitCommunicator : public Communicator { case Operation::kMax: rabit::Allreduce(static_cast(send_receive_buffer), count); break; + case Operation::kMin: + rabit::Allreduce(static_cast(send_receive_buffer), count); + break; case Operation::kSum: rabit::Allreduce(static_cast(send_receive_buffer), count); break; From d418563ef0254afcd8ec69c40074437e1c9c26c8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 7 Sep 2022 13:44:34 -0700 Subject: [PATCH 34/37] update documentation --- include/xgboost/c_api.h | 82 +++++++++++++---------- plugin/federated/federated_communicator.h | 47 +++++++++++-- python-package/xgboost/collective.py | 2 +- src/collective/communicator.h | 3 +- src/collective/device_communicator.cuh | 16 +++++ 5 files changed, 110 insertions(+), 40 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index dc0aa54113c0..92d0e3c3bfae 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1427,82 +1427,96 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config, * - DMLC_ROLE: Role of the current task, "worker" or "server". * - DMLC_NUM_ATTEMPT: Number of attempts after task failure. * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. - * Only applicable to the Federated communicator (these are not case-sensitive): + * Only applicable to the Federated communicator (use upper case for environment variables, use + * lower case for runtime configuration): * - federated_server_address: Address of the federated server. * - federated_world_size: Number of federated workers. * - federated_rank: Rank of the current worker. * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. * - federated_client_key: Client key file path. Only needed for the SSL mode. * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. + * \return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorInit(char const* json_config); /*! - * \brief finalize the collective communicator, - * call this function after you finished all jobs. - * \return true if the communicator is finalized successfully otherwise false + * \brief Finalize the collective communicator. + * + * Call this function after you finished all jobs. + * + * \return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorFinalize(void); /*! - * \brief get rank of current process - * \return rank number of worker - * */ + * \brief Get rank of current process. + * + * \return Rank of the worker. + */ XGB_DLL int XGCommunicatorGetRank(void); /*! - * \brief get total number of process - * \return total world size - * */ + * \brief Get total number of processes. + * + * \return Total world size. + */ XGB_DLL int XGCommunicatorGetWorldSize(void); /*! - * \brief get if the communicator is distributed - * \return if the communicator is distributed - * */ + * \brief Get if the communicator is distributed. + * + * \return True if the communicator is distributed. + */ XGB_DLL int XGCommunicatorIsDistributed(void); /*! - * \brief print the msg to the communicator, - * this function can be used to communicate the information of the progress to - * the user who monitors the communicator - * \param message the message to be printed + * \brief Print the message to the communicator. + * + * This function can be used to communicate the information of the progress to the user who monitors + * the communicator. + * + * \param message The message to be printed. + * \return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorPrint(char const *message); /*! - * \brief get name of processor - * \param name_str pointer to received returned processor name. - * \return 0 for success, -1 for failure + * \brief Get the name of the processor. + * + * \param name_str Pointer to received returned processor name. + * \return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str); /*! - * \brief broadcast an memory region to all others from root + * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe. + * + * Example: + * int a = 1; + * Broadcast(&a, sizeof(a), root); * - * Example: int a = 1; Broadcast(&a, sizeof(a), root); - * \param send_receive_buffer the pointer to send or receive buffer, - * \param size the size of the data - * \param root the root of process + * \param send_receive_buffer Pointer to the send or receive buffer. + * \param size Size of the data. + * \param root The process rank to broadcast from. + * \return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root); /*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe + * \brief Perform in-place allreduce. This function is NOT thread-safe. * * Example Usage: the following code gives sum of the result * vector data(10); * ... - * Allreduce(&data[0], data.size()); + * Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); * ... - * \param send_receive_buffer buffer for both sending and receiving data - * \param count number of elements to be reduced - * \param enum_dtype the enumeration of data type, see xgboost::collective::DataType in communicator.h - * \param enum_op the enumeration of operation type, see xgboost::collective::Operation in communicator.h + * \param send_receive_buffer Buffer for both sending and receiving data. + * \param count Number of elements to be reduced. + * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. + * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. + * \return 0 for success, -1 for failure. */ -XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, - int enum_op); +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op); #endif // XGBOOST_C_API_H_ diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 1ae5cafba249..fa502aff6157 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -12,10 +12,15 @@ namespace xgboost { namespace collective { /** - * @brief A federated learning communicator class that handles collective communication. + * @brief A Federated Learning communicator class that handles collective communication. */ class FederatedCommunicator : public Communicator { public: + /** + * @brief Create a new communicator based on JSON configuration. + * @param config JSON configuration. + * @return Communicator as specified by the JSON configuration. + */ static Communicator *Create(Json const &config) { std::string server_address{}; int world_size{0}; @@ -92,8 +97,12 @@ class FederatedCommunicator : public Communicator { /** * @brief Construct a new federated communicator. * - * @param world_size Total number of processes. - * @param rank Rank of the current process. + * @param world_size Total number of processes. + * @param rank Rank of the current process. + * @param server_address Address of the federated server (host:port). + * @param server_cert_path Path to the server cert file. + * @param client_key_path Path to the client key file. + * @param client_cert_path Path to the client cert file. */ FederatedCommunicator(int world_size, int rank, std::string const &server_address, std::string const &server_cert_path, std::string const &client_key_path, @@ -108,7 +117,12 @@ class FederatedCommunicator : public Communicator { } } - /** @brief Insecure communicator for testing only. */ + /** + * @brief Construct an insecure federated communicator without using SSL. + * @param world_size Total number of processes. + * @param rank Rank of the current process. + * @param server_address Address of the federated server (host:port). + */ FederatedCommunicator(int world_size, int rank, std::string const &server_address) : Communicator{world_size, rank} { client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); @@ -116,8 +130,19 @@ class FederatedCommunicator : public Communicator { ~FederatedCommunicator() override { client_.reset(); } + /** + * \brief Get if the communicator is distributed. + * \return True. + */ bool IsDistributed() const override { return true; } + /** + * \brief Perform in-place allreduce. + * \param send_receive_buffer Buffer for both sending and receiving data. + * \param count Number of elements to be reduced. + * \param data_type Enumeration of data type. + * \param op Enumeration of operation type. + */ void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { std::string const send_buffer(reinterpret_cast(send_receive_buffer), @@ -128,6 +153,12 @@ class FederatedCommunicator : public Communicator { received.copy(reinterpret_cast(send_receive_buffer), count * GetTypeSize(data_type)); } + /** + * \brief Broadcast a memory region to all others from root. + * \param send_receive_buffer Pointer to the send or receive buffer. + * \param size Size of the data. + * \param root The process rank to broadcast from. + */ void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { if (GetWorldSize() == 1) return; if (GetRank() == root) { @@ -139,8 +170,16 @@ class FederatedCommunicator : public Communicator { } } + /** + * \brief Get the name of the processor. + * \return Name of the processor. + */ std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); } + /** + * \brief Print the message to the communicator. + * \param message The message to be printed. + */ void Print(const std::string &message) override { LOG(CONSOLE) << message; } private: diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 8519085ce987..161690a9d007 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -66,7 +66,7 @@ def init(**args: Any) -> None: def finalize() -> None: - """Finalize the process, notify tracker everything is done.""" + """Finalize the communicator.""" _check_call(_LIB.XGCommunicatorFinalize()) diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 8957ca74a73d..cc14cee5d0fc 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -23,6 +23,7 @@ enum class DataType { kDouble = 7 }; +/** @brief Get the size of the data type. */ inline std::size_t GetTypeSize(DataType data_type) { std::size_t size{0}; switch (data_type) { @@ -63,7 +64,7 @@ class DeviceCommunicator; enum class CommunicatorType { kUnknown, kRabit, kMPI, kFederated }; -/* \brief Case-insensitive string comparison */ +/** \brief Case-insensitive string comparison. */ inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { #ifdef _MSC_VER return _stricmp(s1, s2); diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh index 9efb978897ca..15d18cead02f 100644 --- a/src/collective/device_communicator.cuh +++ b/src/collective/device_communicator.cuh @@ -9,16 +9,32 @@ namespace xgboost { namespace collective { +/** + * @brief Collective communicator for device buffers. + */ class DeviceCommunicator { public: virtual ~DeviceCommunicator() = default; + /** + * @brief Sum values from all processes and distribute the result back to all processes. + * @param send_receive_buffer Buffer storing the data. + * @param count Number of elements in the buffer. + */ virtual void AllReduceSum(double *send_receive_buffer, int count) = 0; + /** + * @brief Gather variable-length values from all processes. + * @param send_buffer Buffer storing the input data. + * @param length_bytes Length in bytes of the input data. + * @param segments Size of each segment. + * @param receive_buffer Buffer storing the output data. + */ virtual void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, dh::caching_device_vector *receive_buffer) = 0; + /** @brief Synchronize device operations. */ virtual void Synchronize() = 0; }; From c4cf82c7604f003096f7a86d9c000b65c323273c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 7 Sep 2022 14:48:04 -0700 Subject: [PATCH 35/37] remove mpi communicator type --- python-package/xgboost/collective.py | 1 - src/collective/communicator.cc | 3 --- src/collective/communicator.h | 4 +--- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 161690a9d007..e4662d744e50 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -26,7 +26,6 @@ def init(**args: Any) -> None: - xgboost_communicator: The type of the communicator. Can be set as an environment variable. * rabit: Use Rabit. This is the default if the type is unspecified. - * mpi: Use MPI. * federated: Use the gRPC interface for Federated Learning. Only applicable to the Rabit communicator (these are case sensitive): -- rabit_tracker_uri: Hostname of the tracker. diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index af59f00b8fa5..dd21e3b699c4 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -35,9 +35,6 @@ void Communicator::Init(Json const& config) { communicator_.reset(RabitCommunicator::Create(config)); break; } - case CommunicatorType::kMPI: - LOG(FATAL) << "Not implemented yet."; - break; case CommunicatorType::kFederated: { #if defined(XGBOOST_USE_FEDERATED) communicator_.reset(FederatedCommunicator::Create(config)); diff --git a/src/collective/communicator.h b/src/collective/communicator.h index cc14cee5d0fc..196f03160cc6 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -62,7 +62,7 @@ enum class Operation { kMax = 0, kMin = 1, kSum = 2 }; class DeviceCommunicator; -enum class CommunicatorType { kUnknown, kRabit, kMPI, kFederated }; +enum class CommunicatorType { kUnknown, kRabit, kFederated }; /** \brief Case-insensitive string comparison. */ inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { @@ -190,8 +190,6 @@ class Communicator { CommunicatorType result = CommunicatorType::kUnknown; if (!CompareStringsCaseInsensitive("rabit", str)) { result = CommunicatorType::kRabit; - } else if (!CompareStringsCaseInsensitive("mpi", str)) { - result = CommunicatorType::kMPI; } else if (!CompareStringsCaseInsensitive("federated", str)) { result = CommunicatorType::kFederated; } else { From 5f0daa0ff7efabd6b6877a64a12f3bb62f4963b0 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 7 Sep 2022 14:56:06 -0700 Subject: [PATCH 36/37] fix tests --- tests/cpp/collective/test_communicator.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/cpp/collective/test_communicator.cc b/tests/cpp/collective/test_communicator.cc index 00c7fccfe582..e66e38255345 100644 --- a/tests/cpp/collective/test_communicator.cc +++ b/tests/cpp/collective/test_communicator.cc @@ -15,9 +15,6 @@ TEST(CommunicatorFactory, TypeFromEnv) { dmlc::SetEnv("XGBOOST_COMMUNICATOR", "rabit"); EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv()); - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "MPI"); - EXPECT_EQ(CommunicatorType::kMPI, Communicator::GetTypeFromEnv()); - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "Federated"); EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv()); @@ -32,9 +29,6 @@ TEST(CommunicatorFactory, TypeFromArgs) { config["xgboost_communicator"] = String("rabit"); EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); - config["xgboost_communicator"] = String("MPI"); - EXPECT_EQ(CommunicatorType::kMPI, Communicator::GetTypeFromConfig(config)); - config["xgboost_communicator"] = String("federated"); EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); @@ -49,9 +43,6 @@ TEST(CommunicatorFactory, TypeFromArgsUpperCase) { config["XGBOOST_COMMUNICATOR"] = String("rabit"); EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); - config["XGBOOST_COMMUNICATOR"] = String("MPI"); - EXPECT_EQ(CommunicatorType::kMPI, Communicator::GetTypeFromConfig(config)); - config["XGBOOST_COMMUNICATOR"] = String("federated"); EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); From cb5f4ad58b307a7082f132455e4d828005e401cc Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 9 Sep 2022 18:59:51 -0700 Subject: [PATCH 37/37] shutdown the communicator separately --- plugin/federated/federated_communicator.h | 3 +++ src/collective/communicator.cc | 5 ++++- src/collective/communicator.cu | 5 +++-- src/collective/communicator.h | 5 +++++ src/collective/rabit_communicator.h | 11 ++++++++--- 5 files changed, 23 insertions(+), 6 deletions(-) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index fa502aff6157..6a3186b4f608 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -182,6 +182,9 @@ class FederatedCommunicator : public Communicator { */ void Print(const std::string &message) override { LOG(CONSOLE) << message; } + protected: + void Shutdown() override {} + private: std::unique_ptr client_{}; }; diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index dd21e3b699c4..73765223b225 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -49,7 +49,10 @@ void Communicator::Init(Json const& config) { } #ifndef XGBOOST_USE_CUDA -void Communicator::Finalize() { communicator_.reset(); } +void Communicator::Finalize() { + communicator_->Shutdown(); + communicator_.reset(nullptr); +} #endif } // namespace collective diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu index 9495cf6ae5ba..2485000d9ad4 100644 --- a/src/collective/communicator.cu +++ b/src/collective/communicator.cu @@ -15,9 +15,10 @@ thread_local int Communicator::device_ordinal_{-1}; thread_local std::unique_ptr Communicator::device_communicator_{}; void Communicator::Finalize() { - communicator_.reset(); + communicator_->Shutdown(); + communicator_.reset(nullptr); device_ordinal_ = -1; - device_communicator_.reset(); + device_communicator_.reset(nullptr); } DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 196f03160cc6..14ee201618cd 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -185,6 +185,11 @@ class Communicator { } } + /** + * @brief Shuts down the communicator. + */ + virtual void Shutdown() = 0; + private: static CommunicatorType StringToType(char const *str) { CommunicatorType result = CommunicatorType::kUnknown; diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 5708d1bbc4fe..3b145e6577c5 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -43,14 +43,14 @@ class RabitCommunicator : public Communicator { for (auto &key_value : args_str) { args.push_back(&key_value[0]); } - rabit::Init(static_cast(args.size()), &args[0]); + if (!rabit::Init(static_cast(args.size()), &args[0])) { + LOG(FATAL) << "Failed to initialize Rabit"; + } return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank()); } RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} - ~RabitCommunicator() override { rabit::Finalize(); } - bool IsDistributed() const override { return rabit::IsDistributed(); } void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, @@ -93,6 +93,11 @@ class RabitCommunicator : public Communicator { void Print(const std::string &message) override { rabit::TrackerPrint(message); } + protected: + void Shutdown() override { + rabit::Finalize(); + } + private: template void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {