Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common interface for collective communication #8057

Merged
merged 58 commits into from Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
a2df84d
implement broadcast for federated communicator
rongou Jul 6, 2022
4894334
implement allreduce
rongou Jul 6, 2022
c5fded6
add communicator factory
rongou Jul 11, 2022
f6b7259
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 11, 2022
b0831a0
add device adapter
rongou Jul 12, 2022
05b2ceb
add device communicator to factory
rongou Jul 13, 2022
f39d0a9
add rabit communicator
rongou Jul 16, 2022
f9319d7
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 16, 2022
0db9fda
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 18, 2022
cd0098d
add rabit communicator to the factory
rongou Jul 18, 2022
198ac94
add nccl device communicator
rongou Jul 18, 2022
4585888
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 19, 2022
8ae0d7a
add synchronize to device communicator
rongou Jul 19, 2022
33ebb0a
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 20, 2022
de9580c
add back print and getprocessorname
rongou Jul 21, 2022
69ea687
add python wrapper and c api
rongou Jul 22, 2022
76b0b0a
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 22, 2022
695de5f
clean up types
rongou Jul 22, 2022
e4d0029
fix non-gpu build
rongou Jul 22, 2022
d656387
try to fix ci
rongou Jul 22, 2022
92ae35e
fix std::size_t
rongou Jul 22, 2022
de52150
portable string compare ignore case
rongou Jul 23, 2022
9bc3df6
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 15, 2022
4e2a5b8
c style size_t
rongou Aug 15, 2022
131f4b1
fix lint errors
rongou Aug 15, 2022
88510a3
cross platform setenv
rongou Aug 15, 2022
e758449
fix memory leak
rongou Aug 15, 2022
8923609
fix lint errors
rongou Aug 15, 2022
22d4536
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 22, 2022
e4dfe18
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 23, 2022
3aeae65
address review feedback
rongou Aug 25, 2022
adeab7f
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 25, 2022
cfb7496
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 26, 2022
183ab75
add python test for rabit communicator
rongou Aug 27, 2022
e3c87e0
fix failing gtest
rongou Aug 27, 2022
fcfb1d5
use json to configure communicators
rongou Aug 30, 2022
b57a1be
fix lint error
rongou Aug 31, 2022
2d52b57
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 31, 2022
57dfe8e
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 1, 2022
ba5a6e1
get rid of factories
rongou Sep 1, 2022
2039858
fix cpu build
rongou Sep 1, 2022
c3a42fb
fix include
rongou Sep 1, 2022
0a115ff
fix python import
rongou Sep 1, 2022
63ae4e8
don't export collective.py yet
rongou Sep 1, 2022
fe79c38
skip collective communicator pytest on windows
rongou Sep 2, 2022
60da353
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 2, 2022
52c0ff3
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 2, 2022
56dec26
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 6, 2022
ed8fed2
add review feedback
rongou Sep 6, 2022
057bc49
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 7, 2022
615e621
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 7, 2022
d418563
update documentation
rongou Sep 7, 2022
c4cf82c
remove mpi communicator type
rongou Sep 7, 2022
5f0daa0
fix tests
rongou Sep 7, 2022
7a39ce3
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 8, 2022
e20b3db
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 9, 2022
cb5f4ad
shutdown the communicator separately
rongou Sep 10, 2022
06809c6
Merge remote-tracking branch 'upstream/master' into communicator
hcho3 Sep 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
180 changes: 180 additions & 0 deletions plugin/federated/federated_communicator.h
@@ -0,0 +1,180 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"

namespace xgboost {
namespace collective {

/**
* @brief A federated learning communicator class that handles collective communication.
*/
class FederatedCommunicator : public Communicator {
public:
/**
* @brief Construct a new federated communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
std::string const &server_cert_path, std::string const &client_key_path,
std::string const &client_cert_path)
: Communicator{world_size, rank} {
client_.reset(new xgboost::federated::FederatedClient(
server_address, rank, xgboost::common::ReadAll(server_cert_path),
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
}

/** @brief Insecure communicator for testing only. */
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
: Communicator{world_size, rank} {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
}

~FederatedCommunicator() override { client_.reset(); }

void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
count * GetTypeSize(data_type));
auto const received =
client_->Allreduce(send_buffer, ConvertDataType(data_type), ConvertOperation(op));
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
}

void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
if (GetWorldSize() == 1) return;
if (GetRank() == root) {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
client_->Broadcast(send_buffer, root);
} else {
auto const received = client_->Broadcast("", root);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}
}

private:
static xgboost::federated::DataType ConvertDataType(DataType data_type) {
xgboost::federated::DataType result{};
switch (data_type) {
case DataType::kInt:
result = xgboost::federated::DataType::INT;
break;
case DataType::kFloat:
result = xgboost::federated::DataType::FLOAT;
break;
case DataType::kDouble:
result = xgboost::federated::DataType::DOUBLE;
break;
}
return result;
}

static xgboost::federated::ReduceOperation ConvertOperation(Operation operation) {
xgboost::federated::ReduceOperation result{};
switch (operation) {
case Operation::kMax:
result = xgboost::federated::ReduceOperation::MAX;
break;
case Operation::kSum:
result = xgboost::federated::ReduceOperation::SUM;
break;
}
return result;
}

std::unique_ptr<xgboost::federated::FederatedClient> client_{};
};

class FederatedCommunicatorFactory {
public:
FederatedCommunicatorFactory(int argc, char *argv[]) {
// Parse environment variables first.
for (auto const &env_var : env_vars_) {
char const *value = getenv(env_var.c_str());
if (value != nullptr) {
SetParam(env_var, value);
}
}

// Command line argument overrides.
for (int i = 0; i < argc; ++i) {
std::string const key_value = argv[i];
auto const delimiter = key_value.find('=');
if (delimiter != std::string::npos) {
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
}
}
}

Communicator *Create() {
if (server_address_.empty()) {
LOG(FATAL) << "Federated server address must be set.";
}
if (world_size_ == 0) {
LOG(FATAL) << "Federated world size must be set.";
}
if (rank_ == -1) {
LOG(FATAL) << "Federated rank must be set.";
}
if (server_cert_.empty()) {
LOG(FATAL) << "Federated server cert must be set.";
}
if (client_key_.empty()) {
LOG(FATAL) << "Federated client key must be set.";
}
if (client_cert_.empty()) {
LOG(FATAL) << "Federated client cert must be set.";
}
return new FederatedCommunicator(world_size_, rank_, server_address_, server_cert_, client_key_,
client_cert_);
}

std::string const &GetServerAddress() const { return server_address_; }
int GetWorldSize() const { return world_size_; }
int GetRank() const { return rank_; }
std::string const &GetServerCert() const { return server_cert_; }
std::string const &GetClientKey() const { return client_key_; }
std::string const &GetClientCert() const { return client_cert_; }

private:
void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
world_size_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
rank_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
server_cert_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
client_key_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
client_cert_ = val;
}
}

// clang-format off
std::vector<std::string> const env_vars_{
"FEDERATED_SERVER_ADDRESS",
"FEDERATED_WORLD_SIZE",
"FEDERATED_RANK",
"FEDERATED_SERVER_CERT",
"FEDERATED_CLIENT_KEY",
"FEDERATED_CLIENT_CERT" };
// clang-format on

std::string server_address_{};
int world_size_{0};
int rank_{-1};
std::string server_cert_{};
std::string client_key_{};
std::string client_cert_{};
};

} // namespace collective
} // namespace xgboost
15 changes: 5 additions & 10 deletions plugin/federated/federated_server.cc
Expand Up @@ -10,6 +10,8 @@
#include <fstream>
#include <sstream>

#include "../../src/common/io.h"

namespace xgboost {
namespace federated {

Expand Down Expand Up @@ -201,13 +203,6 @@ grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
return grpc::Status::OK;
}

std::string ReadFile(char const* path) {
auto stream = std::ifstream(path);
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}

void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
Expand All @@ -216,10 +211,10 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
grpc::ServerBuilder builder;
auto options =
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = ReadFile(client_cert_file);
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = ReadFile(server_key_file);
key.cert_chain = ReadFile(server_cert_file);
key.private_key = xgboost::common::ReadAll(server_key_file);
key.cert_chain = xgboost::common::ReadAll(server_cert_file);
options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
Expand Down
96 changes: 96 additions & 0 deletions src/collective/communicator.h
@@ -0,0 +1,96 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/logging.h>

namespace xgboost {
namespace collective {

/** @brief Defines the integral and floating data types. */
enum class DataType { kInt, kFloat, kDouble, kSizeT };

inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt:
size = sizeof(int);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
case DataType::kSizeT:
size = sizeof(std::size_t);
break;
}
return size;
}

/** @brief Defines the reduction operation. */
enum class Operation { kMax, kSum };

/**
* @brief A communicator class that handles collective communication.
*/
class Communicator {
public:
/**
* @brief Construct a new communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
if (world_size < 1) {
LOG(FATAL) << "World size " << world_size << " is less than 1.";
}
if (rank < 0) {
LOG(FATAL) << "Rank " << rank << " is less than 0.";
}
if (rank >= world_size) {
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
}
}

virtual ~Communicator() = default;

/** @brief Get the total number of processes. */
int GetWorldSize() const { return world_size_; }

/** @brief Get the rank of the current processes. */
int GetRank() const { return rank_; }

/** @brief Whether the communicator is running in distributed mode. */
bool IsDistributed() const { return world_size_ > 1; };

/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;

/**
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
* group.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param root Rank of broadcast root.
*/
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;

private:
int const world_size_;
int const rank_;
};

} // namespace collective
} // namespace xgboost
68 changes: 68 additions & 0 deletions src/collective/communicator_factory.cu
@@ -0,0 +1,68 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator_factory.h"
#include "device_communicator_adapter.cuh"

namespace xgboost {
namespace collective {

thread_local std::unique_ptr<CommunicatorFactory> CommunicatorFactory::instance_{};

void CommunicatorFactory::Init(int argc, char* argv[]) {
if (instance_) {
LOG(FATAL) << "Communicator factory can only be initialized once.";
}

auto type = GetTypeFromEnv();
auto const arg = GetTypeFromArgs(argc, argv);
if (arg != CommunicatorType::kUnknown) {
type = arg;
}
switch (type) {
case CommunicatorType::kRabit:
LOG(FATAL) << "Not implemented yet.";
break;
case CommunicatorType::kMPI:
LOG(FATAL) << "Not implemented yet.";
break;
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
FederatedCommunicatorFactory factory{argc, argv};
auto* comm = factory.Create();
instance_.reset(new CommunicatorFactory(type, comm));
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
break;
}
}

void CommunicatorFactory::Finalize() { instance_.reset(); }

CommunicatorFactory::CommunicatorFactory(CommunicatorType type, Communicator* communicator)
: type_{type}, communicator_{communicator} {}

DeviceCommunicator* CommunicatorFactory::GetDeviceCommunicator(int device_ordinal) {
if (!device_communicator_) {
#ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) {
// Use NCCL communicator.
LOG(FATAL) << "Not implemented yet.";
} else {
device_communicator_.reset(
new DeviceCommunicatorAdapter(device_ordinal, communicator_.get()));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, communicator_.get()));
#endif
}
return device_communicator_.get();
}

} // namespace collective
} // namespace xgboost