diff --git a/CMakeLists.txt b/CMakeLists.txt index 11178c449119..f3e66f77997a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,7 @@ address, leak, undefined and thread.") ## Plugins option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF) option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF) +option(PLUGIN_FEDERATED "Build with Federated Learning" OFF) ## TODO: 1. Add check if DPC++ compiler is used for building option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF) option(ADD_PKGCONFIG "Add xgboost.pc into system." ON) diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index 97d1af190cfe..9f59c68f14e0 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -40,3 +40,8 @@ if (PLUGIN_UPDATER_ONEAPI) # Add all objects of oneapi_plugin to objxgboost target_sources(objxgboost INTERFACE $) endif (PLUGIN_UPDATER_ONEAPI) + +# Add the Federate Learning plugin if enabled. +if (PLUGIN_FEDERATED) + add_subdirectory(federated) +endif (PLUGIN_FEDERATED) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt new file mode 100644 index 000000000000..b84fbb7592a1 --- /dev/null +++ b/plugin/federated/CMakeLists.txt @@ -0,0 +1,27 @@ +# gRPC needs to be installed first. See README.md. +find_package(Protobuf REQUIRED) +find_package(gRPC REQUIRED) +find_package(Threads) + +# Generated code from the protobuf definition. +add_library(federated_proto federated.proto) +target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) +target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON) + +get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) +protobuf_generate(TARGET federated_proto LANGUAGE cpp) +protobuf_generate( + TARGET federated_proto + LANGUAGE grpc + GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc + PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}") + +# Wrapper for the gRPC client. +add_library(federated_client INTERFACE federated_client.h) +target_link_libraries(federated_client INTERFACE federated_proto) + +# Rabit engine for Federated Learning. +target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc) +target_link_libraries(objxgboost PRIVATE federated_client) +target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/README.md b/plugin/federated/README.md new file mode 100644 index 000000000000..a5fa95e0c140 --- /dev/null +++ b/plugin/federated/README.md @@ -0,0 +1,35 @@ +XGBoost Plugin for Federated Learning +===================================== + +This folder contains the plugin for federated learning. Follow these steps to build and test it. + +Install gRPC +------------ +```shell +sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build +git clone -b v1.45.2 https://github.com/grpc/grpc +cd grpc +git submodule update --init +cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON +cmake --build build --target install +``` + +Build the Plugin +---------------- +```shell +# Under xgboost source tree. +mkdir build +cd build +cmake .. -GNinja -DPLUGIN_FEDERATED=ON +ninja +cd ../python-package +pip install -e . # or equivalently python setup.py develop +``` + +Test Federated XGBoost +---------------------- +```shell +# Under xgboost source tree. +cd tests/distributed +./runtests-federated.sh +``` diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc new file mode 100644 index 000000000000..ed7252ba117c --- /dev/null +++ b/plugin/federated/engine_federated.cc @@ -0,0 +1,274 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include +#include + +#include "federated_client.h" +#include "rabit/internal/engine.h" +#include "rabit/internal/utils.h" + +namespace MPI { // NOLINT +// MPI data type to be compatible with existing MPI interface +class Datatype { + public: + size_t type_size; + explicit Datatype(size_t type_size) : type_size(type_size) {} +}; +} // namespace MPI + +namespace rabit { +namespace engine { + +/*! \brief implementation of engine using federated learning */ +class FederatedEngine : public IEngine { + public: + void Init(int argc, char *argv[]) { + // Parse environment variables first. + for (auto const &env_var : env_vars_) { + char const *value = getenv(env_var.c_str()); + if (value != nullptr) { + SetParam(env_var, value); + } + } + // Command line argument overrides. + for (int i = 0; i < argc; ++i) { + std::string const key_value = argv[i]; + auto const delimiter = key_value.find('='); + if (delimiter != std::string::npos) { + SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1)); + } + } + utils::Printf("Connecting to federated server %s, world size %d, rank %d", + server_address_.c_str(), world_size_, rank_); + client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_, + client_key_, client_cert_)); + } + + void Finalize() { client_.reset(); } + + void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, + size_t size_prev_slice) override { + throw std::logic_error("FederatedEngine:: Allgather is not supported"); + } + + std::string Allgather(void *sendbuf, size_t total_size) { + std::string const send_buffer(reinterpret_cast(sendbuf), total_size); + return client_->Allgather(send_buffer); + } + + void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer, + PreprocFunction prepare_fun, void *prepare_arg) override { + throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead"); + } + + void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { + auto *buffer = reinterpret_cast(sendrecvbuf); + std::string const send_buffer(buffer, size); + auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op)); + receive_buffer.copy(buffer, size); + } + + int GetRingPrevRank() const override { + throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported"); + } + + void Broadcast(void *sendrecvbuf, size_t size, int root) override { + if (world_size_ == 1) return; + auto *buffer = reinterpret_cast(sendrecvbuf); + std::string const send_buffer(buffer, size); + auto const receive_buffer = client_->Broadcast(send_buffer, root); + if (rank_ != root) { + receive_buffer.copy(buffer, size); + } + } + + int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override { + return 0; + } + + void CheckPoint(const Serializable *global_model, + const Serializable *local_model = nullptr) override { + version_number_ += 1; + } + + void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; } + + int VersionNumber() const override { return version_number_; } + + /*! \brief get rank of current node */ + int GetRank() const override { return rank_; } + + /*! \brief get total number of */ + int GetWorldSize() const override { return world_size_; } + + /*! \brief whether it is distributed */ + bool IsDistributed() const override { return true; } + + /*! \brief get the host name of current node */ + std::string GetHost() const override { return "rank" + std::to_string(rank_); } + + void TrackerPrint(const std::string &msg) override { + // simply print information into the tracker + if (GetRank() == 0) { + utils::Printf("%s", msg.c_str()); + } + } + + private: + /** @brief Transform mpi::DataType to xgboost::federated::DataType. */ + static xgboost::federated::DataType GetDataType(mpi::DataType data_type) { + switch (data_type) { + case mpi::kChar: + return xgboost::federated::CHAR; + case mpi::kUChar: + return xgboost::federated::UCHAR; + case mpi::kInt: + return xgboost::federated::INT; + case mpi::kUInt: + return xgboost::federated::UINT; + case mpi::kLong: + return xgboost::federated::LONG; + case mpi::kULong: + return xgboost::federated::ULONG; + case mpi::kFloat: + return xgboost::federated::FLOAT; + case mpi::kDouble: + return xgboost::federated::DOUBLE; + case mpi::kLongLong: + return xgboost::federated::LONGLONG; + case mpi::kULongLong: + return xgboost::federated::ULONGLONG; + } + utils::Error("unknown mpi::DataType"); + return xgboost::federated::CHAR; + } + + /** @brief Transform mpi::OpType to enum to MPI OP */ + static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) { + switch (op_type) { + case mpi::kMax: + return xgboost::federated::MAX; + case mpi::kMin: + return xgboost::federated::MIN; + case mpi::kSum: + return xgboost::federated::SUM; + case mpi::kBitwiseOR: + utils::Error("Bitwise OR is not supported"); + return xgboost::federated::MAX; + } + utils::Error("unknown mpi::OpType"); + return xgboost::federated::MAX; + } + + void SetParam(std::string const &name, std::string const &val) { + if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { + server_address_ = val; + } else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) { + world_size_ = std::stoi(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { + rank_ = std::stoi(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) { + server_cert_ = ReadFile(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) { + client_key_ = ReadFile(val); + } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) { + client_cert_ = ReadFile(val); + } + } + + static std::string ReadFile(std::string const &path) { + auto stream = std::ifstream(path.data()); + std::ostringstream out; + out << stream.rdbuf(); + return out.str(); + } + + // clang-format off + std::vector const env_vars_{ + "FEDERATED_SERVER_ADDRESS", + "FEDERATED_WORLD_SIZE", + "FEDERATED_RANK", + "FEDERATED_SERVER_CERT", + "FEDERATED_CLIENT_KEY", + "FEDERATED_CLIENT_CERT" }; + // clang-format on + std::string server_address_{"localhost:9091"}; + int world_size_{1}; + int rank_{0}; + std::string server_cert_{}; + std::string client_key_{}; + std::string client_cert_{}; + std::unique_ptr client_{}; + int version_number_{0}; +}; + +// Singleton federated engine. +FederatedEngine engine; // NOLINT(cert-err58-cpp) + +/*! \brief initialize the synchronization module */ +bool Init(int argc, char *argv[]) { + try { + engine.Init(argc, argv); + return true; + } catch (std::exception const &e) { + fprintf(stderr, " failed in federated Init %s\n", e.what()); + return false; + } +} + +/*! \brief finalize synchronization module */ +bool Finalize() { + try { + engine.Finalize(); + return true; + } catch (const std::exception &e) { + fprintf(stderr, "failed in federated shutdown %s\n", e.what()); + return false; + } +} + +/*! \brief singleton method to get engine */ +IEngine *GetEngine() { return &engine; } + +// perform in-place allreduce, on sendrecvbuf +void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red, + mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, + void *prepare_arg) { + if (prepare_fun != nullptr) prepare_fun(prepare_arg); + if (engine.GetWorldSize() == 1) return; + engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op); +} + +ReduceHandle::ReduceHandle() = default; +ReduceHandle::~ReduceHandle() = default; + +int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast(dtype.type_size); } + +void ReduceHandle::Init(IEngine::ReduceFunction redfunc, + __attribute__((unused)) size_t type_nbytes) { + utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice"); + redfunc_ = redfunc; +} + +void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, + IEngine::PreprocFunction prepare_fun, void *prepare_arg) { + utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce"); + if (prepare_fun != nullptr) prepare_fun(prepare_arg); + if (engine.GetWorldSize() == 1) return; + + // Gather all the buffers and call the reduce function locally. + auto const buffer_size = type_nbytes * count; + auto const gathered = engine.Allgather(sendrecvbuf, buffer_size); + auto const *data = gathered.data(); + for (int i = 0; i < engine.GetWorldSize(); i++) { + if (i != engine.GetRank()) { + redfunc_(data + buffer_size * i, sendrecvbuf, static_cast(count), + MPI::Datatype(type_nbytes)); + } + } +} + +} // namespace engine +} // namespace rabit diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto new file mode 100644 index 000000000000..cba897c0ea81 --- /dev/null +++ b/plugin/federated/federated.proto @@ -0,0 +1,68 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +syntax = "proto3"; + +package xgboost.federated; + +service Federated { + rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} + rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} + rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} +} + +enum DataType { + CHAR = 0; + UCHAR = 1; + INT = 2; + UINT = 3; + LONG = 4; + ULONG = 5; + FLOAT = 6; + DOUBLE = 7; + LONGLONG = 8; + ULONGLONG = 9; +} + +enum ReduceOperation { + MAX = 0; + MIN = 1; + SUM = 2; +} + +message AllgatherRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherReply { + bytes receive_buffer = 1; +} + +message AllreduceRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + DataType data_type = 4; + ReduceOperation reduce_operation = 5; +} + +message AllreduceReply { + bytes receive_buffer = 1; +} + +message BroadcastRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + // The root rank to broadcast from. + int32 root = 4; +} + +message BroadcastReply { + bytes receive_buffer = 1; +} diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h new file mode 100644 index 000000000000..5aacef3a621a --- /dev/null +++ b/plugin/federated/federated_client.h @@ -0,0 +1,104 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include +#include +#include + +#include +#include +#include + +namespace xgboost { +namespace federated { + +/** + * @brief A wrapper around the gRPC client. + */ +class FederatedClient { + public: + FederatedClient(std::string const &server_address, int rank, std::string const &server_cert, + std::string const &client_key, std::string const &client_cert) + : stub_{[&] { + grpc::SslCredentialsOptions options; + options.pem_root_certs = server_cert; + options.pem_private_key = client_key; + options.pem_cert_chain = client_cert; + return Federated::NewStub( + grpc::CreateChannel(server_address, grpc::SslCredentials(options))); + }()}, + rank_{rank} {} + + /** @brief Insecure client for testing only. */ + FederatedClient(std::string const &server_address, int rank) + : stub_{Federated::NewStub( + grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))}, + rank_{rank} {} + + std::string Allgather(std::string const &send_buffer) { + AllgatherRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(rank_); + request.set_send_buffer(send_buffer); + + AllgatherReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->Allgather(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("Allgather RPC failed"); + } + } + + std::string Allreduce(std::string const &send_buffer, DataType data_type, + ReduceOperation reduce_operation) { + AllreduceRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(rank_); + request.set_send_buffer(send_buffer); + request.set_data_type(data_type); + request.set_reduce_operation(reduce_operation); + + AllreduceReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->Allreduce(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("Allreduce RPC failed"); + } + } + + std::string Broadcast(std::string const &send_buffer, int root) { + BroadcastRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(rank_); + request.set_send_buffer(send_buffer); + request.set_root(root); + + BroadcastReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->Broadcast(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("Broadcast RPC failed"); + } + } + + private: + std::unique_ptr const stub_; + int const rank_; + uint64_t sequence_number_{}; +}; + +} // namespace federated +} // namespace xgboost diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc new file mode 100644 index 000000000000..c9651797a42c --- /dev/null +++ b/plugin/federated/federated_server.cc @@ -0,0 +1,234 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "federated_server.h" + +#include +#include +#include + +#include +#include + +namespace xgboost { +namespace federated { + +class AllgatherFunctor { + public: + std::string const name{"Allgather"}; + + explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {} + + void operator()(AllgatherRequest const* request, std::string& buffer) const { + auto const rank = request->rank(); + auto const& send_buffer = request->send_buffer(); + auto const send_size = send_buffer.size(); + // Resize the buffer if this is the first request. + if (buffer.size() != send_size * world_size_) { + buffer.resize(send_size * world_size_); + } + // Splice the send_buffer into the common buffer. + buffer.replace(rank * send_size, send_size, send_buffer); + } + + private: + int const world_size_; +}; + +class AllreduceFunctor { + public: + std::string const name{"Allreduce"}; + + void operator()(AllreduceRequest const* request, std::string& buffer) const { + if (buffer.empty()) { + // Copy the send_buffer if this is the first request. + buffer = request->send_buffer(); + } else { + // Apply the reduce_operation to the send_buffer and the common buffer. + Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation()); + } + } + + private: + template + void Accumulate(T* buffer, T const* input, std::size_t n, + ReduceOperation reduce_operation) const { + switch (reduce_operation) { + case ReduceOperation::MAX: + std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); }); + break; + case ReduceOperation::MIN: + std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); }); + break; + case ReduceOperation::SUM: + std::transform(buffer, buffer + n, input, buffer, std::plus()); + break; + default: + throw std::invalid_argument("Invalid reduce operation"); + } + } + + void Accumulate(std::string& buffer, std::string const& input, DataType data_type, + ReduceOperation reduce_operation) const { + switch (data_type) { + case DataType::CHAR: + Accumulate(&buffer[0], reinterpret_cast(input.data()), buffer.size(), + reduce_operation); + break; + case DataType::UCHAR: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), buffer.size(), + reduce_operation); + break; + case DataType::INT: + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), + buffer.size() / sizeof(int), reduce_operation); + break; + case DataType::UINT: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(unsigned int), reduce_operation); + break; + case DataType::LONG: + Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), + buffer.size() / sizeof(long), reduce_operation); + break; + case DataType::ULONG: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(unsigned long), reduce_operation); + break; + case DataType::FLOAT: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), buffer.size() / sizeof(float), + reduce_operation); + break; + case DataType::DOUBLE: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), buffer.size() / sizeof(double), + reduce_operation); + break; + case DataType::LONGLONG: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(long long), reduce_operation); + break; + case DataType::ULONGLONG: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(unsigned long long), reduce_operation); + break; + default: + throw std::invalid_argument("Invalid data type"); + } + } +}; + +class BroadcastFunctor { + public: + std::string const name{"Broadcast"}; + + void operator()(BroadcastRequest const* request, std::string& buffer) const { + if (request->rank() == request->root()) { + // Copy the send_buffer if this is the root. + buffer = request->send_buffer(); + } + } +}; + +grpc::Status FederatedService::Allgather(grpc::ServerContext* context, + AllgatherRequest const* request, AllgatherReply* reply) { + return Handle(request, reply, AllgatherFunctor{world_size_}); +} + +grpc::Status FederatedService::Allreduce(grpc::ServerContext* context, + AllreduceRequest const* request, AllreduceReply* reply) { + return Handle(request, reply, AllreduceFunctor{}); +} + +grpc::Status FederatedService::Broadcast(grpc::ServerContext* context, + BroadcastRequest const* request, BroadcastReply* reply) { + return Handle(request, reply, BroadcastFunctor{}); +} + +template +grpc::Status FederatedService::Handle(Request const* request, Reply* reply, + RequestFunctor const& functor) { + // Pass through if there is only 1 client. + if (world_size_ == 1) { + reply->set_receive_buffer(request->send_buffer()); + return grpc::Status::OK; + } + + std::unique_lock lock(mutex_); + + auto const sequence_number = request->sequence_number(); + auto const rank = request->rank(); + + LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number"; + cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); + + LOG(INFO) << functor.name << " rank " << rank << ": handling request"; + functor(request, buffer_); + received_++; + + if (received_ == world_size_) { + LOG(INFO) << functor.name << " rank " << rank << ": all requests received"; + reply->set_receive_buffer(buffer_); + sent_++; + lock.unlock(); + cv_.notify_all(); + return grpc::Status::OK; + } + + LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients"; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + LOG(INFO) << functor.name << " rank " << rank << ": sending reply"; + reply->set_receive_buffer(buffer_); + sent_++; + + if (sent_ == world_size_) { + LOG(INFO) << functor.name << " rank " << rank << ": all replies sent"; + sent_ = 0; + received_ = 0; + buffer_.clear(); + sequence_number_++; + lock.unlock(); + cv_.notify_all(); + } + + return grpc::Status::OK; +} + +std::string ReadFile(char const* path) { + auto stream = std::ifstream(path); + std::ostringstream out; + out << stream.rdbuf(); + return out.str(); +} + +void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, + char const* client_cert_file) { + std::string const server_address = "0.0.0.0:" + std::to_string(port); + FederatedService service{world_size}; + + grpc::ServerBuilder builder; + auto options = + grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + options.pem_root_certs = ReadFile(client_cert_file); + auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); + key.private_key = ReadFile(server_key_file); + key.cert_chain = ReadFile(server_cert_file); + options.pem_key_cert_pairs.push_back(key); + builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " + << world_size; + + server->Wait(); +} + +} // namespace federated +} // namespace xgboost diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h new file mode 100644 index 000000000000..108d78d47a96 --- /dev/null +++ b/plugin/federated/federated_server.h @@ -0,0 +1,44 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once + +#include + +#include +#include + +namespace xgboost { +namespace federated { + +class FederatedService final : public Federated::Service { + public: + explicit FederatedService(int const world_size) : world_size_{world_size} {} + + grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, + AllgatherReply* reply) override; + + grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, + AllreduceReply* reply) override; + + grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, + BroadcastReply* reply) override; + + private: + template + grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor); + + int const world_size_; + int received_{}; + int sent_{}; + std::string buffer_{}; + uint64_t sequence_number_{}; + mutable std::mutex mutex_; + mutable std::condition_variable cv_; +}; + +void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, + char const* client_cert_file); + +} // namespace federated +} // namespace xgboost diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py new file mode 100644 index 000000000000..369f6790f583 --- /dev/null +++ b/python-package/xgboost/federated.py @@ -0,0 +1,36 @@ +"""XGBoost Federated Learning related API.""" + +from .core import _LIB, _check_call, c_str, build_info, XGBoostError + + +def run_federated_server(port: int, + world_size: int, + server_key_path: str, + server_cert_path: str, + client_cert_path: str) -> None: + """Run the Federated Learning server. + + Parameters + ---------- + port : int + The port to listen on. + world_size: int + The number of federated workers. + server_key_path: str + Path to the server private key file. + server_cert_path: str + Path to the server certificate file. + client_cert_path: str + Path to the client certificate file. + """ + if build_info()['USE_FEDERATED']: + _check_call(_LIB.XGBRunFederatedServer(port, + world_size, + c_str(server_key_path), + c_str(server_cert_path), + c_str(client_cert_path))) + else: + raise XGBoostError( + "XGBoost needs to be built with the federated learning plugin " + "enabled in order to use this module" + ) diff --git a/rabit/CMakeLists.txt b/rabit/CMakeLists.txt index ad39fb249791..3a76794f5f58 100644 --- a/rabit/CMakeLists.txt +++ b/rabit/CMakeLists.txt @@ -6,7 +6,9 @@ set(RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc ${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc) -if (RABIT_BUILD_MPI) +if (PLUGIN_FEDERATED) + # Skip the engine if the Federated Learning plugin is enabled. +elseif (RABIT_BUILD_MPI) list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc) elseif (RABIT_MOCK) list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a11602a56610..3c7c539802fa 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -28,6 +28,10 @@ #include "../data/simple_dmatrix.h" #include "../data/proxy_dmatrix.h" +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_server.h" +#endif + using namespace xgboost; // NOLINT(*); XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) { @@ -95,6 +99,12 @@ XGB_DLL int XGBuildInfo(char const **out) { info["DEBUG"] = Boolean{false}; #endif +#if defined(XGBOOST_USE_FEDERATED) + info["USE_FEDERATED"] = Boolean{true}; +#else + info["USE_FEDERATED"] = Boolean{false}; +#endif + XGBBuildInfoDevice(&info); auto &out_str = GlobalConfigAPIThreadLocalStore::Get()->ret_str; @@ -198,11 +208,15 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, DMatrixHandle *out) { API_BEGIN(); bool load_row_split = false; +#if defined(XGBOOST_USE_FEDERATED) + LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; +#else if (rabit::IsDistributed()) { LOG(CONSOLE) << "XGBoost distributed mode detected, " << "will split data among workers"; load_row_split = true; } +#endif *out = new std::shared_ptr(DMatrix::Load(fname, silent != 0, load_row_split)); API_END(); } @@ -1342,5 +1356,14 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, API_END(); } +#if defined(XGBOOST_USE_FEDERATED) +XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path, + char const *server_cert_path, char const *client_cert_path) { + API_BEGIN(); + federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path); + API_END(); +} +#endif + // force link rabit static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag(); diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index a346e0f03c98..9dfd429d01e2 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -18,6 +18,14 @@ if (NOT PLUGIN_UPDATER_ONEAPI) list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES}) endif (NOT PLUGIN_UPDATER_ONEAPI) +if (PLUGIN_FEDERATED) + target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) + target_link_libraries(testxgboost PRIVATE federated_client) +else (PLUGIN_FEDERATED) + file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc") + list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES}) +endif (PLUGIN_FEDERATED) + target_sources(testxgboost PRIVATE ${TEST_SOURCES} ${xgboost_SOURCE_DIR}/plugin/example/custom_obj.cc) if (USE_CUDA AND PLUGIN_RMM) diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc new file mode 100644 index 000000000000..b20c17b09de5 --- /dev/null +++ b/tests/cpp/plugin/test_federated_server.cc @@ -0,0 +1,130 @@ +/*! + * Copyright 2017-2020 XGBoost contributors + */ +#include +#include + +#include + +#include "federated_client.h" +#include "federated_server.h" + +namespace xgboost { + +class FederatedServerTest : public ::testing::Test { + public: + static void VerifyAllgather(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + CheckAllgather(client, rank); + } + + static void VerifyAllreduce(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + CheckAllreduce(client); + } + + static void VerifyBroadcast(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + CheckBroadcast(client, rank); + } + + static void VerifyMixture(int rank) { + federated::FederatedClient client{kServerAddress, rank}; + for (auto i = 0; i < 10; i++) { + CheckAllgather(client, rank); + CheckAllreduce(client); + CheckBroadcast(client, rank); + } + } + + protected: + void SetUp() override { + server_thread_.reset(new std::thread([this] { + grpc::ServerBuilder builder; + federated::FederatedService service{kWorldSize}; + builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + server_->Wait(); + })); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_->join(); + } + + static void CheckAllgather(federated::FederatedClient& client, int rank) { + auto reply = client.Allgather("hello " + std::to_string(rank) + " "); + EXPECT_EQ(reply, "hello 0 hello 1 hello 2 "); + } + + static void CheckAllreduce(federated::FederatedClient& client) { + int data[] = {1, 2, 3, 4, 5}; + std::string send_buffer(reinterpret_cast(data), sizeof(data)); + auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM); + auto const* result = reinterpret_cast(reply.data()); + int expected[] = {3, 6, 9, 12, 15}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(result[i], expected[i]); + } + } + + static void CheckBroadcast(federated::FederatedClient& client, int rank) { + std::string send_buffer{}; + if (rank == 0) { + send_buffer = "hello broadcast"; + } + auto reply = client.Broadcast(send_buffer, 0); + EXPECT_EQ(reply, "hello broadcast"); + } + + static int const kWorldSize{3}; + static std::string const kServerAddress; + std::unique_ptr server_thread_; + std::unique_ptr server_; +}; + +std::string const FederatedServerTest::kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) + +TEST_F(FederatedServerTest, Allgather) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedServerTest, Allreduce) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedServerTest, Broadcast) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedServerTest, Mixture) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank)); + } + for (auto& thread : threads) { + thread.join(); + } +} + +} // namespace xgboost diff --git a/tests/distributed/runtests-federated.sh b/tests/distributed/runtests-federated.sh new file mode 100755 index 000000000000..77724aa969ea --- /dev/null +++ b/tests/distributed/runtests-federated.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -e + +rm -f ./*.model* ./agaricus* ./*.pem + +world_size=3 + +# Generate server and client certificates. +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost" +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost" + +# Split train and test files manually to simulate a federated environment. +split -n l/${world_size} -d ../../demo/data/agaricus.txt.train agaricus.txt.train- +split -n l/${world_size} -d ../../demo/data/agaricus.txt.test agaricus.txt.test- + +python test_federated.py ${world_size} diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py new file mode 100644 index 000000000000..5b5b167fcd32 --- /dev/null +++ b/tests/distributed/test_federated.py @@ -0,0 +1,78 @@ +#!/usr/bin/python +import multiprocessing +import sys +import time + +import xgboost as xgb +import xgboost.federated + +SERVER_KEY = 'server-key.pem' +SERVER_CERT = 'server-cert.pem' +CLIENT_KEY = 'client-key.pem' +CLIENT_CERT = 'client-cert.pem' + + +def run_server(port: int, world_size: int) -> None: + xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, + CLIENT_CERT) + + +def run_worker(port: int, world_size: int, rank: int) -> None: + # Always call this before using distributed module + rabit_env = [ + f'federated_server_address=localhost:{port}', + f'federated_world_size={world_size}', + f'federated_rank={rank}', + f'federated_server_cert={SERVER_CERT}', + f'federated_client_key={CLIENT_KEY}', + f'federated_client_cert={CLIENT_CERT}' + ] + xgb.rabit.init([e.encode() for e in rabit_env]) + + # Load file, file will not be sharded in federated mode. + dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) + dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) + + # Specify parameters via map, definition are same as c++ version + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 20 + + # Run training, all the features in training API is available. + # Currently, this script only support calling train once for fault recovery purpose. + bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2) + + # Save the model, only ask process 0 to save the model. + if xgb.rabit.get_rank() == 0: + bst.save_model("test.model.json") + xgb.rabit.tracker_print("Finished training\n") + + # Notify the tracker all training has been successful + # This is only needed in distributed training. + xgb.rabit.finalize() + + +def run_test() -> None: + port = 9091 + world_size = int(sys.argv[1]) + + server = multiprocessing.Process(target=run_server, args=(port, world_size)) + server.start() + time.sleep(1) + if not server.is_alive(): + raise Exception("Error starting Federated Learning server") + + workers = [] + for rank in range(world_size): + worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank)) + workers.append(worker) + worker.start() + for worker in workers: + worker.join() + server.terminate() + + +if __name__ == '__main__': + run_test()