Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial support for federated learning (#7831)
Federated learning plugin for xgboost: * A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers. * A Rabit engine for the federated environment. * Integration test to simulate federated learning. Additional followups are needed to address GPU support, better security, and privacy, etc.
- Loading branch information
Showing
16 changed files
with
1,087 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# gRPC needs to be installed first. See README.md. | ||
find_package(Protobuf REQUIRED) | ||
find_package(gRPC REQUIRED) | ||
find_package(Threads) | ||
|
||
# Generated code from the protobuf definition. | ||
add_library(federated_proto federated.proto) | ||
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) | ||
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) | ||
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON) | ||
|
||
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) | ||
protobuf_generate(TARGET federated_proto LANGUAGE cpp) | ||
protobuf_generate( | ||
TARGET federated_proto | ||
LANGUAGE grpc | ||
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc | ||
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}") | ||
|
||
# Wrapper for the gRPC client. | ||
add_library(federated_client INTERFACE federated_client.h) | ||
target_link_libraries(federated_client INTERFACE federated_proto) | ||
|
||
# Rabit engine for Federated Learning. | ||
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc) | ||
target_link_libraries(objxgboost PRIVATE federated_client) | ||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
XGBoost Plugin for Federated Learning | ||
===================================== | ||
|
||
This folder contains the plugin for federated learning. Follow these steps to build and test it. | ||
|
||
Install gRPC | ||
------------ | ||
```shell | ||
sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build | ||
git clone -b v1.45.2 https://github.com/grpc/grpc | ||
cd grpc | ||
git submodule update --init | ||
cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON | ||
cmake --build build --target install | ||
``` | ||
|
||
Build the Plugin | ||
---------------- | ||
```shell | ||
# Under xgboost source tree. | ||
mkdir build | ||
cd build | ||
cmake .. -GNinja -DPLUGIN_FEDERATED=ON | ||
ninja | ||
cd ../python-package | ||
pip install -e . # or equivalently python setup.py develop | ||
``` | ||
|
||
Test Federated XGBoost | ||
---------------------- | ||
```shell | ||
# Under xgboost source tree. | ||
cd tests/distributed | ||
./runtests-federated.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
/*! | ||
* Copyright 2022 XGBoost contributors | ||
*/ | ||
#include <cstdio> | ||
#include <fstream> | ||
#include <sstream> | ||
|
||
#include "federated_client.h" | ||
#include "rabit/internal/engine.h" | ||
#include "rabit/internal/utils.h" | ||
|
||
namespace MPI { // NOLINT | ||
// MPI data type to be compatible with existing MPI interface | ||
class Datatype { | ||
public: | ||
size_t type_size; | ||
explicit Datatype(size_t type_size) : type_size(type_size) {} | ||
}; | ||
} // namespace MPI | ||
|
||
namespace rabit { | ||
namespace engine { | ||
|
||
/*! \brief implementation of engine using federated learning */ | ||
class FederatedEngine : public IEngine { | ||
public: | ||
void Init(int argc, char *argv[]) { | ||
// Parse environment variables first. | ||
for (auto const &env_var : env_vars_) { | ||
char const *value = getenv(env_var.c_str()); | ||
if (value != nullptr) { | ||
SetParam(env_var, value); | ||
} | ||
} | ||
// Command line argument overrides. | ||
for (int i = 0; i < argc; ++i) { | ||
std::string const key_value = argv[i]; | ||
auto const delimiter = key_value.find('='); | ||
if (delimiter != std::string::npos) { | ||
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1)); | ||
} | ||
} | ||
utils::Printf("Connecting to federated server %s, world size %d, rank %d", | ||
server_address_.c_str(), world_size_, rank_); | ||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_, | ||
client_key_, client_cert_)); | ||
} | ||
|
||
void Finalize() { client_.reset(); } | ||
|
||
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, | ||
size_t size_prev_slice) override { | ||
throw std::logic_error("FederatedEngine:: Allgather is not supported"); | ||
} | ||
|
||
std::string Allgather(void *sendbuf, size_t total_size) { | ||
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size); | ||
return client_->Allgather(send_buffer); | ||
} | ||
|
||
void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer, | ||
PreprocFunction prepare_fun, void *prepare_arg) override { | ||
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead"); | ||
} | ||
|
||
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { | ||
auto *buffer = reinterpret_cast<char *>(sendrecvbuf); | ||
std::string const send_buffer(buffer, size); | ||
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op)); | ||
receive_buffer.copy(buffer, size); | ||
} | ||
|
||
int GetRingPrevRank() const override { | ||
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported"); | ||
} | ||
|
||
void Broadcast(void *sendrecvbuf, size_t size, int root) override { | ||
if (world_size_ == 1) return; | ||
auto *buffer = reinterpret_cast<char *>(sendrecvbuf); | ||
std::string const send_buffer(buffer, size); | ||
auto const receive_buffer = client_->Broadcast(send_buffer, root); | ||
if (rank_ != root) { | ||
receive_buffer.copy(buffer, size); | ||
} | ||
} | ||
|
||
int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override { | ||
return 0; | ||
} | ||
|
||
void CheckPoint(const Serializable *global_model, | ||
const Serializable *local_model = nullptr) override { | ||
version_number_ += 1; | ||
} | ||
|
||
void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; } | ||
|
||
int VersionNumber() const override { return version_number_; } | ||
|
||
/*! \brief get rank of current node */ | ||
int GetRank() const override { return rank_; } | ||
|
||
/*! \brief get total number of */ | ||
int GetWorldSize() const override { return world_size_; } | ||
|
||
/*! \brief whether it is distributed */ | ||
bool IsDistributed() const override { return true; } | ||
|
||
/*! \brief get the host name of current node */ | ||
std::string GetHost() const override { return "rank" + std::to_string(rank_); } | ||
|
||
void TrackerPrint(const std::string &msg) override { | ||
// simply print information into the tracker | ||
if (GetRank() == 0) { | ||
utils::Printf("%s", msg.c_str()); | ||
} | ||
} | ||
|
||
private: | ||
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */ | ||
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) { | ||
switch (data_type) { | ||
case mpi::kChar: | ||
return xgboost::federated::CHAR; | ||
case mpi::kUChar: | ||
return xgboost::federated::UCHAR; | ||
case mpi::kInt: | ||
return xgboost::federated::INT; | ||
case mpi::kUInt: | ||
return xgboost::federated::UINT; | ||
case mpi::kLong: | ||
return xgboost::federated::LONG; | ||
case mpi::kULong: | ||
return xgboost::federated::ULONG; | ||
case mpi::kFloat: | ||
return xgboost::federated::FLOAT; | ||
case mpi::kDouble: | ||
return xgboost::federated::DOUBLE; | ||
case mpi::kLongLong: | ||
return xgboost::federated::LONGLONG; | ||
case mpi::kULongLong: | ||
return xgboost::federated::ULONGLONG; | ||
} | ||
utils::Error("unknown mpi::DataType"); | ||
return xgboost::federated::CHAR; | ||
} | ||
|
||
/** @brief Transform mpi::OpType to enum to MPI OP */ | ||
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) { | ||
switch (op_type) { | ||
case mpi::kMax: | ||
return xgboost::federated::MAX; | ||
case mpi::kMin: | ||
return xgboost::federated::MIN; | ||
case mpi::kSum: | ||
return xgboost::federated::SUM; | ||
case mpi::kBitwiseOR: | ||
utils::Error("Bitwise OR is not supported"); | ||
return xgboost::federated::MAX; | ||
} | ||
utils::Error("unknown mpi::OpType"); | ||
return xgboost::federated::MAX; | ||
} | ||
|
||
void SetParam(std::string const &name, std::string const &val) { | ||
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { | ||
server_address_ = val; | ||
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) { | ||
world_size_ = std::stoi(val); | ||
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { | ||
rank_ = std::stoi(val); | ||
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) { | ||
server_cert_ = ReadFile(val); | ||
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) { | ||
client_key_ = ReadFile(val); | ||
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) { | ||
client_cert_ = ReadFile(val); | ||
} | ||
} | ||
|
||
static std::string ReadFile(std::string const &path) { | ||
auto stream = std::ifstream(path.data()); | ||
std::ostringstream out; | ||
out << stream.rdbuf(); | ||
return out.str(); | ||
} | ||
|
||
// clang-format off | ||
std::vector<std::string> const env_vars_{ | ||
"FEDERATED_SERVER_ADDRESS", | ||
"FEDERATED_WORLD_SIZE", | ||
"FEDERATED_RANK", | ||
"FEDERATED_SERVER_CERT", | ||
"FEDERATED_CLIENT_KEY", | ||
"FEDERATED_CLIENT_CERT" }; | ||
// clang-format on | ||
std::string server_address_{"localhost:9091"}; | ||
int world_size_{1}; | ||
int rank_{0}; | ||
std::string server_cert_{}; | ||
std::string client_key_{}; | ||
std::string client_cert_{}; | ||
std::unique_ptr<xgboost::federated::FederatedClient> client_{}; | ||
int version_number_{0}; | ||
}; | ||
|
||
// Singleton federated engine. | ||
FederatedEngine engine; // NOLINT(cert-err58-cpp) | ||
|
||
/*! \brief initialize the synchronization module */ | ||
bool Init(int argc, char *argv[]) { | ||
try { | ||
engine.Init(argc, argv); | ||
return true; | ||
} catch (std::exception const &e) { | ||
fprintf(stderr, " failed in federated Init %s\n", e.what()); | ||
return false; | ||
} | ||
} | ||
|
||
/*! \brief finalize synchronization module */ | ||
bool Finalize() { | ||
try { | ||
engine.Finalize(); | ||
return true; | ||
} catch (const std::exception &e) { | ||
fprintf(stderr, "failed in federated shutdown %s\n", e.what()); | ||
return false; | ||
} | ||
} | ||
|
||
/*! \brief singleton method to get engine */ | ||
IEngine *GetEngine() { return &engine; } | ||
|
||
// perform in-place allreduce, on sendrecvbuf | ||
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red, | ||
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, | ||
void *prepare_arg) { | ||
if (prepare_fun != nullptr) prepare_fun(prepare_arg); | ||
if (engine.GetWorldSize() == 1) return; | ||
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op); | ||
} | ||
|
||
ReduceHandle::ReduceHandle() = default; | ||
ReduceHandle::~ReduceHandle() = default; | ||
|
||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast<int>(dtype.type_size); } | ||
|
||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, | ||
__attribute__((unused)) size_t type_nbytes) { | ||
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice"); | ||
redfunc_ = redfunc; | ||
} | ||
|
||
void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, | ||
IEngine::PreprocFunction prepare_fun, void *prepare_arg) { | ||
utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce"); | ||
if (prepare_fun != nullptr) prepare_fun(prepare_arg); | ||
if (engine.GetWorldSize() == 1) return; | ||
|
||
// Gather all the buffers and call the reduce function locally. | ||
auto const buffer_size = type_nbytes * count; | ||
auto const gathered = engine.Allgather(sendrecvbuf, buffer_size); | ||
auto const *data = gathered.data(); | ||
for (int i = 0; i < engine.GetWorldSize(); i++) { | ||
if (i != engine.GetRank()) { | ||
redfunc_(data + buffer_size * i, sendrecvbuf, static_cast<int>(count), | ||
MPI::Datatype(type_nbytes)); | ||
} | ||
} | ||
} | ||
|
||
} // namespace engine | ||
} // namespace rabit |
Oops, something went wrong.