From 48fc02495bb30dcc8cf1cdf6647684b15e712f09 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 4 Apr 2022 11:34:43 -0700 Subject: [PATCH 01/21] add federated plugin --- CMakeLists.txt | 1 + plugin/CMakeLists.txt | 4 ++++ plugin/federated/CMakeLists.txt | 0 3 files changed, 5 insertions(+) create mode 100644 plugin/federated/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 11178c449119..f3e66f77997a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,7 @@ address, leak, undefined and thread.") ## Plugins option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF) option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF) +option(PLUGIN_FEDERATED "Build with Federated Learning" OFF) ## TODO: 1. Add check if DPC++ compiler is used for building option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF) option(ADD_PKGCONFIG "Add xgboost.pc into system." ON) diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index 97d1af190cfe..fdecf41e6e9f 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -40,3 +40,7 @@ if (PLUGIN_UPDATER_ONEAPI) # Add all objects of oneapi_plugin to objxgboost target_sources(objxgboost INTERFACE $) endif (PLUGIN_UPDATER_ONEAPI) + +if (PLUGIN_FEDERATED) + add_subdirectory(federated) +endif (PLUGIN_FEDERATED) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 From 5e1ff7eb7b0048b98c30da76728774d1c5707d6a Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 7 Apr 2022 18:31:16 -0700 Subject: [PATCH 02/21] add federation server and test client --- plugin/federated/CMakeLists.txt | 21 ++++ plugin/federated/README.md | 16 +++ plugin/federated/federation.proto | 36 ++++++ plugin/federated/federation_client.cc | 51 ++++++++ plugin/federated/federation_server.cc | 160 ++++++++++++++++++++++++++ 5 files changed, 284 insertions(+) create mode 100644 plugin/federated/README.md create mode 100644 plugin/federated/federation.proto create mode 100644 plugin/federated/federation_client.cc create mode 100644 plugin/federated/federation_server.cc diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index e69de29bb2d1..c17bc38a9d77 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -0,0 +1,21 @@ +find_package(protobuf CONFIG REQUIRED) +find_package(gRPC CONFIG REQUIRED) +find_package(Threads) + +add_library(federation_proto OBJECT federation.proto) +target_link_libraries(federation_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) +target_include_directories(federation_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) +protobuf_generate(TARGET federation_proto LANGUAGE cpp) +protobuf_generate( + TARGET federation_proto + LANGUAGE grpc + GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc + PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}") + +add_executable(federation_server federation_server.cc) +target_link_libraries(federation_server PRIVATE federation_proto) + +add_executable(federation_client federation_client.cc) +target_link_libraries(federation_client PRIVATE federation_proto) diff --git a/plugin/federated/README.md b/plugin/federated/README.md new file mode 100644 index 000000000000..dfe9e5592464 --- /dev/null +++ b/plugin/federated/README.md @@ -0,0 +1,16 @@ +XGBoost Plugin for Federated Learning +===================================== + +This folder contains the plugin for federated learning. + +Install gRPC +------------ +```shell +sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build +git clone -b v1.45.1 https://github.com/grpc/grpc +cd grpc +git submodule update --init +cmake -S . -B build -GNinja\ + -DABSL_PROPAGATE_CXX_STD=ON +cmake --build build --target install +``` diff --git a/plugin/federated/federation.proto b/plugin/federated/federation.proto new file mode 100644 index 000000000000..bdd73d46362d --- /dev/null +++ b/plugin/federated/federation.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package xgboost.federated; + +service Federation { + rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} +} + +enum DataType { + CHAR = 0; + UCHAR = 1; + INT = 2; + UINT = 3; + LONG = 4; + ULONG = 5; + FLOAT = 6; + DOUBLE = 7; + LONGLONG = 8; + ULONGLONG = 9; +} + +enum ReduceOperation { + MAX = 0; + MIN = 1; + SUM = 2; +} + +message AllreduceRequest { + bytes send_buffer = 1; + DataType data_type = 2; + ReduceOperation reduce_operation = 3; +} + +message AllreduceReply { + bytes receive_buffer = 1; +} diff --git a/plugin/federated/federation_client.cc b/plugin/federated/federation_client.cc new file mode 100644 index 000000000000..7497afce2199 --- /dev/null +++ b/plugin/federated/federation_client.cc @@ -0,0 +1,51 @@ +#include +#include +#include + +#include +#include +#include + +class FederationClient { + public: + explicit FederationClient(const std::shared_ptr &channel) + : stub_(xgboost::federated::Federation::NewStub(channel)) {} + + std::string Allreduce(const std::string &send_buffer) { + xgboost::federated::AllreduceRequest request; + request.set_send_buffer(send_buffer); + request.set_data_type(xgboost::federated::INT); + request.set_reduce_operation(xgboost::federated::SUM); + + xgboost::federated::AllreduceReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->Allreduce(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char **argv) { + FederationClient client( + grpc::CreateChannel("localhost:50051", grpc::InsecureChannelCredentials())); + + for (int i = 1; i <= 10; i++) { + int data[] = {1 * i, 2 * i, 3 * i, 4 * i, 5 * i}; + int n = sizeof(data) / sizeof(data[0]); + std::string send_buffer(reinterpret_cast(data), sizeof(data)); + std::string receive_buffer = client.Allreduce(send_buffer); + int *result = reinterpret_cast(receive_buffer.data()); + std::copy(result, result + n, std::ostream_iterator(std::cout, " ")); + std::cout << '\n'; + } + + return 0; +} diff --git a/plugin/federated/federation_server.cc b/plugin/federated/federation_server.cc new file mode 100644 index 000000000000..07ad5b205393 --- /dev/null +++ b/plugin/federated/federation_server.cc @@ -0,0 +1,160 @@ +#include +#include +#include + +#include +#include + +namespace xgboost::federated { + +class FederationService final : public Federation::Service { + public: + explicit FederationService(int const world_size) : world_size_(world_size) {} + + grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, + AllreduceReply* reply) override { + // Pass through if there is only 1 client. + if (world_size_ == 1) { + reply->set_receive_buffer(request->send_buffer()); + return grpc::Status::OK; + } + + std::unique_lock lock(mutex_); + + // Wait for all previous replies have been sent. + cv_.wait(lock, [this] { return sent_ == 0; }); + + if (received_ == 0) { + // Copy the send_buffer if this is the first client. + buffer_ = request->send_buffer(); + } else { + // Accumulate the send_buffer into the common buffer. + Accumulate(request->send_buffer(), request->data_type(), request->reduce_operation()); + } + received_++; + // If all clients have been received, send the reply and notify all. + if (received_ == world_size_) { + received_ = 0; + sent_++; + reply->set_receive_buffer(buffer_); + lock.unlock(); + cv_.notify_all(); + return grpc::Status::OK; + } + + // Wait for all the clients to be received. + cv_.wait(lock, [this] { return received_ == 0; }); + sent_++; + reply->set_receive_buffer(buffer_); + if (sent_ == world_size_) { + sent_ = 0; + lock.unlock(); + cv_.notify_all(); + } + return grpc::Status::OK; + } + + private: + template + void Accumulate(T* buffer, T const* input, std::size_t n, ReduceOperation reduce_operation) { + switch (reduce_operation) { + case ReduceOperation::MAX: + std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); }); + break; + case ReduceOperation::MIN: + std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); }); + break; + case ReduceOperation::SUM: + std::transform(buffer, buffer + n, input, buffer, std::plus()); + break; + default: + throw std::invalid_argument("Invalid reduce operation"); + } + } + + void Accumulate(std::string const& input, DataType data_type, ReduceOperation reduce_operation) { + switch (data_type) { + case DataType::CHAR: + Accumulate(buffer_.data(), input.data(), buffer_.size(), reduce_operation); + break; + case DataType::UCHAR: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), buffer_.size(), + reduce_operation); + break; + case DataType::INT: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), buffer_.size() / sizeof(int), + reduce_operation); + break; + case DataType::UINT: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), + buffer_.size() / sizeof(unsigned int), reduce_operation); + break; + case DataType::LONG: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), buffer_.size() / sizeof(long), + reduce_operation); + break; + case DataType::ULONG: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), + buffer_.size() / sizeof(unsigned long), reduce_operation); + break; + case DataType::FLOAT: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), buffer_.size() / sizeof(float), + reduce_operation); + break; + case DataType::DOUBLE: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), buffer_.size() / sizeof(double), + reduce_operation); + break; + case DataType::LONGLONG: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), + buffer_.size() / sizeof(long long), reduce_operation); + break; + case DataType::ULONGLONG: + Accumulate(reinterpret_cast(buffer_.data()), + reinterpret_cast(input.data()), + buffer_.size() / sizeof(unsigned long long), reduce_operation); + break; + default: + throw std::invalid_argument("Invalid data type"); + } + } + + int const world_size_; + int received_{}; + int sent_{}; + std::string buffer_{}; + mutable std::mutex mutex_; + mutable std::condition_variable cv_; +}; + +void RunServer(int world_size) { + std::string const server_address{"0.0.0.0:50051"}; + FederationService service{world_size}; + + grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << " with world size " << world_size + << '\n'; + + server->Wait(); +} +} // namespace xgboost::federated + +int main(int argc, char** argv) { + auto world_size{1}; + if (argc > 1) { + world_size = std::stoi(argv[1]); + } + xgboost::federated::RunServer(world_size); + return 0; +} From cd52ceb9fb079100705d3cd5d4ac37cd80a59b6d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 7 Apr 2022 18:37:22 -0700 Subject: [PATCH 03/21] minor cleanup --- plugin/federated/federation_client.cc | 17 ++++++++++------- plugin/federated/federation_server.cc | 4 +++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/plugin/federated/federation_client.cc b/plugin/federated/federation_client.cc index 7497afce2199..d2bc5b37a7eb 100644 --- a/plugin/federated/federation_client.cc +++ b/plugin/federated/federation_client.cc @@ -6,18 +6,20 @@ #include #include +namespace xgboost::federated { + class FederationClient { public: explicit FederationClient(const std::shared_ptr &channel) - : stub_(xgboost::federated::Federation::NewStub(channel)) {} + : stub_(Federation::NewStub(channel)) {} std::string Allreduce(const std::string &send_buffer) { - xgboost::federated::AllreduceRequest request; + AllreduceRequest request; request.set_send_buffer(send_buffer); - request.set_data_type(xgboost::federated::INT); - request.set_reduce_operation(xgboost::federated::SUM); + request.set_data_type(DataType::INT); + request.set_reduce_operation(ReduceOperation::SUM); - xgboost::federated::AllreduceReply reply; + AllreduceReply reply; grpc::ClientContext context; grpc::Status status = stub_->Allreduce(&context, request, &reply); @@ -30,11 +32,12 @@ class FederationClient { } private: - std::unique_ptr stub_; + std::unique_ptr stub_; }; +} // namespace xgboost::federated int main(int argc, char **argv) { - FederationClient client( + xgboost::federated::FederationClient client( grpc::CreateChannel("localhost:50051", grpc::InsecureChannelCredentials())); for (int i = 1; i <= 10; i++) { diff --git a/plugin/federated/federation_server.cc b/plugin/federated/federation_server.cc index 07ad5b205393..edcbd78c607a 100644 --- a/plugin/federated/federation_server.cc +++ b/plugin/federated/federation_server.cc @@ -22,7 +22,9 @@ class FederationService final : public Federation::Service { std::unique_lock lock(mutex_); // Wait for all previous replies have been sent. - cv_.wait(lock, [this] { return sent_ == 0; }); + if (sent_ != 0) { + cv_.wait(lock, [this] { return sent_ == 0; }); + } if (received_ == 0) { // Copy the send_buffer if this is the first client. From b63395a30eda7f493de22568fa25e962efcdc5b6 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 15 Apr 2022 19:31:21 -0700 Subject: [PATCH 04/21] implemented allreduce/allgather/broadcast --- plugin/federated/CMakeLists.txt | 25 +- plugin/federated/engine_federated.cc | 285 ++++++++++++++++++ plugin/federated/federated.proto | 58 ++++ plugin/federated/federated_client.h | 81 +++++ ...deration_server.cc => federated_server.cc} | 143 +++++++-- plugin/federated/federation.proto | 36 --- plugin/federated/federation_client.cc | 54 ---- plugin/federated/test_client.cc | 41 +++ rabit/CMakeLists.txt | 3 +- tests/distributed/runtests-federated.sh | 21 ++ 10 files changed, 623 insertions(+), 124 deletions(-) create mode 100644 plugin/federated/engine_federated.cc create mode 100644 plugin/federated/federated.proto create mode 100644 plugin/federated/federated_client.h rename plugin/federated/{federation_server.cc => federated_server.cc} (53%) delete mode 100644 plugin/federated/federation.proto delete mode 100644 plugin/federated/federation_client.cc create mode 100644 plugin/federated/test_client.cc create mode 100755 tests/distributed/runtests-federated.sh diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index c17bc38a9d77..789c181cf441 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -2,20 +2,27 @@ find_package(protobuf CONFIG REQUIRED) find_package(gRPC CONFIG REQUIRED) find_package(Threads) -add_library(federation_proto OBJECT federation.proto) -target_link_libraries(federation_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) -target_include_directories(federation_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +add_library(federated_proto federated.proto) +target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) +target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON) get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) -protobuf_generate(TARGET federation_proto LANGUAGE cpp) +protobuf_generate(TARGET federated_proto LANGUAGE cpp) protobuf_generate( - TARGET federation_proto + TARGET federated_proto LANGUAGE grpc GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}") -add_executable(federation_server federation_server.cc) -target_link_libraries(federation_server PRIVATE federation_proto) +add_library(federated_client INTERFACE federated_client.h) +target_link_libraries(federated_client INTERFACE federated_proto) -add_executable(federation_client federation_client.cc) -target_link_libraries(federation_client PRIVATE federation_proto) +add_executable(federated_server federated_server.cc) +target_link_libraries(federated_server PRIVATE federated_proto) + +add_executable(test_client test_client.cc) +target_link_libraries(test_client PRIVATE federated_client) + +target_sources(objxgboost PRIVATE engine_federated.cc) +target_link_libraries(objxgboost PRIVATE federated_client) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc new file mode 100644 index 000000000000..bcf3b9bbb4d1 --- /dev/null +++ b/plugin/federated/engine_federated.cc @@ -0,0 +1,285 @@ +#define NOMINMAX +#include +#include + +#include +#include +#include + +#include "federated_client.h" +#include "rabit/internal/engine.h" +#include "rabit/internal/utils.h" + +namespace rabit { +namespace engine { + +/*! \brief implementation of engine using federated learning */ +class FederatedEngine : public IEngine { + public: + void Init(int argc, char *argv[]) { + // Parse environment variables first. + for (auto const &env_var : env_vars_) { + char const *value = getenv(env_var.c_str()); + if (value != nullptr) { + SetParam(env_var, value); + } + } + // Command line argument override. + for (int i = 0; i < argc; ++i) { + std::string key_value = argv[i]; + auto delimiter = key_value.find('='); + if (delimiter != std::string::npos) { + SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1)); + } + } + utils::Printf("Connecting to federated server %s, world size %d, rank %d", + server_address_.c_str(), world_size_, rank_); + client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_)); + } + + void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, + size_t size_prev_slice) override { + throw std::logic_error("FederatedEngine:: Allgather is not supported"); + } + + void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, + PreprocFunction prepare_fun, void *prepare_arg) override { + throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead"); + } + + // NOLINTNEXTLINE(readability-identifier-naming) + void Allreduce_(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { + auto *buffer = reinterpret_cast(sendrecvbuf); + std::string send_buffer(buffer, size); + auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op)); + receive_buffer.copy(buffer, size); + } + + int GetRingPrevRank() const override { + throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported"); + } + + void Broadcast(void *sendrecvbuf, size_t size, int root) override { + auto *buffer = reinterpret_cast(sendrecvbuf); + std::string send_buffer(buffer, size); + auto const receive_buffer = client_->Broadcast(send_buffer, root); + if (rank_ != root) { + receive_buffer.copy(buffer, size); + } + } + + int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override { + return 0; + } + + void CheckPoint(const Serializable *global_model, + const Serializable *local_model = nullptr) override { + version_number_ += 1; + } + + void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; } + + int VersionNumber() const override { return version_number_; } + + /*! \brief get rank of current node */ + int GetRank() const override { return rank_; } + + /*! \brief get total number of */ + int GetWorldSize() const override { return world_size_; } + + /*! \brief whether it is distributed */ + bool IsDistributed() const override { return true; } + + /*! \brief get the host name of current node */ + std::string GetHost() const override { return "rank" + std::to_string(rank_); } + + void TrackerPrint(const std::string &msg) override { + // simply print information into the tracker + if (GetRank() == 0) { + utils::Printf("%s", msg.c_str()); + } + } + + private: + /** @brief Transform mpi::DataType to xgboost::federated::DataType. */ + static xgboost::federated::DataType GetDataType(mpi::DataType data_type) { + switch (data_type) { + case mpi::kChar: + return xgboost::federated::CHAR; + case mpi::kUChar: + return xgboost::federated::UCHAR; + case mpi::kInt: + return xgboost::federated::INT; + case mpi::kUInt: + return xgboost::federated::UINT; + case mpi::kLong: + return xgboost::federated::LONG; + case mpi::kULong: + return xgboost::federated::ULONG; + case mpi::kFloat: + return xgboost::federated::FLOAT; + case mpi::kDouble: + return xgboost::federated::DOUBLE; + case mpi::kLongLong: + return xgboost::federated::LONGLONG; + case mpi::kULongLong: + return xgboost::federated::ULONGLONG; + } + utils::Error("unknown mpi::DataType"); + return xgboost::federated::CHAR; + } + + /** @brief Transform mpi::OpType to enum to MPI OP */ + static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) { + switch (op_type) { + case mpi::kMax: + return xgboost::federated::MAX; + case mpi::kMin: + return xgboost::federated::MIN; + case mpi::kSum: + return xgboost::federated::SUM; + case mpi::kBitwiseOR: + utils::Error("Bitwise OR is not supported"); + return xgboost::federated::MAX; + } + utils::Error("unknown mpi::OpType"); + return xgboost::federated::MAX; + } + + void SetParam(std::string const &name, std::string const &val) { + if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { + server_address_ = val; + } else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) { + world_size_ = std::stoi(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { + rank_ = std::stoi(val); + } + } + + std::vector const env_vars_{"FEDERATED_SERVER_ADDRESS", "FEDERATED_WORLD_SIZE", + "FEDERATED_RANK"}; + std::string server_address_{"localhost:9091"}; + int world_size_{1}; + int rank_{0}; + std::unique_ptr client_{}; + int version_number_{0}; +}; + +// Singleton federated engine. +FederatedEngine engine; // NOLINT(cert-err58-cpp) + +/*! \brief initialize the synchronization module */ +bool Init(int argc, char *argv[]) { + try { + engine.Init(argc, argv); + return true; + } catch (std::exception const &e) { + fprintf(stderr, " failed in Federated Init %s\n", e.what()); + return false; + } +} + +/*! \brief finalize synchronization module */ +bool Finalize() { return true; } + +/*! \brief singleton method to get engine */ +IEngine *GetEngine() { return &engine; } + +// perform in-place allreduce, on sendrecvbuf +void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red, + mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, + void *prepare_arg) { + if (prepare_fun != nullptr) prepare_fun(prepare_arg); + engine.Allreduce_(sendrecvbuf, type_nbytes * count, dtype, op); +} + +// code for reduce handle +ReduceHandle::ReduceHandle(void) : handle_(NULL), redfunc_(NULL), htype_(NULL) {} + +ReduceHandle::~ReduceHandle(void) { + /* !WARNING! + + A handle can be held by a tree method/Learner from xgboost. The booster might not be + freed until program exit, while (good) users call rabit.finalize() before reaching + the end of program. So op->Free() might be called after finalization and results + into following error: + + ``` + Attempting to use an MPI routine after finalizing MPICH + ``` + + Here we skip calling Free if MPI has already been finalized to workaround the issue. + It can be a potential leak of memory. The best way to resolve it is to eliminate all + use of long living handle. + */ + int finalized = 0; + // CHECK_EQ(MPI_Finalized(&finalized), MPI_SUCCESS); + if (handle_ != NULL) { + // MPI::Op *op = reinterpret_cast(handle_); + // if (!finalized) { + // op->Free(); + // } + // delete op; + } + if (htype_ != NULL) { + MPI::Datatype *dtype = reinterpret_cast(htype_); + if (!finalized) { + // dtype->Free(); + } + // delete dtype; + } +} + +int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { + return 0; + // return dtype.Get_size(); +} + +void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { + utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice"); + if (type_nbytes != 0) { + // MPI::Datatype *dtype = new MPI::Datatype(); + // if (type_nbytes % 8 == 0) { + // *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*) + // } else if (type_nbytes % 4 == 0) { + // *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); + // } else { + // *dtype = MPI::CHAR.Create_contiguous(type_nbytes); + // } + // dtype->Commit(); + created_type_nbytes_ = type_nbytes; + // htype_ = dtype; + } + // MPI::Op *op = new MPI::Op(); + // MPI::User_function *pf = redfunc; + // op->Init(pf, true); + // handle_ = op; +} + +void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, + IEngine::PreprocFunction prepare_fun, void *prepare_arg) { + utils::Assert(handle_ != NULL, "must initialize handle to call AllReduce"); + // MPI::Op *op = reinterpret_cast(handle_); + // MPI::Datatype *dtype = reinterpret_cast(htype_); + // if (created_type_nbytes_ != type_nbytes || dtype == NULL) { + // if (dtype == NULL) { + // dtype = new MPI::Datatype(); + // } else { + // dtype->Free(); + // } + // if (type_nbytes % 8 == 0) { + // *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*) + // } else if (type_nbytes % 4 == 0) { + // *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); + // } else { + // *dtype = MPI::CHAR.Create_contiguous(type_nbytes); + // } + // dtype->Commit(); + // created_type_nbytes_ = type_nbytes; + // } + // if (prepare_fun != NULL) prepare_fun(prepare_arg); + // MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op); +} + +} // namespace engine +} // namespace rabit diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto new file mode 100644 index 000000000000..fc09b0f5c73e --- /dev/null +++ b/plugin/federated/federated.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; + +package xgboost.federated; + +service Federated { + rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} + rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} + rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} +} + +enum DataType { + CHAR = 0; + UCHAR = 1; + INT = 2; + UINT = 3; + LONG = 4; + ULONG = 5; + FLOAT = 6; + DOUBLE = 7; + LONGLONG = 8; + ULONGLONG = 9; +} + +enum ReduceOperation { + MAX = 0; + MIN = 1; + SUM = 2; +} + +message AllgatherRequest { + int32 rank = 1; + bytes send_buffer = 2; +} + +message AllgatherReply { + bytes receive_buffer = 1; +} + +message AllreduceRequest { + int32 rank = 1; + bytes send_buffer = 2; + DataType data_type = 3; + ReduceOperation reduce_operation = 4; +} + +message AllreduceReply { + bytes receive_buffer = 1; +} + +message BroadcastRequest { + int32 rank = 1; + bytes send_buffer = 2; + int32 root = 3; +} + +message BroadcastReply { + bytes receive_buffer = 1; +} diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h new file mode 100644 index 000000000000..5bfc6afb445e --- /dev/null +++ b/plugin/federated/federated_client.h @@ -0,0 +1,81 @@ +#pragma once +#include +#include +#include + +#include +#include +#include + +namespace xgboost { +namespace federated { + +class FederatedClient { + public: + explicit FederatedClient(std::string const &server_address, int rank) + : stub_{Federated::NewStub( + grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))}, + rank_{rank} {} + + std::string Allgather(std::string const &send_buffer) { + AllgatherRequest request; + request.set_rank(rank_); + request.set_send_buffer(send_buffer); + + AllgatherReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->Allgather(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("Allgather RPC failed"); + } + } + + std::string Allreduce(std::string const &send_buffer, DataType data_type, + ReduceOperation reduce_operation) { + AllreduceRequest request; + request.set_rank(rank_); + request.set_send_buffer(send_buffer); + request.set_data_type(data_type); + request.set_reduce_operation(reduce_operation); + + AllreduceReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->Allreduce(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("Allreduce RPC failed"); + } + } + + std::string Broadcast(std::string const &send_buffer, int root) { + BroadcastRequest request; + request.set_rank(rank_); + request.set_send_buffer(send_buffer); + request.set_root(root); + + BroadcastReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->Broadcast(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("Broadcast RPC failed"); + } + } + + private: + std::unique_ptr const stub_; + int const rank_; +}; + +} // namespace federated +} // namespace xgboost diff --git a/plugin/federated/federation_server.cc b/plugin/federated/federated_server.cc similarity index 53% rename from plugin/federated/federation_server.cc rename to plugin/federated/federated_server.cc index edcbd78c607a..7dc693c62fa3 100644 --- a/plugin/federated/federation_server.cc +++ b/plugin/federated/federated_server.cc @@ -1,15 +1,65 @@ -#include -#include +#include +#include #include #include #include +#include namespace xgboost::federated { -class FederationService final : public Federation::Service { +class FederatedService final : public Federated::Service { public: - explicit FederationService(int const world_size) : world_size_(world_size) {} + explicit FederatedService(int const world_size) : world_size_(world_size) {} + + grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, + AllgatherReply* reply) override { + // Pass through if there is only 1 client. + if (world_size_ == 1) { + reply->set_receive_buffer(request->send_buffer()); + return grpc::Status::OK; + } + + std::unique_lock lock(mutex_); + + auto const rank = request->rank(); + auto const& send_buffer = request->send_buffer(); + auto const buffer_size = send_buffer.size(); + + if (received_ == 0) { + std::cout << "Allgather rank " << rank << ": first request, resizing buffer\n"; + buffer_.resize(buffer_size * world_size_); + } + std::cout << "Allgather rank " << rank << ": copying send buffer into common buffer\n"; + buffer_.replace(rank * buffer_size, buffer_size, send_buffer); + received_++; + std::cout << "Allgather rank " << rank << ": received=" << received_ << '\n'; + + if (received_ == world_size_) { + std::cout << "Allgather rank " << rank << ": all requests received, sending reply\n"; + reply->set_receive_buffer(buffer_); + sent_++; + lock.unlock(); + cv_.notify_all(); + return grpc::Status::OK; + } + + std::cout << "Allgather rank " << rank << ": waiting for all clients\n"; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + std::cout << "Allgather rank " << rank << ": sending reply\n"; + reply->set_receive_buffer(buffer_); + sent_++; + + if (sent_ == world_size_) { + std::cout << "Allgather rank " << rank << ": all replies sent\n"; + sent_ = 0; + received_ = 0; + buffer_.clear(); + } + + return grpc::Status::OK; + } grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) override { @@ -21,38 +71,81 @@ class FederationService final : public Federation::Service { std::unique_lock lock(mutex_); - // Wait for all previous replies have been sent. - if (sent_ != 0) { - cv_.wait(lock, [this] { return sent_ == 0; }); - } + auto const rank = request->rank(); if (received_ == 0) { - // Copy the send_buffer if this is the first client. + std::cout << "Allreduce rank " << rank << ": first request, copying send buffer\n"; buffer_ = request->send_buffer(); } else { - // Accumulate the send_buffer into the common buffer. + std::cout << "Allreduce rank " << rank << ": accumulating send buffer into common buffer\n"; Accumulate(request->send_buffer(), request->data_type(), request->reduce_operation()); } received_++; - // If all clients have been received, send the reply and notify all. + if (received_ == world_size_) { - received_ = 0; - sent_++; + std::cout << "Allreduce rank " << rank << ": all requests received\n"; reply->set_receive_buffer(buffer_); + sent_++; lock.unlock(); cv_.notify_all(); return grpc::Status::OK; } - // Wait for all the clients to be received. - cv_.wait(lock, [this] { return received_ == 0; }); - sent_++; + std::cout << "Allreduce rank " << rank << ": waiting for all clients\n"; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + std::cout << "Allreduce rank " << rank << ": sending reply\n"; reply->set_receive_buffer(buffer_); + sent_++; + if (sent_ == world_size_) { + std::cout << "Allreduce rank " << rank << ": all replies sent\n"; sent_ = 0; + received_ = 0; + buffer_.clear(); + } + + return grpc::Status::OK; + } + + grpc::Status Broadcast(grpc::ServerContext* context, + xgboost::federated::BroadcastRequest const* request, + xgboost::federated::BroadcastReply* reply) override { + // Pass through if there is only 1 client. + if (world_size_ == 1) { + reply->set_receive_buffer(request->send_buffer()); + return grpc::Status::OK; + } + + std::unique_lock lock(mutex_); + + auto const rank = request->rank(); + + if (request->rank() == request->root()) { + std::cout << "Broadcast rank " << rank << ": root copying send buffer to common buffer\n"; + buffer_ = request->send_buffer(); + received_ = world_size_; + reply->set_receive_buffer(buffer_); + sent_++; lock.unlock(); cv_.notify_all(); + return grpc::Status::OK; } + + std::cout << "Broadcast rank " << rank << ": waiting for the root\n"; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + std::cout << "Broadcast rank " << rank << ": sending reply\n"; + reply->set_receive_buffer(buffer_); + sent_++; + + if (sent_ == world_size_) { + std::cout << "Broadcast rank " << rank << ": all replies sent\n"; + sent_ = 0; + received_ = 0; + buffer_.clear(); + } + return grpc::Status::OK; } @@ -137,15 +230,15 @@ class FederationService final : public Federation::Service { mutable std::condition_variable cv_; }; -void RunServer(int world_size) { - std::string const server_address{"0.0.0.0:50051"}; - FederationService service{world_size}; +void RunServer(int port, int world_size) { + std::string const server_address = "0.0.0.0:" + std::to_string(port); + FederatedService service{world_size}; grpc::ServerBuilder builder; builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); - std::cout << "Server listening on " << server_address << " with world size " << world_size + std::cout << "Federated server listening on " << server_address << ", world size " << world_size << '\n'; server->Wait(); @@ -153,10 +246,12 @@ void RunServer(int world_size) { } // namespace xgboost::federated int main(int argc, char** argv) { - auto world_size{1}; - if (argc > 1) { - world_size = std::stoi(argv[1]); + if (argc != 3) { + std::cerr << "Usage: federated_server port world_size" << '\n'; + return 1; } - xgboost::federated::RunServer(world_size); + auto port = std::stoi(argv[1]); + auto world_size = std::stoi(argv[2]); + xgboost::federated::RunServer(port, world_size); return 0; } diff --git a/plugin/federated/federation.proto b/plugin/federated/federation.proto deleted file mode 100644 index bdd73d46362d..000000000000 --- a/plugin/federated/federation.proto +++ /dev/null @@ -1,36 +0,0 @@ -syntax = "proto3"; - -package xgboost.federated; - -service Federation { - rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} -} - -enum DataType { - CHAR = 0; - UCHAR = 1; - INT = 2; - UINT = 3; - LONG = 4; - ULONG = 5; - FLOAT = 6; - DOUBLE = 7; - LONGLONG = 8; - ULONGLONG = 9; -} - -enum ReduceOperation { - MAX = 0; - MIN = 1; - SUM = 2; -} - -message AllreduceRequest { - bytes send_buffer = 1; - DataType data_type = 2; - ReduceOperation reduce_operation = 3; -} - -message AllreduceReply { - bytes receive_buffer = 1; -} diff --git a/plugin/federated/federation_client.cc b/plugin/federated/federation_client.cc deleted file mode 100644 index d2bc5b37a7eb..000000000000 --- a/plugin/federated/federation_client.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include -#include -#include - -#include -#include -#include - -namespace xgboost::federated { - -class FederationClient { - public: - explicit FederationClient(const std::shared_ptr &channel) - : stub_(Federation::NewStub(channel)) {} - - std::string Allreduce(const std::string &send_buffer) { - AllreduceRequest request; - request.set_send_buffer(send_buffer); - request.set_data_type(DataType::INT); - request.set_reduce_operation(ReduceOperation::SUM); - - AllreduceReply reply; - grpc::ClientContext context; - grpc::Status status = stub_->Allreduce(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << std::endl; - return "RPC failed"; - } - } - - private: - std::unique_ptr stub_; -}; -} // namespace xgboost::federated - -int main(int argc, char **argv) { - xgboost::federated::FederationClient client( - grpc::CreateChannel("localhost:50051", grpc::InsecureChannelCredentials())); - - for (int i = 1; i <= 10; i++) { - int data[] = {1 * i, 2 * i, 3 * i, 4 * i, 5 * i}; - int n = sizeof(data) / sizeof(data[0]); - std::string send_buffer(reinterpret_cast(data), sizeof(data)); - std::string receive_buffer = client.Allreduce(send_buffer); - int *result = reinterpret_cast(receive_buffer.data()); - std::copy(result, result + n, std::ostream_iterator(std::cout, " ")); - std::cout << '\n'; - } - - return 0; -} diff --git a/plugin/federated/test_client.cc b/plugin/federated/test_client.cc new file mode 100644 index 000000000000..701432da72a9 --- /dev/null +++ b/plugin/federated/test_client.cc @@ -0,0 +1,41 @@ +#include + +#include "federated_client.h" + +int main(int argc, char **argv) { + if (argc != 3) { + std::cerr << "Usage: federated_client server_address(host:port) rank" << '\n'; + return 1; + } + auto const server_address = argv[1]; + auto const rank = std::stoi(argv[2]); + xgboost::federated::FederatedClient client(server_address, rank); + + for (int i = 1; i <= 10; i++) { + // Allgather. + std::string allgather_send = "hello " + std::to_string(rank) + ":" + std::to_string(i) + " "; + auto const allgather_receive = client.Allgather(allgather_send); + std::cout << "Allgather rank " << rank << ": " << allgather_receive << '\n'; + + // Allreduce. + int data[] = {1 * i, 2 * i, 3 * i, 4 * i, 5 * i}; + int n = sizeof(data) / sizeof(data[0]); + std::string send_buffer(reinterpret_cast(data), sizeof(data)); + auto receive_buffer = + client.Allreduce(send_buffer, xgboost::federated::INT, xgboost::federated::SUM); + auto *result = reinterpret_cast(receive_buffer.data()); + std::cout << "Allreduce rank " << rank << ": "; + std::copy(result, result + n, std::ostream_iterator(std::cout, " ")); + std::cout << '\n'; + + // Broadcast. + std::string broadcast_send{}; + if (rank == 0) { + broadcast_send = "hello " + std::to_string(i); + } + auto const broadcast_receive = client.Broadcast(broadcast_send, 0); + std::cout << "Broadcast rank " << rank << ": " << broadcast_receive << '\n'; + } + + return 0; +} diff --git a/rabit/CMakeLists.txt b/rabit/CMakeLists.txt index ad39fb249791..594b6f15dda8 100644 --- a/rabit/CMakeLists.txt +++ b/rabit/CMakeLists.txt @@ -6,7 +6,8 @@ set(RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc ${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc) -if (RABIT_BUILD_MPI) +if (PLUGIN_FEDERATED) +elseif (RABIT_BUILD_MPI) list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc) elseif (RABIT_MOCK) list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc) diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh new file mode 100755 index 000000000000..7f757aa7084a --- /dev/null +++ b/tests/distributed/runtests-federated.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +trap "kill 0" EXIT + +rm -f ./*.model* + +port=9091 +world_size=3 + +../../build/plugin/federated/federated_server ${port} ${world_size} & + +export FEDERATED_SERVER_ADDRESS="localhost:${port}" +export FEDERATED_WORLD_SIZE=${world_size} +for ((rank = 0; rank < world_size; rank++)); do + FEDERATED_RANK=${rank} python test_basic.py & + pids[${rank}]=$! +done + +for pid in ${pids[*]}; do + wait $pid +done From 66bf977fbfd5adb66d28255ac26f8ecf1226e7a5 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 15 Apr 2022 20:05:39 -0700 Subject: [PATCH 05/21] refactor federated server --- plugin/federated/federated_server.cc | 181 ++++++++++++--------------- 1 file changed, 80 insertions(+), 101 deletions(-) diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index 7dc693c62fa3..f2d40042a61f 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -8,12 +8,13 @@ namespace xgboost::federated { -class FederatedService final : public Federated::Service { +template +class Operation { public: - explicit FederatedService(int const world_size) : world_size_(world_size) {} + Operation(std::string name, int const world_size) + : name_{std::move(name)}, world_size_{world_size} {} - grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, - AllgatherReply* reply) override { + grpc::Status Operate(Request const* request, Reply* reply) { // Pass through if there is only 1 client. if (world_size_ == 1) { reply->set_receive_buffer(request->send_buffer()); @@ -23,20 +24,13 @@ class FederatedService final : public Federated::Service { std::unique_lock lock(mutex_); auto const rank = request->rank(); - auto const& send_buffer = request->send_buffer(); - auto const buffer_size = send_buffer.size(); - if (received_ == 0) { - std::cout << "Allgather rank " << rank << ": first request, resizing buffer\n"; - buffer_.resize(buffer_size * world_size_); - } - std::cout << "Allgather rank " << rank << ": copying send buffer into common buffer\n"; - buffer_.replace(rank * buffer_size, buffer_size, send_buffer); + std::cout << name_ << " rank " << rank << ": on request\n"; + OnRequest(request); received_++; - std::cout << "Allgather rank " << rank << ": received=" << received_ << '\n'; if (received_ == world_size_) { - std::cout << "Allgather rank " << rank << ": all requests received, sending reply\n"; + std::cout << name_ << " rank " << rank << ": all requests received\n"; reply->set_receive_buffer(buffer_); sent_++; lock.unlock(); @@ -44,15 +38,15 @@ class FederatedService final : public Federated::Service { return grpc::Status::OK; } - std::cout << "Allgather rank " << rank << ": waiting for all clients\n"; + std::cout << name_ << " rank " << rank << ": waiting for all clients\n"; cv_.wait(lock, [this] { return received_ == world_size_; }); - std::cout << "Allgather rank " << rank << ": sending reply\n"; + std::cout << name_ << " rank " << rank << ": sending reply\n"; reply->set_receive_buffer(buffer_); sent_++; if (sent_ == world_size_) { - std::cout << "Allgather rank " << rank << ": all replies sent\n"; + std::cout << name_ << " rank " << rank << ": all replies sent\n"; sent_ = 0; received_ = 0; buffer_.clear(); @@ -61,92 +55,45 @@ class FederatedService final : public Federated::Service { return grpc::Status::OK; } - grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, - AllreduceReply* reply) override { - // Pass through if there is only 1 client. - if (world_size_ == 1) { - reply->set_receive_buffer(request->send_buffer()); - return grpc::Status::OK; - } - - std::unique_lock lock(mutex_); - - auto const rank = request->rank(); - - if (received_ == 0) { - std::cout << "Allreduce rank " << rank << ": first request, copying send buffer\n"; - buffer_ = request->send_buffer(); - } else { - std::cout << "Allreduce rank " << rank << ": accumulating send buffer into common buffer\n"; - Accumulate(request->send_buffer(), request->data_type(), request->reduce_operation()); - } - received_++; - - if (received_ == world_size_) { - std::cout << "Allreduce rank " << rank << ": all requests received\n"; - reply->set_receive_buffer(buffer_); - sent_++; - lock.unlock(); - cv_.notify_all(); - return grpc::Status::OK; - } - - std::cout << "Allreduce rank " << rank << ": waiting for all clients\n"; - cv_.wait(lock, [this] { return received_ == world_size_; }); + protected: + virtual void OnRequest(Request const* request) = 0; - std::cout << "Allreduce rank " << rank << ": sending reply\n"; - reply->set_receive_buffer(buffer_); - sent_++; + std::string const name_; + int const world_size_; + int received_{}; + int sent_{}; + std::string buffer_{}; + mutable std::mutex mutex_; + mutable std::condition_variable cv_; +}; - if (sent_ == world_size_) { - std::cout << "Allreduce rank " << rank << ": all replies sent\n"; - sent_ = 0; - received_ = 0; - buffer_.clear(); - } +class AllgatherOp : public Operation { + public: + explicit AllgatherOp(int const world_size) + : Operation("Allgather", world_size) {} - return grpc::Status::OK; + protected: + void OnRequest(AllgatherRequest const* request) override { + auto const rank = request->rank(); + auto const& send_buffer = request->send_buffer(); + auto const buffer_size = send_buffer.size(); + buffer_.resize(buffer_size * world_size_); + buffer_.replace(rank * buffer_size, buffer_size, send_buffer); } +}; - grpc::Status Broadcast(grpc::ServerContext* context, - xgboost::federated::BroadcastRequest const* request, - xgboost::federated::BroadcastReply* reply) override { - // Pass through if there is only 1 client. - if (world_size_ == 1) { - reply->set_receive_buffer(request->send_buffer()); - return grpc::Status::OK; - } - - std::unique_lock lock(mutex_); - - auto const rank = request->rank(); +class AllreduceOp : public Operation { + public: + explicit AllreduceOp(int const world_size) + : Operation("Allreduce", world_size) {} - if (request->rank() == request->root()) { - std::cout << "Broadcast rank " << rank << ": root copying send buffer to common buffer\n"; + protected: + void OnRequest(AllreduceRequest const* request) override { + if (buffer_.empty()) { buffer_ = request->send_buffer(); - received_ = world_size_; - reply->set_receive_buffer(buffer_); - sent_++; - lock.unlock(); - cv_.notify_all(); - return grpc::Status::OK; - } - - std::cout << "Broadcast rank " << rank << ": waiting for the root\n"; - cv_.wait(lock, [this] { return received_ == world_size_; }); - - std::cout << "Broadcast rank " << rank << ": sending reply\n"; - reply->set_receive_buffer(buffer_); - sent_++; - - if (sent_ == world_size_) { - std::cout << "Broadcast rank " << rank << ": all replies sent\n"; - sent_ = 0; - received_ = 0; - buffer_.clear(); + } else { + Accumulate(request->send_buffer(), request->data_type(), request->reduce_operation()); } - - return grpc::Status::OK; } private: @@ -221,13 +168,45 @@ class FederatedService final : public Federated::Service { throw std::invalid_argument("Invalid data type"); } } +}; - int const world_size_; - int received_{}; - int sent_{}; - std::string buffer_{}; - mutable std::mutex mutex_; - mutable std::condition_variable cv_; +class BroadcastOp : public Operation { + public: + explicit BroadcastOp(int const world_size) + : Operation("Broadcast", world_size) {} + + protected: + void OnRequest(BroadcastRequest const* request) override { + if (request->rank() == request->root()) { + buffer_ = request->send_buffer(); + } + } +}; + +class FederatedService final : public Federated::Service { + public: + explicit FederatedService(int const world_size) + : allgather_op_{world_size}, allreduce_op_{world_size}, broadcast_op_{world_size} {} + + grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, + AllgatherReply* reply) override { + return allgather_op_.Operate(request, reply); + } + + grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, + AllreduceReply* reply) override { + return allreduce_op_.Operate(request, reply); + } + + grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, + BroadcastReply* reply) override { + return broadcast_op_.Operate(request, reply); + } + + private: + AllgatherOp allgather_op_; + AllreduceOp allreduce_op_; + BroadcastOp broadcast_op_; }; void RunServer(int port, int world_size) { From 549fbdb094f8e699aff3fe238feb1b52f979ea91 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 16 Apr 2022 00:01:35 -0700 Subject: [PATCH 06/21] more refactoring of the federated server --- plugin/federated/federated.proto | 21 +-- plugin/federated/federated_client.h | 4 + plugin/federated/federated_server.cc | 217 ++++++++++++++------------- 3 files changed, 125 insertions(+), 117 deletions(-) diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index fc09b0f5c73e..7fd35f489899 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -28,8 +28,9 @@ enum ReduceOperation { } message AllgatherRequest { - int32 rank = 1; - bytes send_buffer = 2; + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; } message AllgatherReply { @@ -37,10 +38,11 @@ message AllgatherReply { } message AllreduceRequest { - int32 rank = 1; - bytes send_buffer = 2; - DataType data_type = 3; - ReduceOperation reduce_operation = 4; + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + DataType data_type = 4; + ReduceOperation reduce_operation = 5; } message AllreduceReply { @@ -48,9 +50,10 @@ message AllreduceReply { } message BroadcastRequest { - int32 rank = 1; - bytes send_buffer = 2; - int32 root = 3; + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + int32 root = 4; } message BroadcastReply { diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index 5bfc6afb445e..60c808bae30a 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -19,6 +19,7 @@ class FederatedClient { std::string Allgather(std::string const &send_buffer) { AllgatherRequest request; + request.set_sequence_number(sequence_number_++); request.set_rank(rank_); request.set_send_buffer(send_buffer); @@ -37,6 +38,7 @@ class FederatedClient { std::string Allreduce(std::string const &send_buffer, DataType data_type, ReduceOperation reduce_operation) { AllreduceRequest request; + request.set_sequence_number(sequence_number_++); request.set_rank(rank_); request.set_send_buffer(send_buffer); request.set_data_type(data_type); @@ -56,6 +58,7 @@ class FederatedClient { std::string Broadcast(std::string const &send_buffer, int root) { BroadcastRequest request; + request.set_sequence_number(sequence_number_++); request.set_rank(rank_); request.set_send_buffer(send_buffer); request.set_root(root); @@ -75,6 +78,7 @@ class FederatedClient { private: std::unique_ptr const stub_; int const rank_; + uint64_t sequence_number_{}; }; } // namespace federated diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index f2d40042a61f..f10f61a52712 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -1,104 +1,47 @@ #include -#include #include #include #include -#include namespace xgboost::federated { -template -class Operation { +class AllgatherHandler { public: - Operation(std::string name, int const world_size) - : name_{std::move(name)}, world_size_{world_size} {} + std::string const name{"Allgather"}; - grpc::Status Operate(Request const* request, Reply* reply) { - // Pass through if there is only 1 client. - if (world_size_ == 1) { - reply->set_receive_buffer(request->send_buffer()); - return grpc::Status::OK; - } - - std::unique_lock lock(mutex_); + explicit AllgatherHandler(int const world_size) : world_size_{world_size} {} + void Handle(AllgatherRequest const* request, std::string& buffer) const { auto const rank = request->rank(); - - std::cout << name_ << " rank " << rank << ": on request\n"; - OnRequest(request); - received_++; - - if (received_ == world_size_) { - std::cout << name_ << " rank " << rank << ": all requests received\n"; - reply->set_receive_buffer(buffer_); - sent_++; - lock.unlock(); - cv_.notify_all(); - return grpc::Status::OK; - } - - std::cout << name_ << " rank " << rank << ": waiting for all clients\n"; - cv_.wait(lock, [this] { return received_ == world_size_; }); - - std::cout << name_ << " rank " << rank << ": sending reply\n"; - reply->set_receive_buffer(buffer_); - sent_++; - - if (sent_ == world_size_) { - std::cout << name_ << " rank " << rank << ": all replies sent\n"; - sent_ = 0; - received_ = 0; - buffer_.clear(); + auto const& send_buffer = request->send_buffer(); + auto const send_size = send_buffer.size(); + if (buffer.size() != send_size * world_size_) { + buffer.resize(send_size * world_size_); } - - return grpc::Status::OK; + buffer.replace(rank * send_size, send_size, send_buffer); } - protected: - virtual void OnRequest(Request const* request) = 0; - - std::string const name_; + private: int const world_size_; - int received_{}; - int sent_{}; - std::string buffer_{}; - mutable std::mutex mutex_; - mutable std::condition_variable cv_; -}; - -class AllgatherOp : public Operation { - public: - explicit AllgatherOp(int const world_size) - : Operation("Allgather", world_size) {} - - protected: - void OnRequest(AllgatherRequest const* request) override { - auto const rank = request->rank(); - auto const& send_buffer = request->send_buffer(); - auto const buffer_size = send_buffer.size(); - buffer_.resize(buffer_size * world_size_); - buffer_.replace(rank * buffer_size, buffer_size, send_buffer); - } }; -class AllreduceOp : public Operation { +class AllreduceHandler { public: - explicit AllreduceOp(int const world_size) - : Operation("Allreduce", world_size) {} + std::string const name{"Allreduce"}; - protected: - void OnRequest(AllreduceRequest const* request) override { - if (buffer_.empty()) { - buffer_ = request->send_buffer(); + void Handle(AllreduceRequest const* request, std::string& buffer) const { + if (buffer.empty()) { + buffer = request->send_buffer(); } else { - Accumulate(request->send_buffer(), request->data_type(), request->reduce_operation()); + Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation()); } } private: template - void Accumulate(T* buffer, T const* input, std::size_t n, ReduceOperation reduce_operation) { + void Accumulate(T* buffer, T const* input, std::size_t n, + ReduceOperation reduce_operation) const { switch (reduce_operation) { case ReduceOperation::MAX: std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); }); @@ -114,55 +57,56 @@ class AllreduceOp : public Operation { } } - void Accumulate(std::string const& input, DataType data_type, ReduceOperation reduce_operation) { + void Accumulate(std::string& buffer, std::string const& input, DataType data_type, + ReduceOperation reduce_operation) const { switch (data_type) { case DataType::CHAR: - Accumulate(buffer_.data(), input.data(), buffer_.size(), reduce_operation); + Accumulate(buffer.data(), input.data(), buffer.size(), reduce_operation); break; case DataType::UCHAR: - Accumulate(reinterpret_cast(buffer_.data()), - reinterpret_cast(input.data()), buffer_.size(), + Accumulate(reinterpret_cast(buffer.data()), + reinterpret_cast(input.data()), buffer.size(), reduce_operation); break; case DataType::INT: - Accumulate(reinterpret_cast(buffer_.data()), - reinterpret_cast(input.data()), buffer_.size() / sizeof(int), + Accumulate(reinterpret_cast(buffer.data()), + reinterpret_cast(input.data()), buffer.size() / sizeof(int), reduce_operation); break; case DataType::UINT: - Accumulate(reinterpret_cast(buffer_.data()), + Accumulate(reinterpret_cast(buffer.data()), reinterpret_cast(input.data()), - buffer_.size() / sizeof(unsigned int), reduce_operation); + buffer.size() / sizeof(unsigned int), reduce_operation); break; case DataType::LONG: - Accumulate(reinterpret_cast(buffer_.data()), - reinterpret_cast(input.data()), buffer_.size() / sizeof(long), + Accumulate(reinterpret_cast(buffer.data()), + reinterpret_cast(input.data()), buffer.size() / sizeof(long), reduce_operation); break; case DataType::ULONG: - Accumulate(reinterpret_cast(buffer_.data()), + Accumulate(reinterpret_cast(buffer.data()), reinterpret_cast(input.data()), - buffer_.size() / sizeof(unsigned long), reduce_operation); + buffer.size() / sizeof(unsigned long), reduce_operation); break; case DataType::FLOAT: - Accumulate(reinterpret_cast(buffer_.data()), - reinterpret_cast(input.data()), buffer_.size() / sizeof(float), + Accumulate(reinterpret_cast(buffer.data()), + reinterpret_cast(input.data()), buffer.size() / sizeof(float), reduce_operation); break; case DataType::DOUBLE: - Accumulate(reinterpret_cast(buffer_.data()), - reinterpret_cast(input.data()), buffer_.size() / sizeof(double), + Accumulate(reinterpret_cast(buffer.data()), + reinterpret_cast(input.data()), buffer.size() / sizeof(double), reduce_operation); break; case DataType::LONGLONG: - Accumulate(reinterpret_cast(buffer_.data()), + Accumulate(reinterpret_cast(buffer.data()), reinterpret_cast(input.data()), - buffer_.size() / sizeof(long long), reduce_operation); + buffer.size() / sizeof(long long), reduce_operation); break; case DataType::ULONGLONG: - Accumulate(reinterpret_cast(buffer_.data()), + Accumulate(reinterpret_cast(buffer.data()), reinterpret_cast(input.data()), - buffer_.size() / sizeof(unsigned long long), reduce_operation); + buffer.size() / sizeof(unsigned long long), reduce_operation); break; default: throw std::invalid_argument("Invalid data type"); @@ -170,15 +114,13 @@ class AllreduceOp : public Operation { } }; -class BroadcastOp : public Operation { +class BroadcastHandler { public: - explicit BroadcastOp(int const world_size) - : Operation("Broadcast", world_size) {} + std::string const name{"Broadcast"}; - protected: - void OnRequest(BroadcastRequest const* request) override { + static void Handle(BroadcastRequest const* request, std::string& buffer) { if (request->rank() == request->root()) { - buffer_ = request->send_buffer(); + buffer = request->send_buffer(); } } }; @@ -186,27 +128,86 @@ class BroadcastOp : public Operation { class FederatedService final : public Federated::Service { public: explicit FederatedService(int const world_size) - : allgather_op_{world_size}, allreduce_op_{world_size}, broadcast_op_{world_size} {} + : world_size_{world_size}, + allgather_handler_{world_size}, + allreduce_handler_{}, + broadcast_handler_{} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override { - return allgather_op_.Operate(request, reply); + return Handle(request, reply, allgather_handler_); } grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) override { - return allreduce_op_.Operate(request, reply); + return Handle(request, reply, allreduce_handler_); } grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, BroadcastReply* reply) override { - return broadcast_op_.Operate(request, reply); + return Handle(request, reply, broadcast_handler_); } private: - AllgatherOp allgather_op_; - AllreduceOp allreduce_op_; - BroadcastOp broadcast_op_; + template + grpc::Status Handle(Request const* request, Reply* reply, RequestHandler const& handler) { + // Pass through if there is only 1 client. + if (world_size_ == 1) { + reply->set_receive_buffer(request->send_buffer()); + return grpc::Status::OK; + } + + std::unique_lock lock(mutex_); + + auto const sequence_number = request->sequence_number(); + auto const rank = request->rank(); + + std::cout << handler.name << " rank " << rank << ": waiting for current sequence number\n"; + cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); + + std::cout << handler.name << " rank " << rank << ": handling request\n"; + handler.Handle(request, buffer_); + received_++; + + if (received_ == world_size_) { + std::cout << handler.name << " rank " << rank << ": all requests received\n"; + reply->set_receive_buffer(buffer_); + sent_++; + lock.unlock(); + cv_.notify_all(); + return grpc::Status::OK; + } + + std::cout << handler.name << " rank " << rank << ": waiting for all clients\n"; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + std::cout << handler.name << " rank " << rank << ": sending reply\n"; + reply->set_receive_buffer(buffer_); + sent_++; + + if (sent_ == world_size_) { + std::cout << handler.name << " rank " << rank << ": all replies sent\n"; + sent_ = 0; + received_ = 0; + buffer_.clear(); + sequence_number_++; + lock.unlock(); + cv_.notify_all(); + } + + return grpc::Status::OK; + } + + int const world_size_; + AllgatherHandler allgather_handler_; + AllreduceHandler allreduce_handler_; + BroadcastHandler broadcast_handler_; + int received_{}; + int sent_{}; + std::string buffer_{}; + uint64_t sequence_number_{}; + mutable std::mutex mutex_; + mutable std::condition_variable cv_; }; void RunServer(int port, int world_size) { From 4723c00b507c31b7dd264bb958e317d1385c2cd8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 18 Apr 2022 12:50:55 -0700 Subject: [PATCH 07/21] support custom reduction --- plugin/federated/engine_federated.cc | 132 ++++++++---------------- tests/distributed/runtests-federated.sh | 2 +- 2 files changed, 44 insertions(+), 90 deletions(-) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index bcf3b9bbb4d1..ef272712de35 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -10,6 +10,15 @@ #include "rabit/internal/engine.h" #include "rabit/internal/utils.h" +namespace MPI { // NOLINT +// MPI data type to be compatible with existing MPI interface +class Datatype { + public: + size_t type_size; + explicit Datatype(size_t type_size) : type_size(type_size) {} +}; +} // namespace MPI + namespace rabit { namespace engine { @@ -24,10 +33,10 @@ class FederatedEngine : public IEngine { SetParam(env_var, value); } } - // Command line argument override. + // Command line argument overrides. for (int i = 0; i < argc; ++i) { - std::string key_value = argv[i]; - auto delimiter = key_value.find('='); + std::string const key_value = argv[i]; + auto const delimiter = key_value.find('='); if (delimiter != std::string::npos) { SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1)); } @@ -37,20 +46,24 @@ class FederatedEngine : public IEngine { client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_)); } - void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, + void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override { throw std::logic_error("FederatedEngine:: Allgather is not supported"); } - void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, + std::string Allgather(void *sendbuf, size_t total_size) { + std::string const send_buffer(reinterpret_cast(sendbuf), total_size); + return client_->Allgather(send_buffer); + } + + void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer, PreprocFunction prepare_fun, void *prepare_arg) override { throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead"); } - // NOLINTNEXTLINE(readability-identifier-naming) - void Allreduce_(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { + void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { auto *buffer = reinterpret_cast(sendrecvbuf); - std::string send_buffer(buffer, size); + std::string const send_buffer(buffer, size); auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op)); receive_buffer.copy(buffer, size); } @@ -60,8 +73,9 @@ class FederatedEngine : public IEngine { } void Broadcast(void *sendrecvbuf, size_t size, int root) override { + if (world_size_ == 1) return; auto *buffer = reinterpret_cast(sendrecvbuf); - std::string send_buffer(buffer, size); + std::string const send_buffer(buffer, size); auto const receive_buffer = client_->Broadcast(send_buffer, root); if (rank_ != root) { receive_buffer.copy(buffer, size); @@ -190,95 +204,35 @@ void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::Re mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, void *prepare_arg) { if (prepare_fun != nullptr) prepare_fun(prepare_arg); - engine.Allreduce_(sendrecvbuf, type_nbytes * count, dtype, op); + if (engine.GetWorldSize() == 1) return; + engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op); } -// code for reduce handle -ReduceHandle::ReduceHandle(void) : handle_(NULL), redfunc_(NULL), htype_(NULL) {} - -ReduceHandle::~ReduceHandle(void) { - /* !WARNING! - - A handle can be held by a tree method/Learner from xgboost. The booster might not be - freed until program exit, while (good) users call rabit.finalize() before reaching - the end of program. So op->Free() might be called after finalization and results - into following error: - - ``` - Attempting to use an MPI routine after finalizing MPICH - ``` - - Here we skip calling Free if MPI has already been finalized to workaround the issue. - It can be a potential leak of memory. The best way to resolve it is to eliminate all - use of long living handle. - */ - int finalized = 0; - // CHECK_EQ(MPI_Finalized(&finalized), MPI_SUCCESS); - if (handle_ != NULL) { - // MPI::Op *op = reinterpret_cast(handle_); - // if (!finalized) { - // op->Free(); - // } - // delete op; - } - if (htype_ != NULL) { - MPI::Datatype *dtype = reinterpret_cast(htype_); - if (!finalized) { - // dtype->Free(); - } - // delete dtype; - } -} +ReduceHandle::ReduceHandle() : created_type_nbytes_{} {} +ReduceHandle::~ReduceHandle() = default; -int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { - return 0; - // return dtype.Get_size(); -} +int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast(dtype.type_size); } void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { - utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice"); - if (type_nbytes != 0) { - // MPI::Datatype *dtype = new MPI::Datatype(); - // if (type_nbytes % 8 == 0) { - // *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*) - // } else if (type_nbytes % 4 == 0) { - // *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); - // } else { - // *dtype = MPI::CHAR.Create_contiguous(type_nbytes); - // } - // dtype->Commit(); - created_type_nbytes_ = type_nbytes; - // htype_ = dtype; - } - // MPI::Op *op = new MPI::Op(); - // MPI::User_function *pf = redfunc; - // op->Init(pf, true); - // handle_ = op; + utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice"); + redfunc_ = redfunc; } void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::PreprocFunction prepare_fun, void *prepare_arg) { - utils::Assert(handle_ != NULL, "must initialize handle to call AllReduce"); - // MPI::Op *op = reinterpret_cast(handle_); - // MPI::Datatype *dtype = reinterpret_cast(htype_); - // if (created_type_nbytes_ != type_nbytes || dtype == NULL) { - // if (dtype == NULL) { - // dtype = new MPI::Datatype(); - // } else { - // dtype->Free(); - // } - // if (type_nbytes % 8 == 0) { - // *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*) - // } else if (type_nbytes % 4 == 0) { - // *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); - // } else { - // *dtype = MPI::CHAR.Create_contiguous(type_nbytes); - // } - // dtype->Commit(); - // created_type_nbytes_ = type_nbytes; - // } - // if (prepare_fun != NULL) prepare_fun(prepare_arg); - // MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op); + utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce"); + if (prepare_fun != nullptr) prepare_fun(prepare_arg); + if (engine.GetWorldSize() == 1) return; + + auto const buffer_size = type_nbytes * count; + auto const gathered = engine.Allgather(sendrecvbuf, buffer_size); + auto const *data = gathered.data(); + for (int i = 0; i < engine.GetWorldSize(); i++) { + if (i != engine.GetRank()) { + redfunc_(data + buffer_size * i, sendrecvbuf, static_cast(count), + MPI::Datatype(type_nbytes)); + } + } } } // namespace engine diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index 7f757aa7084a..fc43ed587353 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -7,7 +7,7 @@ rm -f ./*.model* port=9091 world_size=3 -../../build/plugin/federated/federated_server ${port} ${world_size} & +../../build/plugin/federated/federated_server ${port} ${world_size} > /dev/null & export FEDERATED_SERVER_ADDRESS="localhost:${port}" export FEDERATED_WORLD_SIZE=${world_size} From 457f690d036d977082ef57a367f2350ab17d1480 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 19 Apr 2022 12:57:13 -0700 Subject: [PATCH 08/21] remove unused includes --- plugin/federated/engine_federated.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index ef272712de35..e4521646d520 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -1,7 +1,3 @@ -#define NOMINMAX -#include -#include - #include #include #include @@ -208,7 +204,7 @@ void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::Re engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op); } -ReduceHandle::ReduceHandle() : created_type_nbytes_{} {} +ReduceHandle::ReduceHandle() = default; ReduceHandle::~ReduceHandle() = default; int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast(dtype.type_size); } From 66e942580ede30a0e6124d85f1cf8a99df7f939c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 19 Apr 2022 13:19:51 -0700 Subject: [PATCH 09/21] fix finalize --- plugin/federated/engine_federated.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index e4521646d520..3ab9c10e4cca 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -42,6 +42,8 @@ class FederatedEngine : public IEngine { client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_)); } + void Finalize() { client_.reset(); } + void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override { throw std::logic_error("FederatedEngine:: Allgather is not supported"); @@ -184,13 +186,21 @@ bool Init(int argc, char *argv[]) { engine.Init(argc, argv); return true; } catch (std::exception const &e) { - fprintf(stderr, " failed in Federated Init %s\n", e.what()); + fprintf(stderr, " failed in federated Init %s\n", e.what()); return false; } } /*! \brief finalize synchronization module */ -bool Finalize() { return true; } +bool Finalize() { + try { + engine.Finalize(); + return true; + } catch (const std::exception &e) { + fprintf(stderr, "failed in federated shutdown %s\n", e.what()); + return false; + } +} /*! \brief singleton method to get engine */ IEngine *GetEngine() { return &engine; } From 4dc81dfd1c0cf4d74807155133ef5c6794578918 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 19 Apr 2022 15:20:34 -0700 Subject: [PATCH 10/21] no splitting data in federated mode --- plugin/federated/CMakeLists.txt | 1 + src/c_api/c_api.cc | 4 ++++ tests/distributed/runtests-federated.sh | 10 +++++--- tests/distributed/test_federated.py | 32 +++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 tests/distributed/test_federated.py diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 789c181cf441..72d2f06f20f6 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -26,3 +26,4 @@ target_link_libraries(test_client PRIVATE federated_client) target_sources(objxgboost PRIVATE engine_federated.cc) target_link_libraries(objxgboost PRIVATE federated_client) +target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a11602a56610..20970f82d0ee 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -198,11 +198,15 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, DMatrixHandle *out) { API_BEGIN(); bool load_row_split = false; +#if defined(XGBOOST_USE_FEDERATED) + LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; +#else if (rabit::IsDistributed()) { LOG(CONSOLE) << "XGBoost distributed mode detected, " << "will split data among workers"; load_row_split = true; } +#endif *out = new std::shared_ptr(DMatrix::Load(fname, silent != 0, load_row_split)); API_END(); } diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index fc43ed587353..deb50daee782 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -2,17 +2,21 @@ trap "kill 0" EXIT -rm -f ./*.model* +rm -f ./*.model* ./agaricus* port=9091 world_size=3 -../../build/plugin/federated/federated_server ${port} ${world_size} > /dev/null & +../../build/plugin/federated/federated_server ${port} ${world_size} >/dev/null & + +# Split train and test files: +split -n l/${world_size} -d -a 1 ../../demo/data/agaricus.txt.train agaricus.txt.train- +split -n l/${world_size} -d -a 1 ../../demo/data/agaricus.txt.test agaricus.txt.test- export FEDERATED_SERVER_ADDRESS="localhost:${port}" export FEDERATED_WORLD_SIZE=${world_size} for ((rank = 0; rank < world_size; rank++)); do - FEDERATED_RANK=${rank} python test_basic.py & + FEDERATED_RANK=${rank} python test_federated.py & pids[${rank}]=$! done diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py new file mode 100644 index 000000000000..da3c372c1185 --- /dev/null +++ b/tests/distributed/test_federated.py @@ -0,0 +1,32 @@ +#!/usr/bin/python +import os + +import xgboost as xgb + +# Always call this before using distributed module +xgb.rabit.init() + +# Load file, file will not be sharded in federated mode. +rank = os.getenv('FEDERATED_RANK') +dtrain = xgb.DMatrix('agaricus.txt.train-%s' % rank) +dtest = xgb.DMatrix('agaricus.txt.test-%s' % rank) + +# Specify parameters via map, definition are same as c++ version +param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + +# Specify validations set to watch performance +watchlist = [(dtest, 'eval'), (dtrain, 'train')] +num_round = 20 + +# Run training, all the features in training API is available. +# Currently, this script only support calling train once for fault recovery purpose. +bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2) + +# Save the model, only ask process 0 to save the model. +if xgb.rabit.get_rank() == 0: + bst.save_model("test.model.json") + xgb.rabit.tracker_print("Finished training\n") + +# Notify the tracker all training has been successful +# This is only needed in distributed training. +xgb.rabit.finalize() From e40ca5d9d11ab62f03e57d404e1ab42c59c76870 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 19 Apr 2022 15:25:46 -0700 Subject: [PATCH 11/21] update readme --- plugin/federated/README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/plugin/federated/README.md b/plugin/federated/README.md index dfe9e5592464..06f8ca3fb1fb 100644 --- a/plugin/federated/README.md +++ b/plugin/federated/README.md @@ -14,3 +14,23 @@ cmake -S . -B build -GNinja\ -DABSL_PROPAGATE_CXX_STD=ON cmake --build build --target install ``` + +Build the Plugin +---------------- +```shell +# Under xgboost source tree. +mkdir build +cd build +cmake .. -DPLUGIN_FEDERATED=ON +make -j$(nproc) +cd ../python-package +pip install -e . # or equivalently python setup.py develop +``` + +Test Federated XGBoost +---------------------- +```shell +# Under xgboost source tree. +cd tests/distributed +./runtests-federated.sh +``` From c99883cfdf0f45f0fd84663bbb39fd3dd6e00ab3 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 19 Apr 2022 21:42:22 -0700 Subject: [PATCH 12/21] support more than 10 workers --- tests/distributed/runtests-federated.sh | 4 ++-- tests/distributed/test_federated.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index deb50daee782..536c58705ae2 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -10,8 +10,8 @@ world_size=3 ../../build/plugin/federated/federated_server ${port} ${world_size} >/dev/null & # Split train and test files: -split -n l/${world_size} -d -a 1 ../../demo/data/agaricus.txt.train agaricus.txt.train- -split -n l/${world_size} -d -a 1 ../../demo/data/agaricus.txt.test agaricus.txt.test- +split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train- +split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test- export FEDERATED_SERVER_ADDRESS="localhost:${port}" export FEDERATED_WORLD_SIZE=${world_size} diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index da3c372c1185..4043b62a94ee 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -7,9 +7,9 @@ xgb.rabit.init() # Load file, file will not be sharded in federated mode. -rank = os.getenv('FEDERATED_RANK') -dtrain = xgb.DMatrix('agaricus.txt.train-%s' % rank) -dtest = xgb.DMatrix('agaricus.txt.test-%s' % rank) +rank = int(os.getenv('FEDERATED_RANK')) +dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) +dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) # Specify parameters via map, definition are same as c++ version param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} From 60ef12eb0b090fd448d025f3d69fab89606d79da Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 21 Apr 2022 15:10:56 -0700 Subject: [PATCH 13/21] add some comments and copyright headers --- plugin/CMakeLists.txt | 1 + plugin/federated/CMakeLists.txt | 7 +++++++ plugin/federated/README.md | 11 +++++------ plugin/federated/engine_federated.cc | 9 ++++++--- plugin/federated/federated.proto | 7 +++++++ plugin/federated/federated_client.h | 6 ++++++ plugin/federated/federated_server.cc | 8 ++++++++ plugin/federated/test_client.cc | 3 +++ rabit/CMakeLists.txt | 1 + rabit/include/rabit/internal/engine.h | 2 +- tests/distributed/runtests-federated.sh | 3 ++- 11 files changed, 47 insertions(+), 11 deletions(-) diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index fdecf41e6e9f..9f59c68f14e0 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -41,6 +41,7 @@ if (PLUGIN_UPDATER_ONEAPI) target_sources(objxgboost INTERFACE $) endif (PLUGIN_UPDATER_ONEAPI) +# Add the Federate Learning plugin if enabled. if (PLUGIN_FEDERATED) add_subdirectory(federated) endif (PLUGIN_FEDERATED) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 72d2f06f20f6..e338b2c49ab0 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -1,7 +1,9 @@ +# gRPC needs to be installed first. See README.md. find_package(protobuf CONFIG REQUIRED) find_package(gRPC CONFIG REQUIRED) find_package(Threads) +# Generated code from the protobuf definition. add_library(federated_proto federated.proto) target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) @@ -15,15 +17,20 @@ protobuf_generate( GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}") +# Wrapper for the gRPC client. add_library(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) +# Federated Learning gRPC server. add_executable(federated_server federated_server.cc) target_link_libraries(federated_server PRIVATE federated_proto) +# A test client to exercise the gRPC calls. +# TODO(rongou): add unit tests and get rid of this. add_executable(test_client test_client.cc) target_link_libraries(test_client PRIVATE federated_client) +# Rabit engine for Federated Learning. target_sources(objxgboost PRIVATE engine_federated.cc) target_link_libraries(objxgboost PRIVATE federated_client) target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/README.md b/plugin/federated/README.md index 06f8ca3fb1fb..a5fa95e0c140 100644 --- a/plugin/federated/README.md +++ b/plugin/federated/README.md @@ -1,17 +1,16 @@ XGBoost Plugin for Federated Learning ===================================== -This folder contains the plugin for federated learning. +This folder contains the plugin for federated learning. Follow these steps to build and test it. Install gRPC ------------ ```shell sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build -git clone -b v1.45.1 https://github.com/grpc/grpc +git clone -b v1.45.2 https://github.com/grpc/grpc cd grpc git submodule update --init -cmake -S . -B build -GNinja\ - -DABSL_PROPAGATE_CXX_STD=ON +cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON cmake --build build --target install ``` @@ -21,8 +20,8 @@ Build the Plugin # Under xgboost source tree. mkdir build cd build -cmake .. -DPLUGIN_FEDERATED=ON -make -j$(nproc) +cmake .. -GNinja -DPLUGIN_FEDERATED=ON +ninja cd ../python-package pip install -e . # or equivalently python setup.py develop ``` diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index 3ab9c10e4cca..febaef87969a 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -1,6 +1,7 @@ +/*! + * Copyright 2022 XGBoost contributors + */ #include -#include -#include #include "federated_client.h" #include "rabit/internal/engine.h" @@ -219,7 +220,8 @@ ReduceHandle::~ReduceHandle() = default; int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast(dtype.type_size); } -void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { +void ReduceHandle::Init(IEngine::ReduceFunction redfunc, + __attribute__((unused)) size_t type_nbytes) { utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice"); redfunc_ = redfunc; } @@ -230,6 +232,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count if (prepare_fun != nullptr) prepare_fun(prepare_arg); if (engine.GetWorldSize() == 1) return; + // Gather all the buffers and call the reduce function locally. auto const buffer_size = type_nbytes * count; auto const gathered = engine.Allgather(sendrecvbuf, buffer_size); auto const *data = gathered.data(); diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index 7fd35f489899..cba897c0ea81 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -1,3 +1,6 @@ +/*! + * Copyright 2022 XGBoost contributors + */ syntax = "proto3"; package xgboost.federated; @@ -28,6 +31,7 @@ enum ReduceOperation { } message AllgatherRequest { + // An incrementing counter that is unique to each round to operations. uint64 sequence_number = 1; int32 rank = 2; bytes send_buffer = 3; @@ -38,6 +42,7 @@ message AllgatherReply { } message AllreduceRequest { + // An incrementing counter that is unique to each round to operations. uint64 sequence_number = 1; int32 rank = 2; bytes send_buffer = 3; @@ -50,9 +55,11 @@ message AllreduceReply { } message BroadcastRequest { + // An incrementing counter that is unique to each round to operations. uint64 sequence_number = 1; int32 rank = 2; bytes send_buffer = 3; + // The root rank to broadcast from. int32 root = 4; } diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index 60c808bae30a..300b95daff43 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -1,3 +1,6 @@ +/*! + * Copyright 2022 XGBoost contributors + */ #pragma once #include #include @@ -10,6 +13,9 @@ namespace xgboost { namespace federated { +/** + * @brief A wrapper around the gRPC client. + */ class FederatedClient { public: explicit FederatedClient(std::string const &server_address, int rank) diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index f10f61a52712..edf6fd75a2b7 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2022 XGBoost contributors + */ #include #include @@ -16,9 +19,11 @@ class AllgatherHandler { auto const rank = request->rank(); auto const& send_buffer = request->send_buffer(); auto const send_size = send_buffer.size(); + // Resize the buffer if this is the first request. if (buffer.size() != send_size * world_size_) { buffer.resize(send_size * world_size_); } + // Splice the send_buffer into the common buffer. buffer.replace(rank * send_size, send_size, send_buffer); } @@ -32,8 +37,10 @@ class AllreduceHandler { void Handle(AllreduceRequest const* request, std::string& buffer) const { if (buffer.empty()) { + // Copy the send_buffer if this is the first request. buffer = request->send_buffer(); } else { + // Apply the reduce_operation to the send_buffer and the common buffer. Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation()); } } @@ -120,6 +127,7 @@ class BroadcastHandler { static void Handle(BroadcastRequest const* request, std::string& buffer) { if (request->rank() == request->root()) { + // Copy the send_buffer if this is the root. buffer = request->send_buffer(); } } diff --git a/plugin/federated/test_client.cc b/plugin/federated/test_client.cc index 701432da72a9..16c080cdb28e 100644 --- a/plugin/federated/test_client.cc +++ b/plugin/federated/test_client.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2022 XGBoost contributors + */ #include #include "federated_client.h" diff --git a/rabit/CMakeLists.txt b/rabit/CMakeLists.txt index 594b6f15dda8..3a76794f5f58 100644 --- a/rabit/CMakeLists.txt +++ b/rabit/CMakeLists.txt @@ -7,6 +7,7 @@ set(RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc) if (PLUGIN_FEDERATED) + # Skip the engine if the Federated Learning plugin is enabled. elseif (RABIT_BUILD_MPI) list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc) elseif (RABIT_MOCK) diff --git a/rabit/include/rabit/internal/engine.h b/rabit/include/rabit/internal/engine.h index 50b452f8db1a..88dd263f491f 100644 --- a/rabit/include/rabit/internal/engine.h +++ b/rabit/include/rabit/internal/engine.h @@ -260,7 +260,7 @@ class ReduceHandle { * with the type the reduce function needs to deal with * the reduce function MUST be communicative */ - void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes); + void Init(IEngine::ReduceFunction redfunc, __attribute__((unused)) size_t type_nbytes); /*! * \brief customized in-place all reduce operation * \param sendrecvbuf the in place send-recv buffer diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index 536c58705ae2..d750682b5bbe 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -7,9 +7,10 @@ rm -f ./*.model* ./agaricus* port=9091 world_size=3 +# Start the federated server. ../../build/plugin/federated/federated_server ${port} ${world_size} >/dev/null & -# Split train and test files: +# Split train and test files manually to simulate a federated environment. split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train- split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test- From dca5902daf1cecf8fe1c76b7f43f6077046c8524 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 23 Apr 2022 01:07:24 -0700 Subject: [PATCH 14/21] add mutual ssl/tls authentication --- plugin/federated/CMakeLists.txt | 4 ++-- plugin/federated/engine_federated.cc | 32 ++++++++++++++++++++++--- plugin/federated/federated_client.h | 13 +++++++--- plugin/federated/federated_server.cc | 31 ++++++++++++++++++++---- tests/distributed/runtests-federated.sh | 25 +++++++++++++++++-- 5 files changed, 90 insertions(+), 15 deletions(-) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index e338b2c49ab0..ddabe4451ce3 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -27,8 +27,8 @@ target_link_libraries(federated_server PRIVATE federated_proto) # A test client to exercise the gRPC calls. # TODO(rongou): add unit tests and get rid of this. -add_executable(test_client test_client.cc) -target_link_libraries(test_client PRIVATE federated_client) +#add_executable(test_client test_client.cc) +#target_link_libraries(test_client PRIVATE federated_client) # Rabit engine for Federated Learning. target_sources(objxgboost PRIVATE engine_federated.cc) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index febaef87969a..a9ea5d757aa7 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -2,6 +2,8 @@ * Copyright 2022 XGBoost contributors */ #include +#include +#include #include "federated_client.h" #include "rabit/internal/engine.h" @@ -40,7 +42,8 @@ class FederatedEngine : public IEngine { } utils::Printf("Connecting to federated server %s, world size %d, rank %d", server_address_.c_str(), world_size_, rank_); - client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_)); + client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, ca_cert_, + client_key_, client_cert_)); } void Finalize() { client_.reset(); } @@ -166,14 +169,37 @@ class FederatedEngine : public IEngine { world_size_ = std::stoi(val); } else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { rank_ = std::stoi(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_CA_CERT")) { + ca_cert_ = ReadFile(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) { + client_key_ = ReadFile(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) { + client_cert_ = ReadFile(val); } } - std::vector const env_vars_{"FEDERATED_SERVER_ADDRESS", "FEDERATED_WORLD_SIZE", - "FEDERATED_RANK"}; + static std::string ReadFile(std::string const &path) { + auto stream = std::ifstream(path.data()); + std::ostringstream out; + out << stream.rdbuf(); + return out.str(); + } + + // clang-format off + std::vector const env_vars_{ + "FEDERATED_SERVER_ADDRESS", + "FEDERATED_WORLD_SIZE", + "FEDERATED_RANK", + "FEDERATED_CA_CERT", + "FEDERATED_CLIENT_KEY", + "FEDERATED_CLIENT_CERT" }; + // clang-format on std::string server_address_{"localhost:9091"}; int world_size_{1}; int rank_{0}; + std::string ca_cert_{}; + std::string client_key_{}; + std::string client_cert_{}; std::unique_ptr client_{}; int version_number_{0}; }; diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index 300b95daff43..046ad47d9b62 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -18,9 +18,16 @@ namespace federated { */ class FederatedClient { public: - explicit FederatedClient(std::string const &server_address, int rank) - : stub_{Federated::NewStub( - grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))}, + explicit FederatedClient(std::string const &server_address, int rank, std::string const &ca_cert, + std::string const &client_key, std::string const &client_cert) + : stub_{[&] { + grpc::SslCredentialsOptions options; + options.pem_root_certs = ca_cert; + options.pem_private_key = client_key; + options.pem_cert_chain = client_cert; + return Federated::NewStub( + grpc::CreateChannel(server_address, grpc::SslCredentials(options))); + }()}, rank_{rank} {} std::string Allgather(std::string const &send_buffer) { diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index edf6fd75a2b7..80be27c99873 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -2,10 +2,13 @@ * Copyright 2022 XGBoost contributors */ #include +#include #include #include +#include #include +#include namespace xgboost::federated { @@ -218,12 +221,27 @@ class FederatedService final : public Federated::Service { mutable std::condition_variable cv_; }; -void RunServer(int port, int world_size) { +std::string ReadFile(std::string const& path) { + auto stream = std::ifstream(path.data()); + std::ostringstream out; + out << stream.rdbuf(); + return out.str(); +} + +void RunServer(int port, int world_size, std::string const& ca_cert_file, + std::string const& key_file, std::string const& cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); FederatedService service{world_size}; grpc::ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + auto options = + grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + options.pem_root_certs = ReadFile(ca_cert_file); + auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); + key.private_key = ReadFile(key_file); + key.cert_chain = ReadFile(cert_file); + options.pem_key_cert_pairs.push_back(key); + builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); std::cout << "Federated server listening on " << server_address << ", world size " << world_size @@ -234,12 +252,15 @@ void RunServer(int port, int world_size) { } // namespace xgboost::federated int main(int argc, char** argv) { - if (argc != 3) { - std::cerr << "Usage: federated_server port world_size" << '\n'; + if (argc != 6) { + std::cerr << "Usage: federated_server port world_size ca_cert_file key_file cert_file" << '\n'; return 1; } auto port = std::stoi(argv[1]); auto world_size = std::stoi(argv[2]); - xgboost::federated::RunServer(port, world_size); + std::string ca_cert_file = argv[3]; + std::string key_file = argv[4]; + std::string cert_file = argv[5]; + xgboost::federated::RunServer(port, world_size, ca_cert_file, key_file, cert_file); return 0; } diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index d750682b5bbe..9622189302ed 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -1,14 +1,32 @@ #!/bin/bash +set -e + trap "kill 0" EXIT -rm -f ./*.model* ./agaricus* +rm -f ./*.model* ./agaricus* ./*.pem ./.*.srl port=9091 world_size=3 +# Generate server and client certificates. +# 1. Generate CA's private key and self-signed certificate. +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout ca-key.pem -out ca-cert.pem -subj "/C=US/CN=localhost" + +# 2. Generate gRPC server's private key and certificate signing request (CSR). +openssl req -newkey rsa:2048 -nodes -keyout server-key.pem -out server-req.pem -subj "/C=US/CN=localhost" + +# 3. Use CA's private key to sign gRPC server's CSR and get back the signed certificate. +openssl x509 -req -in server-req.pem -days 7 -CA ca-cert.pem -CAkey ca-key.pem -CAcreateserial -out server-cert.pem + +# 4. Generate client's private key and certificate signing request (CSR). +openssl req -newkey rsa:2048 -nodes -keyout "client-key.pem" -out "client-req.pem" -subj "/C=US/CN=localhost" + +# 5. Use CA's private key to sign client's CSR and get back the signed certificate. +openssl x509 -req -in "client-req.pem" -days 7 -CA ca-cert.pem -CAkey ca-key.pem -CAcreateserial -out "client-cert.pem" + # Start the federated server. -../../build/plugin/federated/federated_server ${port} ${world_size} >/dev/null & +../../build/plugin/federated/federated_server ${port} ${world_size} client-cert.pem server-key.pem server-cert.pem >/dev/null & # Split train and test files manually to simulate a federated environment. split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train- @@ -16,6 +34,9 @@ split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test- export FEDERATED_SERVER_ADDRESS="localhost:${port}" export FEDERATED_WORLD_SIZE=${world_size} +export FEDERATED_CA_CERT=server-cert.pem +export FEDERATED_CLIENT_KEY=client-key.pem +export FEDERATED_CLIENT_CERT=client-cert.pem for ((rank = 0; rank < world_size; rank++)); do FEDERATED_RANK=${rank} python test_federated.py & pids[${rank}]=$! From 05fce5860e371ff0747f1bb7d0631ed4f2b2a2c1 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 23 Apr 2022 01:39:54 -0700 Subject: [PATCH 15/21] simplify cert generation --- plugin/federated/engine_federated.cc | 10 +++++----- plugin/federated/federated_client.h | 7 ++++--- plugin/federated/federated_server.cc | 17 +++++++++-------- tests/distributed/runtests-federated.sh | 22 +++++----------------- 4 files changed, 23 insertions(+), 33 deletions(-) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index a9ea5d757aa7..ed7252ba117c 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -42,7 +42,7 @@ class FederatedEngine : public IEngine { } utils::Printf("Connecting to federated server %s, world size %d, rank %d", server_address_.c_str(), world_size_, rank_); - client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, ca_cert_, + client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_, client_key_, client_cert_)); } @@ -169,8 +169,8 @@ class FederatedEngine : public IEngine { world_size_ = std::stoi(val); } else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { rank_ = std::stoi(val); - } else if (!strcasecmp(name.c_str(), "FEDERATED_CA_CERT")) { - ca_cert_ = ReadFile(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) { + server_cert_ = ReadFile(val); } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) { client_key_ = ReadFile(val); } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) { @@ -190,14 +190,14 @@ class FederatedEngine : public IEngine { "FEDERATED_SERVER_ADDRESS", "FEDERATED_WORLD_SIZE", "FEDERATED_RANK", - "FEDERATED_CA_CERT", + "FEDERATED_SERVER_CERT", "FEDERATED_CLIENT_KEY", "FEDERATED_CLIENT_CERT" }; // clang-format on std::string server_address_{"localhost:9091"}; int world_size_{1}; int rank_{0}; - std::string ca_cert_{}; + std::string server_cert_{}; std::string client_key_{}; std::string client_cert_{}; std::unique_ptr client_{}; diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index 046ad47d9b62..c97a1ff7ab71 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -18,11 +18,12 @@ namespace federated { */ class FederatedClient { public: - explicit FederatedClient(std::string const &server_address, int rank, std::string const &ca_cert, - std::string const &client_key, std::string const &client_cert) + explicit FederatedClient(std::string const &server_address, int rank, + std::string const &server_cert, std::string const &client_key, + std::string const &client_cert) : stub_{[&] { grpc::SslCredentialsOptions options; - options.pem_root_certs = ca_cert; + options.pem_root_certs = server_cert; options.pem_private_key = client_key; options.pem_cert_chain = client_cert; return Federated::NewStub( diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index 80be27c99873..a8b62ba7fd3f 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -228,15 +228,15 @@ std::string ReadFile(std::string const& path) { return out.str(); } -void RunServer(int port, int world_size, std::string const& ca_cert_file, - std::string const& key_file, std::string const& cert_file) { +void RunServer(int port, int world_size, std::string const& key_file, std::string const& cert_file, + std::string const& client_cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); FederatedService service{world_size}; grpc::ServerBuilder builder; auto options = grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); - options.pem_root_certs = ReadFile(ca_cert_file); + options.pem_root_certs = ReadFile(client_cert_file); auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); key.private_key = ReadFile(key_file); key.cert_chain = ReadFile(cert_file); @@ -253,14 +253,15 @@ void RunServer(int port, int world_size, std::string const& ca_cert_file, int main(int argc, char** argv) { if (argc != 6) { - std::cerr << "Usage: federated_server port world_size ca_cert_file key_file cert_file" << '\n'; + std::cerr << "Usage: federated_server port world_size key_file cert_file client_cert_file" + << '\n'; return 1; } auto port = std::stoi(argv[1]); auto world_size = std::stoi(argv[2]); - std::string ca_cert_file = argv[3]; - std::string key_file = argv[4]; - std::string cert_file = argv[5]; - xgboost::federated::RunServer(port, world_size, ca_cert_file, key_file, cert_file); + std::string key_file = argv[3]; + std::string cert_file = argv[4]; + std::string client_cert_file = argv[5]; + xgboost::federated::RunServer(port, world_size, key_file, cert_file, client_cert_file); return 0; } diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index 9622189302ed..559b8762a01c 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -4,29 +4,17 @@ set -e trap "kill 0" EXIT -rm -f ./*.model* ./agaricus* ./*.pem ./.*.srl +rm -f ./*.model* ./agaricus* ./*.pem port=9091 world_size=3 # Generate server and client certificates. -# 1. Generate CA's private key and self-signed certificate. -openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout ca-key.pem -out ca-cert.pem -subj "/C=US/CN=localhost" - -# 2. Generate gRPC server's private key and certificate signing request (CSR). -openssl req -newkey rsa:2048 -nodes -keyout server-key.pem -out server-req.pem -subj "/C=US/CN=localhost" - -# 3. Use CA's private key to sign gRPC server's CSR and get back the signed certificate. -openssl x509 -req -in server-req.pem -days 7 -CA ca-cert.pem -CAkey ca-key.pem -CAcreateserial -out server-cert.pem - -# 4. Generate client's private key and certificate signing request (CSR). -openssl req -newkey rsa:2048 -nodes -keyout "client-key.pem" -out "client-req.pem" -subj "/C=US/CN=localhost" - -# 5. Use CA's private key to sign client's CSR and get back the signed certificate. -openssl x509 -req -in "client-req.pem" -days 7 -CA ca-cert.pem -CAkey ca-key.pem -CAcreateserial -out "client-cert.pem" +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost" +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost" # Start the federated server. -../../build/plugin/federated/federated_server ${port} ${world_size} client-cert.pem server-key.pem server-cert.pem >/dev/null & +../../build/plugin/federated/federated_server ${port} ${world_size} server-key.pem server-cert.pem client-cert.pem >/dev/null & # Split train and test files manually to simulate a federated environment. split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train- @@ -34,7 +22,7 @@ split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test- export FEDERATED_SERVER_ADDRESS="localhost:${port}" export FEDERATED_WORLD_SIZE=${world_size} -export FEDERATED_CA_CERT=server-cert.pem +export FEDERATED_SERVER_CERT=server-cert.pem export FEDERATED_CLIENT_KEY=client-key.pem export FEDERATED_CLIENT_CERT=client-cert.pem for ((rank = 0; rank < world_size; rank++)); do From 6c163194ba9db4cde34b0f5fb2ccf0b59c33d5f9 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 25 Apr 2022 19:53:35 -0700 Subject: [PATCH 16/21] change to functors --- plugin/federated/federated_server.cc | 50 +++++++++++------------ tests/cpp/plugin/test_federated_server.cc | 0 2 files changed, 25 insertions(+), 25 deletions(-) create mode 100644 tests/cpp/plugin/test_federated_server.cc diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index a8b62ba7fd3f..a88f9d199a0a 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -12,13 +12,13 @@ namespace xgboost::federated { -class AllgatherHandler { +class AllgatherFunctor { public: std::string const name{"Allgather"}; - explicit AllgatherHandler(int const world_size) : world_size_{world_size} {} + explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {} - void Handle(AllgatherRequest const* request, std::string& buffer) const { + void operator()(AllgatherRequest const* request, std::string& buffer) const { auto const rank = request->rank(); auto const& send_buffer = request->send_buffer(); auto const send_size = send_buffer.size(); @@ -34,11 +34,11 @@ class AllgatherHandler { int const world_size_; }; -class AllreduceHandler { +class AllreduceFunctor { public: std::string const name{"Allreduce"}; - void Handle(AllreduceRequest const* request, std::string& buffer) const { + void operator()(AllreduceRequest const* request, std::string& buffer) const { if (buffer.empty()) { // Copy the send_buffer if this is the first request. buffer = request->send_buffer(); @@ -124,11 +124,11 @@ class AllreduceHandler { } }; -class BroadcastHandler { +class BroadcastFunctor { public: std::string const name{"Broadcast"}; - static void Handle(BroadcastRequest const* request, std::string& buffer) { + void operator()(BroadcastRequest const* request, std::string& buffer) const { if (request->rank() == request->root()) { // Copy the send_buffer if this is the root. buffer = request->send_buffer(); @@ -140,28 +140,28 @@ class FederatedService final : public Federated::Service { public: explicit FederatedService(int const world_size) : world_size_{world_size}, - allgather_handler_{world_size}, - allreduce_handler_{}, - broadcast_handler_{} {} + allgather_functor_{world_size}, + allreduce_functor_{}, + broadcast_functor_{} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override { - return Handle(request, reply, allgather_handler_); + return Handle(request, reply, allgather_functor_); } grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) override { - return Handle(request, reply, allreduce_handler_); + return Handle(request, reply, allreduce_functor_); } grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, BroadcastReply* reply) override { - return Handle(request, reply, broadcast_handler_); + return Handle(request, reply, broadcast_functor_); } private: - template - grpc::Status Handle(Request const* request, Reply* reply, RequestHandler const& handler) { + template + grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor) { // Pass through if there is only 1 client. if (world_size_ == 1) { reply->set_receive_buffer(request->send_buffer()); @@ -173,15 +173,15 @@ class FederatedService final : public Federated::Service { auto const sequence_number = request->sequence_number(); auto const rank = request->rank(); - std::cout << handler.name << " rank " << rank << ": waiting for current sequence number\n"; + std::cout << functor.name << " rank " << rank << ": waiting for current sequence number\n"; cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); - std::cout << handler.name << " rank " << rank << ": handling request\n"; - handler.Handle(request, buffer_); + std::cout << functor.name << " rank " << rank << ": handling request\n"; + functor(request, buffer_); received_++; if (received_ == world_size_) { - std::cout << handler.name << " rank " << rank << ": all requests received\n"; + std::cout << functor.name << " rank " << rank << ": all requests received\n"; reply->set_receive_buffer(buffer_); sent_++; lock.unlock(); @@ -189,15 +189,15 @@ class FederatedService final : public Federated::Service { return grpc::Status::OK; } - std::cout << handler.name << " rank " << rank << ": waiting for all clients\n"; + std::cout << functor.name << " rank " << rank << ": waiting for all clients\n"; cv_.wait(lock, [this] { return received_ == world_size_; }); - std::cout << handler.name << " rank " << rank << ": sending reply\n"; + std::cout << functor.name << " rank " << rank << ": sending reply\n"; reply->set_receive_buffer(buffer_); sent_++; if (sent_ == world_size_) { - std::cout << handler.name << " rank " << rank << ": all replies sent\n"; + std::cout << functor.name << " rank " << rank << ": all replies sent\n"; sent_ = 0; received_ = 0; buffer_.clear(); @@ -210,9 +210,9 @@ class FederatedService final : public Federated::Service { } int const world_size_; - AllgatherHandler allgather_handler_; - AllreduceHandler allreduce_handler_; - BroadcastHandler broadcast_handler_; + AllgatherFunctor allgather_functor_; + AllreduceFunctor allreduce_functor_; + BroadcastFunctor broadcast_functor_; int received_{}; int sent_{}; std::string buffer_{}; diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc new file mode 100644 index 000000000000..e69de29bb2d1 From 851ccdedc715a3a26d4874fd3c51742c3cae7b39 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 27 Apr 2022 11:55:36 -0700 Subject: [PATCH 17/21] add c api --- plugin/federated/CMakeLists.txt | 11 +--- plugin/federated/federated_server.cc | 80 +++++++++++------------ plugin/federated/federated_server.h | 11 ++++ python-package/xgboost/__init__.py | 3 + python-package/xgboost/federated.py | 30 +++++++++ src/c_api/c_api.cc | 13 ++++ tests/distributed/runtests-federated.sh | 20 +----- tests/distributed/test_federated.py | 84 ++++++++++++++++++------- 8 files changed, 156 insertions(+), 96 deletions(-) create mode 100644 plugin/federated/federated_server.h create mode 100644 python-package/xgboost/federated.py diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index ddabe4451ce3..a72bd3ea0d1e 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -21,16 +21,7 @@ protobuf_generate( add_library(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) -# Federated Learning gRPC server. -add_executable(federated_server federated_server.cc) -target_link_libraries(federated_server PRIVATE federated_proto) - -# A test client to exercise the gRPC calls. -# TODO(rongou): add unit tests and get rid of this. -#add_executable(test_client test_client.cc) -#target_link_libraries(test_client PRIVATE federated_client) - # Rabit engine for Federated Learning. -target_sources(objxgboost PRIVATE engine_federated.cc) +target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc) target_link_libraries(objxgboost PRIVATE federated_client) target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index a88f9d199a0a..0d95559de850 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -1,16 +1,20 @@ /*! * Copyright 2022 XGBoost contributors */ +#include "federated_server.h" + #include #include #include +#include #include #include #include #include -namespace xgboost::federated { +namespace xgboost { +namespace federated { class AllgatherFunctor { public: @@ -71,50 +75,49 @@ class AllreduceFunctor { ReduceOperation reduce_operation) const { switch (data_type) { case DataType::CHAR: - Accumulate(buffer.data(), input.data(), buffer.size(), reduce_operation); + Accumulate(&buffer[0], reinterpret_cast(input.data()), buffer.size(), + reduce_operation); break; case DataType::UCHAR: - Accumulate(reinterpret_cast(buffer.data()), + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), buffer.size(), reduce_operation); break; case DataType::INT: - Accumulate(reinterpret_cast(buffer.data()), - reinterpret_cast(input.data()), buffer.size() / sizeof(int), - reduce_operation); + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), + buffer.size() / sizeof(int), reduce_operation); break; case DataType::UINT: - Accumulate(reinterpret_cast(buffer.data()), + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), buffer.size() / sizeof(unsigned int), reduce_operation); break; case DataType::LONG: - Accumulate(reinterpret_cast(buffer.data()), - reinterpret_cast(input.data()), buffer.size() / sizeof(long), - reduce_operation); + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), + buffer.size() / sizeof(long), reduce_operation); break; case DataType::ULONG: - Accumulate(reinterpret_cast(buffer.data()), + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), buffer.size() / sizeof(unsigned long), reduce_operation); break; case DataType::FLOAT: - Accumulate(reinterpret_cast(buffer.data()), + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), buffer.size() / sizeof(float), reduce_operation); break; case DataType::DOUBLE: - Accumulate(reinterpret_cast(buffer.data()), + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), buffer.size() / sizeof(double), reduce_operation); break; case DataType::LONGLONG: - Accumulate(reinterpret_cast(buffer.data()), + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), buffer.size() / sizeof(long long), reduce_operation); break; case DataType::ULONGLONG: - Accumulate(reinterpret_cast(buffer.data()), + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), buffer.size() / sizeof(unsigned long long), reduce_operation); break; @@ -168,20 +171,20 @@ class FederatedService final : public Federated::Service { return grpc::Status::OK; } - std::unique_lock lock(mutex_); + std::unique_lock lock(mutex_); auto const sequence_number = request->sequence_number(); auto const rank = request->rank(); - std::cout << functor.name << " rank " << rank << ": waiting for current sequence number\n"; + LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number\n"; cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); - std::cout << functor.name << " rank " << rank << ": handling request\n"; + LOG(INFO) << functor.name << " rank " << rank << ": handling request\n"; functor(request, buffer_); received_++; if (received_ == world_size_) { - std::cout << functor.name << " rank " << rank << ": all requests received\n"; + LOG(INFO) << functor.name << " rank " << rank << ": all requests received\n"; reply->set_receive_buffer(buffer_); sent_++; lock.unlock(); @@ -189,15 +192,15 @@ class FederatedService final : public Federated::Service { return grpc::Status::OK; } - std::cout << functor.name << " rank " << rank << ": waiting for all clients\n"; + LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients\n"; cv_.wait(lock, [this] { return received_ == world_size_; }); - std::cout << functor.name << " rank " << rank << ": sending reply\n"; + LOG(INFO) << functor.name << " rank " << rank << ": sending reply\n"; reply->set_receive_buffer(buffer_); sent_++; if (sent_ == world_size_) { - std::cout << functor.name << " rank " << rank << ": all replies sent\n"; + LOG(INFO) << functor.name << " rank " << rank << ": all replies sent\n"; sent_ = 0; received_ = 0; buffer_.clear(); @@ -221,15 +224,15 @@ class FederatedService final : public Federated::Service { mutable std::condition_variable cv_; }; -std::string ReadFile(std::string const& path) { - auto stream = std::ifstream(path.data()); +std::string ReadFile(char const* path) { + auto stream = std::ifstream(path); std::ostringstream out; out << stream.rdbuf(); return out.str(); } -void RunServer(int port, int world_size, std::string const& key_file, std::string const& cert_file, - std::string const& client_cert_file) { +void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, + char const* client_cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); FederatedService service{world_size}; @@ -238,30 +241,17 @@ void RunServer(int port, int world_size, std::string const& key_file, std::strin grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); options.pem_root_certs = ReadFile(client_cert_file); auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); - key.private_key = ReadFile(key_file); - key.cert_chain = ReadFile(cert_file); + key.private_key = ReadFile(server_key_file); + key.cert_chain = ReadFile(server_cert_file); options.pem_key_cert_pairs.push_back(key); builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); - std::cout << "Federated server listening on " << server_address << ", world size " << world_size - << '\n'; + LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " + << world_size << '\n'; server->Wait(); } -} // namespace xgboost::federated -int main(int argc, char** argv) { - if (argc != 6) { - std::cerr << "Usage: federated_server port world_size key_file cert_file client_cert_file" - << '\n'; - return 1; - } - auto port = std::stoi(argv[1]); - auto world_size = std::stoi(argv[2]); - std::string key_file = argv[3]; - std::string cert_file = argv[4]; - std::string client_cert_file = argv[5]; - xgboost::federated::RunServer(port, world_size, key_file, cert_file, client_cert_file); - return 0; -} +} // namespace federated +} // namespace xgboost diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h new file mode 100644 index 000000000000..7679e8779438 --- /dev/null +++ b/plugin/federated/federated_server.h @@ -0,0 +1,11 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +namespace xgboost { +namespace federated { + +void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, + char const* client_cert_file); + +} // namespace federated +} // namespace xgboost diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index aacd88dd26a1..bb6a8ea921a1 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -12,6 +12,7 @@ from . import tracker # noqa from .tracker import RabitTracker # noqa from . import dask +from .federated import run_federated_server try: from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker @@ -51,4 +52,6 @@ "XGBRFRegressor", # dask "dask", + # federated + "run_federated_server", ] diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py new file mode 100644 index 000000000000..0a28885c06ae --- /dev/null +++ b/python-package/xgboost/federated.py @@ -0,0 +1,30 @@ +"""XGBoost Federated Learning related API.""" + +from .core import _LIB, _check_call, c_str + + +def run_federated_server(port: int, + world_size: int, + server_key_path: str, + server_cert_path: str, + client_cert_path: str) -> None: + """Run the Federated Learning server. + + Parameters + ---------- + port : int + The port to listen on. + world_size: int + The number of federated workers. + server_key_path: str + Path to the server private key file. + server_cert_path: str + Path to the server certificate file. + client_cert_path: str + Path to the client certificate file. + """ + _check_call(_LIB.XGBRunFederatedServer(port, + world_size, + c_str(server_key_path), + c_str(server_cert_path), + c_str(client_cert_path))) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 20970f82d0ee..ae98a9309e70 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -28,6 +28,10 @@ #include "../data/simple_dmatrix.h" #include "../data/proxy_dmatrix.h" +#if defined(XGBOOST_USE_FEDERATED) +#include "../../../plugin/federated/federated_server.h" +#endif + using namespace xgboost; // NOLINT(*); XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) { @@ -1346,5 +1350,14 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, API_END(); } +#if defined(XGBOOST_USE_FEDERATED) +XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path, + char const *server_cert_path, char const *client_cert_path) { + API_BEGIN(); + federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path); + API_END(); +} +#endif + // force link rabit static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag(); diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh index 559b8762a01c..77724aa969ea 100755 --- a/tests/distributed/runtests-federated.sh +++ b/tests/distributed/runtests-federated.sh @@ -2,34 +2,16 @@ set -e -trap "kill 0" EXIT - rm -f ./*.model* ./agaricus* ./*.pem -port=9091 world_size=3 # Generate server and client certificates. openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost" openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost" -# Start the federated server. -../../build/plugin/federated/federated_server ${port} ${world_size} server-key.pem server-cert.pem client-cert.pem >/dev/null & - # Split train and test files manually to simulate a federated environment. split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train- split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test- -export FEDERATED_SERVER_ADDRESS="localhost:${port}" -export FEDERATED_WORLD_SIZE=${world_size} -export FEDERATED_SERVER_CERT=server-cert.pem -export FEDERATED_CLIENT_KEY=client-key.pem -export FEDERATED_CLIENT_CERT=client-cert.pem -for ((rank = 0; rank < world_size; rank++)); do - FEDERATED_RANK=${rank} python test_federated.py & - pids[${rank}]=$! -done - -for pid in ${pids[*]}; do - wait $pid -done +python test_federated.py ${world_size} diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index 4043b62a94ee..9f97c227959e 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -1,32 +1,72 @@ #!/usr/bin/python -import os +import multiprocessing +import sys import xgboost as xgb -# Always call this before using distributed module -xgb.rabit.init() +SERVER_KEY = 'server-key.pem' +SERVER_CERT = 'server-cert.pem' +CLIENT_KEY = 'client-key.pem' +CLIENT_CERT = 'client-cert.pem' -# Load file, file will not be sharded in federated mode. -rank = int(os.getenv('FEDERATED_RANK')) -dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) -dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) -# Specify parameters via map, definition are same as c++ version -param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} +def run_server(port: int, world_size: int) -> None: + xgb.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, CLIENT_CERT) -# Specify validations set to watch performance -watchlist = [(dtest, 'eval'), (dtrain, 'train')] -num_round = 20 -# Run training, all the features in training API is available. -# Currently, this script only support calling train once for fault recovery purpose. -bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2) +def run_worker(port: int, world_size: int, rank: int) -> None: + # Always call this before using distributed module + rabit_env = [ + f'federated_server_address=localhost:{port}', + f'federated_world_size={world_size}', + f'federated_rank={rank}', + f'federated_server_cert={SERVER_CERT}', + f'federated_client_key={CLIENT_KEY}', + f'federated_client_cert={CLIENT_CERT}' + ] + xgb.rabit.init([e.encode() for e in rabit_env]) -# Save the model, only ask process 0 to save the model. -if xgb.rabit.get_rank() == 0: - bst.save_model("test.model.json") - xgb.rabit.tracker_print("Finished training\n") + # Load file, file will not be sharded in federated mode. + dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) + dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) -# Notify the tracker all training has been successful -# This is only needed in distributed training. -xgb.rabit.finalize() + # Specify parameters via map, definition are same as c++ version + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 20 + + # Run training, all the features in training API is available. + # Currently, this script only support calling train once for fault recovery purpose. + bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2) + + # Save the model, only ask process 0 to save the model. + if xgb.rabit.get_rank() == 0: + bst.save_model("test.model.json") + xgb.rabit.tracker_print("Finished training\n") + + # Notify the tracker all training has been successful + # This is only needed in distributed training. + xgb.rabit.finalize() + + +def run_test() -> None: + port = 9091 + world_size = int(sys.argv[1]) + + server = multiprocessing.Process(target=run_server, args=(port, world_size)) + server.start() + + workers = [] + for rank in range(world_size): + worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank)) + workers.append(worker) + worker.start() + for worker in workers: + worker.join() + server.terminate() + + +if __name__ == '__main__': + run_test() From b7ba8ae1e52e871afccef25fdfc340c0599d1066 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 28 Apr 2022 14:09:57 -0700 Subject: [PATCH 18/21] add federated server unit tests --- plugin/federated/federated_client.h | 11 +- plugin/federated/federated_server.cc | 125 +++++++++------------ plugin/federated/federated_server.h | 33 ++++++ plugin/federated/test_client.cc | 44 -------- tests/cpp/CMakeLists.txt | 7 ++ tests/cpp/plugin/test_federated_server.cc | 130 ++++++++++++++++++++++ 6 files changed, 229 insertions(+), 121 deletions(-) delete mode 100644 plugin/federated/test_client.cc diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index c97a1ff7ab71..5aacef3a621a 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -18,9 +18,8 @@ namespace federated { */ class FederatedClient { public: - explicit FederatedClient(std::string const &server_address, int rank, - std::string const &server_cert, std::string const &client_key, - std::string const &client_cert) + FederatedClient(std::string const &server_address, int rank, std::string const &server_cert, + std::string const &client_key, std::string const &client_cert) : stub_{[&] { grpc::SslCredentialsOptions options; options.pem_root_certs = server_cert; @@ -31,6 +30,12 @@ class FederatedClient { }()}, rank_{rank} {} + /** @brief Insecure client for testing only. */ + FederatedClient(std::string const &server_address, int rank) + : stub_{Federated::NewStub( + grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))}, + rank_{rank} {} + std::string Allgather(std::string const &send_buffer) { AllgatherRequest request; request.set_sequence_number(sequence_number_++); diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index 0d95559de850..c9651797a42c 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -3,14 +3,11 @@ */ #include "federated_server.h" -#include #include #include #include -#include #include -#include #include namespace xgboost { @@ -139,90 +136,70 @@ class BroadcastFunctor { } }; -class FederatedService final : public Federated::Service { - public: - explicit FederatedService(int const world_size) - : world_size_{world_size}, - allgather_functor_{world_size}, - allreduce_functor_{}, - broadcast_functor_{} {} +grpc::Status FederatedService::Allgather(grpc::ServerContext* context, + AllgatherRequest const* request, AllgatherReply* reply) { + return Handle(request, reply, AllgatherFunctor{world_size_}); +} - grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, - AllgatherReply* reply) override { - return Handle(request, reply, allgather_functor_); - } +grpc::Status FederatedService::Allreduce(grpc::ServerContext* context, + AllreduceRequest const* request, AllreduceReply* reply) { + return Handle(request, reply, AllreduceFunctor{}); +} - grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, - AllreduceReply* reply) override { - return Handle(request, reply, allreduce_functor_); - } +grpc::Status FederatedService::Broadcast(grpc::ServerContext* context, + BroadcastRequest const* request, BroadcastReply* reply) { + return Handle(request, reply, BroadcastFunctor{}); +} - grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, - BroadcastReply* reply) override { - return Handle(request, reply, broadcast_functor_); +template +grpc::Status FederatedService::Handle(Request const* request, Reply* reply, + RequestFunctor const& functor) { + // Pass through if there is only 1 client. + if (world_size_ == 1) { + reply->set_receive_buffer(request->send_buffer()); + return grpc::Status::OK; } - private: - template - grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor) { - // Pass through if there is only 1 client. - if (world_size_ == 1) { - reply->set_receive_buffer(request->send_buffer()); - return grpc::Status::OK; - } - - std::unique_lock lock(mutex_); - - auto const sequence_number = request->sequence_number(); - auto const rank = request->rank(); - - LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number\n"; - cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); + std::unique_lock lock(mutex_); - LOG(INFO) << functor.name << " rank " << rank << ": handling request\n"; - functor(request, buffer_); - received_++; + auto const sequence_number = request->sequence_number(); + auto const rank = request->rank(); - if (received_ == world_size_) { - LOG(INFO) << functor.name << " rank " << rank << ": all requests received\n"; - reply->set_receive_buffer(buffer_); - sent_++; - lock.unlock(); - cv_.notify_all(); - return grpc::Status::OK; - } + LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number"; + cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); - LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients\n"; - cv_.wait(lock, [this] { return received_ == world_size_; }); + LOG(INFO) << functor.name << " rank " << rank << ": handling request"; + functor(request, buffer_); + received_++; - LOG(INFO) << functor.name << " rank " << rank << ": sending reply\n"; + if (received_ == world_size_) { + LOG(INFO) << functor.name << " rank " << rank << ": all requests received"; reply->set_receive_buffer(buffer_); sent_++; - - if (sent_ == world_size_) { - LOG(INFO) << functor.name << " rank " << rank << ": all replies sent\n"; - sent_ = 0; - received_ = 0; - buffer_.clear(); - sequence_number_++; - lock.unlock(); - cv_.notify_all(); - } - + lock.unlock(); + cv_.notify_all(); return grpc::Status::OK; } - int const world_size_; - AllgatherFunctor allgather_functor_; - AllreduceFunctor allreduce_functor_; - BroadcastFunctor broadcast_functor_; - int received_{}; - int sent_{}; - std::string buffer_{}; - uint64_t sequence_number_{}; - mutable std::mutex mutex_; - mutable std::condition_variable cv_; -}; + LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients"; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + LOG(INFO) << functor.name << " rank " << rank << ": sending reply"; + reply->set_receive_buffer(buffer_); + sent_++; + + if (sent_ == world_size_) { + LOG(INFO) << functor.name << " rank " << rank << ": all replies sent"; + sent_ = 0; + received_ = 0; + buffer_.clear(); + sequence_number_++; + lock.unlock(); + cv_.notify_all(); + } + + return grpc::Status::OK; +} std::string ReadFile(char const* path) { auto stream = std::ifstream(path); @@ -248,7 +225,7 @@ void RunServer(int port, int world_size, char const* server_key_file, char const builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " - << world_size << '\n'; + << world_size; server->Wait(); } diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 7679e8779438..108d78d47a96 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -1,9 +1,42 @@ /*! * Copyright 2022 XGBoost contributors */ +#pragma once + +#include + +#include +#include + namespace xgboost { namespace federated { +class FederatedService final : public Federated::Service { + public: + explicit FederatedService(int const world_size) : world_size_{world_size} {} + + grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, + AllgatherReply* reply) override; + + grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, + AllreduceReply* reply) override; + + grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, + BroadcastReply* reply) override; + + private: + template + grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor); + + int const world_size_; + int received_{}; + int sent_{}; + std::string buffer_{}; + uint64_t sequence_number_{}; + mutable std::mutex mutex_; + mutable std::condition_variable cv_; +}; + void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file); diff --git a/plugin/federated/test_client.cc b/plugin/federated/test_client.cc deleted file mode 100644 index 16c080cdb28e..000000000000 --- a/plugin/federated/test_client.cc +++ /dev/null @@ -1,44 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include - -#include "federated_client.h" - -int main(int argc, char **argv) { - if (argc != 3) { - std::cerr << "Usage: federated_client server_address(host:port) rank" << '\n'; - return 1; - } - auto const server_address = argv[1]; - auto const rank = std::stoi(argv[2]); - xgboost::federated::FederatedClient client(server_address, rank); - - for (int i = 1; i <= 10; i++) { - // Allgather. - std::string allgather_send = "hello " + std::to_string(rank) + ":" + std::to_string(i) + " "; - auto const allgather_receive = client.Allgather(allgather_send); - std::cout << "Allgather rank " << rank << ": " << allgather_receive << '\n'; - - // Allreduce. - int data[] = {1 * i, 2 * i, 3 * i, 4 * i, 5 * i}; - int n = sizeof(data) / sizeof(data[0]); - std::string send_buffer(reinterpret_cast(data), sizeof(data)); - auto receive_buffer = - client.Allreduce(send_buffer, xgboost::federated::INT, xgboost::federated::SUM); - auto *result = reinterpret_cast(receive_buffer.data()); - std::cout << "Allreduce rank " << rank << ": "; - std::copy(result, result + n, std::ostream_iterator(std::cout, " ")); - std::cout << '\n'; - - // Broadcast. - std::string broadcast_send{}; - if (rank == 0) { - broadcast_send = "hello " + std::to_string(i); - } - auto const broadcast_receive = client.Broadcast(broadcast_send, 0); - std::cout << "Broadcast rank " << rank << ": " << broadcast_receive << '\n'; - } - - return 0; -} diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index a346e0f03c98..c2bb353841f4 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -25,6 +25,13 @@ if (USE_CUDA AND PLUGIN_RMM) target_include_directories(testxgboost PRIVATE ${CUDA_INCLUDE_DIRS}) endif (USE_CUDA AND PLUGIN_RMM) +if (PLUGIN_FEDERATED) + target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) + target_link_libraries(testxgboost PRIVATE federated_client) +else(PLUGIN_FEDERATED) + list(REMOVE_ITEM TEST_SOURCES "plugin/test_federated_server.cc") +endif(PLUGIN_FEDERATED) + target_include_directories(testxgboost PRIVATE ${GTEST_INCLUDE_DIRS} diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index e69de29bb2d1..b20c17b09de5 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -0,0 +1,130 @@ +/*! + * Copyright 2017-2020 XGBoost contributors + */ +#include +#include + +#include + +#include "federated_client.h" +#include "federated_server.h" + +namespace xgboost { + +class FederatedServerTest : public ::testing::Test { + public: + static void VerifyAllgather(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + CheckAllgather(client, rank); + } + + static void VerifyAllreduce(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + CheckAllreduce(client); + } + + static void VerifyBroadcast(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + CheckBroadcast(client, rank); + } + + static void VerifyMixture(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + for (auto i = 0; i < 10; i++) { + CheckAllgather(client, rank); + CheckAllreduce(client); + CheckBroadcast(client, rank); + } + } + + protected: + void SetUp() override { + server_thread_.reset(new std::thread([this] { + grpc::ServerBuilder builder; + federated::FederatedService service{kWorldSize}; + builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + server_->Wait(); + })); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_->join(); + } + + static void CheckAllgather(federated::FederatedClient& client, int rank) { + auto reply = client.Allgather("hello " + std::to_string(rank) + " "); + EXPECT_EQ(reply, "hello 0 hello 1 hello 2 "); + } + + static void CheckAllreduce(federated::FederatedClient& client) { + int data[] = {1, 2, 3, 4, 5}; + std::string send_buffer(reinterpret_cast(data), sizeof(data)); + auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM); + auto const* result = reinterpret_cast(reply.data()); + int expected[] = {3, 6, 9, 12, 15}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(result[i], expected[i]); + } + } + + static void CheckBroadcast(federated::FederatedClient& client, int rank) { + std::string send_buffer{}; + if (rank == 0) { + send_buffer = "hello broadcast"; + } + auto reply = client.Broadcast(send_buffer, 0); + EXPECT_EQ(reply, "hello broadcast"); + } + + static int const kWorldSize{3}; + static std::string const kServerAddress; + std::unique_ptr server_thread_; + std::unique_ptr server_; +}; + +std::string const FederatedServerTest::kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) + +TEST_F(FederatedServerTest, Allgather) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedServerTest, Allreduce) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedServerTest, Broadcast) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedServerTest, Mixture) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +} // namespace xgboost From f2164c68bb8ddf3eb57cbbce6e4e69ea399c863f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 28 Apr 2022 17:11:54 -0700 Subject: [PATCH 19/21] exclude federated tests when plugin not enabled --- tests/cpp/CMakeLists.txt | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index c2bb353841f4..9dfd429d01e2 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -18,6 +18,14 @@ if (NOT PLUGIN_UPDATER_ONEAPI) list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES}) endif (NOT PLUGIN_UPDATER_ONEAPI) +if (PLUGIN_FEDERATED) + target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) + target_link_libraries(testxgboost PRIVATE federated_client) +else (PLUGIN_FEDERATED) + file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc") + list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES}) +endif (PLUGIN_FEDERATED) + target_sources(testxgboost PRIVATE ${TEST_SOURCES} ${xgboost_SOURCE_DIR}/plugin/example/custom_obj.cc) if (USE_CUDA AND PLUGIN_RMM) @@ -25,13 +33,6 @@ if (USE_CUDA AND PLUGIN_RMM) target_include_directories(testxgboost PRIVATE ${CUDA_INCLUDE_DIRS}) endif (USE_CUDA AND PLUGIN_RMM) -if (PLUGIN_FEDERATED) - target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) - target_link_libraries(testxgboost PRIVATE federated_client) -else(PLUGIN_FEDERATED) - list(REMOVE_ITEM TEST_SOURCES "plugin/test_federated_server.cc") -endif(PLUGIN_FEDERATED) - target_include_directories(testxgboost PRIVATE ${GTEST_INCLUDE_DIRS} From bb0896e3ae8fea71a863bc3e01b89800ed38d97f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 29 Apr 2022 09:02:35 -0700 Subject: [PATCH 20/21] revert accidiental change --- rabit/include/rabit/internal/engine.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rabit/include/rabit/internal/engine.h b/rabit/include/rabit/internal/engine.h index 88dd263f491f..50b452f8db1a 100644 --- a/rabit/include/rabit/internal/engine.h +++ b/rabit/include/rabit/internal/engine.h @@ -260,7 +260,7 @@ class ReduceHandle { * with the type the reduce function needs to deal with * the reduce function MUST be communicative */ - void Init(IEngine::ReduceFunction redfunc, __attribute__((unused)) size_t type_nbytes); + void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes); /*! * \brief customized in-place all reduce operation * \param sendrecvbuf the in place send-recv buffer From ba5202156aef7b110b773e0c1809a6daa2350907 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 2 May 2022 13:11:21 -0700 Subject: [PATCH 21/21] address review comments --- plugin/federated/CMakeLists.txt | 4 ++-- python-package/xgboost/__init__.py | 3 --- python-package/xgboost/federated.py | 18 ++++++++++++------ src/c_api/c_api.cc | 8 +++++++- tests/distributed/test_federated.py | 8 +++++++- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index a72bd3ea0d1e..b84fbb7592a1 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -1,6 +1,6 @@ # gRPC needs to be installed first. See README.md. -find_package(protobuf CONFIG REQUIRED) -find_package(gRPC CONFIG REQUIRED) +find_package(Protobuf REQUIRED) +find_package(gRPC REQUIRED) find_package(Threads) # Generated code from the protobuf definition. diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index bb6a8ea921a1..aacd88dd26a1 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -12,7 +12,6 @@ from . import tracker # noqa from .tracker import RabitTracker # noqa from . import dask -from .federated import run_federated_server try: from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker @@ -52,6 +51,4 @@ "XGBRFRegressor", # dask "dask", - # federated - "run_federated_server", ] diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py index 0a28885c06ae..369f6790f583 100644 --- a/python-package/xgboost/federated.py +++ b/python-package/xgboost/federated.py @@ -1,6 +1,6 @@ """XGBoost Federated Learning related API.""" -from .core import _LIB, _check_call, c_str +from .core import _LIB, _check_call, c_str, build_info, XGBoostError def run_federated_server(port: int, @@ -23,8 +23,14 @@ def run_federated_server(port: int, client_cert_path: str Path to the client certificate file. """ - _check_call(_LIB.XGBRunFederatedServer(port, - world_size, - c_str(server_key_path), - c_str(server_cert_path), - c_str(client_cert_path))) + if build_info()['USE_FEDERATED']: + _check_call(_LIB.XGBRunFederatedServer(port, + world_size, + c_str(server_key_path), + c_str(server_cert_path), + c_str(client_cert_path))) + else: + raise XGBoostError( + "XGBoost needs to be built with the federated learning plugin " + "enabled in order to use this module" + ) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ae98a9309e70..3c7c539802fa 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -29,7 +29,7 @@ #include "../data/proxy_dmatrix.h" #if defined(XGBOOST_USE_FEDERATED) -#include "../../../plugin/federated/federated_server.h" +#include "../../plugin/federated/federated_server.h" #endif using namespace xgboost; // NOLINT(*); @@ -99,6 +99,12 @@ XGB_DLL int XGBuildInfo(char const **out) { info["DEBUG"] = Boolean{false}; #endif +#if defined(XGBOOST_USE_FEDERATED) + info["USE_FEDERATED"] = Boolean{true}; +#else + info["USE_FEDERATED"] = Boolean{false}; +#endif + XGBBuildInfoDevice(&info); auto &out_str = GlobalConfigAPIThreadLocalStore::Get()->ret_str; diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index 9f97c227959e..5b5b167fcd32 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -1,8 +1,10 @@ #!/usr/bin/python import multiprocessing import sys +import time import xgboost as xgb +import xgboost.federated SERVER_KEY = 'server-key.pem' SERVER_CERT = 'server-cert.pem' @@ -11,7 +13,8 @@ def run_server(port: int, world_size: int) -> None: - xgb.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, CLIENT_CERT) + xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, + CLIENT_CERT) def run_worker(port: int, world_size: int, rank: int) -> None: @@ -57,6 +60,9 @@ def run_test() -> None: server = multiprocessing.Process(target=run_server, args=(port, world_size)) server.start() + time.sleep(1) + if not server.is_alive(): + raise Exception("Error starting Federated Learning server") workers = [] for rank in range(world_size):