Skip to content

Commit

Permalink
Initial support for federated learning (#7831)
Browse files Browse the repository at this point in the history
Federated learning plugin for xgboost:
* A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers.
* A Rabit engine for the federated environment.
* Integration test to simulate federated learning.

Additional followups are needed to address GPU support, better security, and privacy, etc.
  • Loading branch information
rongou committed May 5, 2022
1 parent 46e0bce commit 14ef38b
Show file tree
Hide file tree
Showing 16 changed files with 1,087 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Expand Up @@ -66,6 +66,7 @@ address, leak, undefined and thread.")
## Plugins
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
option(PLUGIN_FEDERATED "Build with Federated Learning" OFF)
## TODO: 1. Add check if DPC++ compiler is used for building
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
Expand Down
5 changes: 5 additions & 0 deletions plugin/CMakeLists.txt
Expand Up @@ -40,3 +40,8 @@ if (PLUGIN_UPDATER_ONEAPI)
# Add all objects of oneapi_plugin to objxgboost
target_sources(objxgboost INTERFACE $<TARGET_OBJECTS:oneapi_plugin>)
endif (PLUGIN_UPDATER_ONEAPI)

# Add the Federate Learning plugin if enabled.
if (PLUGIN_FEDERATED)
add_subdirectory(federated)
endif (PLUGIN_FEDERATED)
27 changes: 27 additions & 0 deletions plugin/federated/CMakeLists.txt
@@ -0,0 +1,27 @@
# gRPC needs to be installed first. See README.md.
find_package(Protobuf REQUIRED)
find_package(gRPC REQUIRED)
find_package(Threads)

# Generated code from the protobuf definition.
add_library(federated_proto federated.proto)
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)

get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
protobuf_generate(
TARGET federated_proto
LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")

# Wrapper for the gRPC client.
add_library(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
target_link_libraries(objxgboost PRIVATE federated_client)
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
35 changes: 35 additions & 0 deletions plugin/federated/README.md
@@ -0,0 +1,35 @@
XGBoost Plugin for Federated Learning
=====================================

This folder contains the plugin for federated learning. Follow these steps to build and test it.

Install gRPC
------------
```shell
sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build
git clone -b v1.45.2 https://github.com/grpc/grpc
cd grpc
git submodule update --init
cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON
cmake --build build --target install
```

Build the Plugin
----------------
```shell
# Under xgboost source tree.
mkdir build
cd build
cmake .. -GNinja -DPLUGIN_FEDERATED=ON
ninja
cd ../python-package
pip install -e . # or equivalently python setup.py develop
```

Test Federated XGBoost
----------------------
```shell
# Under xgboost source tree.
cd tests/distributed
./runtests-federated.sh
```
274 changes: 274 additions & 0 deletions plugin/federated/engine_federated.cc
@@ -0,0 +1,274 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <cstdio>
#include <fstream>
#include <sstream>

#include "federated_client.h"
#include "rabit/internal/engine.h"
#include "rabit/internal/utils.h"

namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
} // namespace MPI

namespace rabit {
namespace engine {

/*! \brief implementation of engine using federated learning */
class FederatedEngine : public IEngine {
public:
void Init(int argc, char *argv[]) {
// Parse environment variables first.
for (auto const &env_var : env_vars_) {
char const *value = getenv(env_var.c_str());
if (value != nullptr) {
SetParam(env_var, value);
}
}
// Command line argument overrides.
for (int i = 0; i < argc; ++i) {
std::string const key_value = argv[i];
auto const delimiter = key_value.find('=');
if (delimiter != std::string::npos) {
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
}
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}

void Finalize() { client_.reset(); }

void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
size_t size_prev_slice) override {
throw std::logic_error("FederatedEngine:: Allgather is not supported");
}

std::string Allgather(void *sendbuf, size_t total_size) {
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
return client_->Allgather(send_buffer);
}

void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
PreprocFunction prepare_fun, void *prepare_arg) override {
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
}

void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
receive_buffer.copy(buffer, size);
}

int GetRingPrevRank() const override {
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
}

void Broadcast(void *sendrecvbuf, size_t size, int root) override {
if (world_size_ == 1) return;
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Broadcast(send_buffer, root);
if (rank_ != root) {
receive_buffer.copy(buffer, size);
}
}

int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override {
return 0;
}

void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}

void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; }

int VersionNumber() const override { return version_number_; }

/*! \brief get rank of current node */
int GetRank() const override { return rank_; }

/*! \brief get total number of */
int GetWorldSize() const override { return world_size_; }

/*! \brief whether it is distributed */
bool IsDistributed() const override { return true; }

/*! \brief get the host name of current node */
std::string GetHost() const override { return "rank" + std::to_string(rank_); }

void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}

private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}

/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}

void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
world_size_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
rank_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
server_cert_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
client_key_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
client_cert_ = ReadFile(val);
}
}

static std::string ReadFile(std::string const &path) {
auto stream = std::ifstream(path.data());
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}

// clang-format off
std::vector<std::string> const env_vars_{
"FEDERATED_SERVER_ADDRESS",
"FEDERATED_WORLD_SIZE",
"FEDERATED_RANK",
"FEDERATED_SERVER_CERT",
"FEDERATED_CLIENT_KEY",
"FEDERATED_CLIENT_CERT" };
// clang-format on
std::string server_address_{"localhost:9091"};
int world_size_{1};
int rank_{0};
std::string server_cert_{};
std::string client_key_{};
std::string client_cert_{};
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
int version_number_{0};
};

// Singleton federated engine.
FederatedEngine engine; // NOLINT(cert-err58-cpp)

/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
engine.Init(argc, argv);
return true;
} catch (std::exception const &e) {
fprintf(stderr, " failed in federated Init %s\n", e.what());
return false;
}
}

/*! \brief finalize synchronization module */
bool Finalize() {
try {
engine.Finalize();
return true;
} catch (const std::exception &e) {
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
return false;
}
}

/*! \brief singleton method to get engine */
IEngine *GetEngine() { return &engine; }

// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
}

ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;

int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast<int>(dtype.type_size); }

void ReduceHandle::Init(IEngine::ReduceFunction redfunc,
__attribute__((unused)) size_t type_nbytes) {
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}

void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun, void *prepare_arg) {
utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce");
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;

// Gather all the buffers and call the reduce function locally.
auto const buffer_size = type_nbytes * count;
auto const gathered = engine.Allgather(sendrecvbuf, buffer_size);
auto const *data = gathered.data();
for (int i = 0; i < engine.GetWorldSize(); i++) {
if (i != engine.GetRank()) {
redfunc_(data + buffer_size * i, sendrecvbuf, static_cast<int>(count),
MPI::Datatype(type_nbytes));
}
}
}

} // namespace engine
} // namespace rabit

0 comments on commit 14ef38b

Please sign in to comment.