Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common interface for collective communication #8057

Merged
merged 58 commits into from Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
a2df84d
implement broadcast for federated communicator
rongou Jul 6, 2022
4894334
implement allreduce
rongou Jul 6, 2022
c5fded6
add communicator factory
rongou Jul 11, 2022
f6b7259
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 11, 2022
b0831a0
add device adapter
rongou Jul 12, 2022
05b2ceb
add device communicator to factory
rongou Jul 13, 2022
f39d0a9
add rabit communicator
rongou Jul 16, 2022
f9319d7
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 16, 2022
0db9fda
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 18, 2022
cd0098d
add rabit communicator to the factory
rongou Jul 18, 2022
198ac94
add nccl device communicator
rongou Jul 18, 2022
4585888
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 19, 2022
8ae0d7a
add synchronize to device communicator
rongou Jul 19, 2022
33ebb0a
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 20, 2022
de9580c
add back print and getprocessorname
rongou Jul 21, 2022
69ea687
add python wrapper and c api
rongou Jul 22, 2022
76b0b0a
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 22, 2022
695de5f
clean up types
rongou Jul 22, 2022
e4d0029
fix non-gpu build
rongou Jul 22, 2022
d656387
try to fix ci
rongou Jul 22, 2022
92ae35e
fix std::size_t
rongou Jul 22, 2022
de52150
portable string compare ignore case
rongou Jul 23, 2022
9bc3df6
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 15, 2022
4e2a5b8
c style size_t
rongou Aug 15, 2022
131f4b1
fix lint errors
rongou Aug 15, 2022
88510a3
cross platform setenv
rongou Aug 15, 2022
e758449
fix memory leak
rongou Aug 15, 2022
8923609
fix lint errors
rongou Aug 15, 2022
22d4536
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 22, 2022
e4dfe18
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 23, 2022
3aeae65
address review feedback
rongou Aug 25, 2022
adeab7f
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 25, 2022
cfb7496
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 26, 2022
183ab75
add python test for rabit communicator
rongou Aug 27, 2022
e3c87e0
fix failing gtest
rongou Aug 27, 2022
fcfb1d5
use json to configure communicators
rongou Aug 30, 2022
b57a1be
fix lint error
rongou Aug 31, 2022
2d52b57
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 31, 2022
57dfe8e
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 1, 2022
ba5a6e1
get rid of factories
rongou Sep 1, 2022
2039858
fix cpu build
rongou Sep 1, 2022
c3a42fb
fix include
rongou Sep 1, 2022
0a115ff
fix python import
rongou Sep 1, 2022
63ae4e8
don't export collective.py yet
rongou Sep 1, 2022
fe79c38
skip collective communicator pytest on windows
rongou Sep 2, 2022
60da353
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 2, 2022
52c0ff3
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 2, 2022
56dec26
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 6, 2022
ed8fed2
add review feedback
rongou Sep 6, 2022
057bc49
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 7, 2022
615e621
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 7, 2022
d418563
update documentation
rongou Sep 7, 2022
c4cf82c
remove mpi communicator type
rongou Sep 7, 2022
5f0daa0
fix tests
rongou Sep 7, 2022
7a39ce3
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 8, 2022
e20b3db
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 9, 2022
cb5f4ad
shutdown the communicator separately
rongou Sep 10, 2022
06809c6
Merge remote-tracking branch 'upstream/master' into communicator
hcho3 Sep 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -71,6 +71,9 @@
#include "../src/logging.cc"
#include "../src/global_config.cc"

// collective
#include "../src/collective/communicator.cc"

// common
#include "../src/common/common.cc"
#include "../src/common/column_matrix.cc"
Expand Down
116 changes: 116 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -9,10 +9,12 @@

#ifdef __cplusplus
#define XGB_EXTERN_C extern "C"
#include <cstddef>
#include <cstdio>
#include <cstdint>
#else
#define XGB_EXTERN_C
#include <stddef.h>
#include <stdio.h>
#include <stdint.h>
#endif // __cplusplus
Expand Down Expand Up @@ -1386,4 +1388,118 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
bst_ulong *out_dim,
bst_ulong const **out_shape,
float const **out_scores);

/*!
* \brief Initialize the collective communicator.
rongou marked this conversation as resolved.
Show resolved Hide resolved
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (these are not case-sensitive):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
XGB_DLL int XGCommunicatorInit(char const* json_config);

/*!
* \brief finalize the collective communicator,
* call this function after you finished all jobs.
* \return true if the communicator is finalized successfully otherwise false
*/
XGB_DLL int XGCommunicatorFinalize(void);

/*!
* \brief get rank of current process
* \return rank number of worker
* */
XGB_DLL int XGCommunicatorGetRank(void);

/*!
* \brief get total number of process
* \return total world size
* */
XGB_DLL int XGCommunicatorGetWorldSize(void);

/*!
* \brief get if the communicator is distributed
* \return if the communicator is distributed
* */
XGB_DLL int XGCommunicatorIsDistributed(void);

/*!
* \brief print the msg to the communicator,
* this function can be used to communicate the information of the progress to
* the user who monitors the communicator
* \param message the message to be printed
*/
XGB_DLL int XGCommunicatorPrint(char const *message);

/*!
* \brief get name of processor
* \param name_str pointer to received returned processor name.
* \return 0 for success, -1 for failure
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);

/*!
* \brief broadcast an memory region to all others from root
*
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param send_receive_buffer the pointer to send or receive buffer,
* \param size the size of the data
* \param root the root of process
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);

/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size());
* ...
* \param send_receive_buffer buffer for both sending and receiving data
* \param count number of elements to be reduced
* \param enum_dtype the enumeration of data type, see xgboost::collective::DataType in communicator.h
* \param enum_op the enumeration of operation type, see xgboost::collective::Operation in communicator.h
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op);


#endif // XGBOOST_C_API_H_
49 changes: 3 additions & 46 deletions plugin/federated/engine_federated.cc
Expand Up @@ -71,7 +71,9 @@ class FederatedEngine : public IEngine {
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
auto const receive_buffer =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
static_cast<xgboost::federated::ReduceOperation>(op));
receive_buffer.copy(buffer, size);
}

Expand Down Expand Up @@ -113,51 +115,6 @@ class FederatedEngine : public IEngine {
}

private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}

/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}

void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
Expand Down
14 changes: 6 additions & 8 deletions plugin/federated/federated.proto
Expand Up @@ -12,16 +12,14 @@ service Federated {
}

enum DataType {
CHAR = 0;
UCHAR = 1;
INT = 2;
UINT = 3;
LONG = 4;
ULONG = 5;
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
LONGLONG = 8;
ULONGLONG = 9;
}

enum ReduceOperation {
Expand Down
150 changes: 150 additions & 0 deletions plugin/federated/federated_communicator.h
@@ -0,0 +1,150 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>

#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"

namespace xgboost {
namespace collective {

/**
* @brief A federated learning communicator class that handles collective communication.
*/
class FederatedCommunicator : public Communicator {
public:
static Communicator *Create(Json const &config) {
std::string server_address{};
int world_size{0};
int rank{-1};
std::string server_cert{};
std::string client_key{};
std::string client_cert{};

// Parse environment variables first.
auto *value = getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}
value = getenv("FEDERATED_SERVER_CERT");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT");
if (value != nullptr) {
client_cert = value;
}

// Runtime configuration overrides.
auto const &j_server_address = config["federated_server_address"];
if (IsA<String const>(j_server_address)) {
server_address = get<String const>(j_server_address);
}
auto const &j_world_size = config["federated_world_size"];
if (IsA<Integer const>(j_world_size)) {
world_size = static_cast<int>(get<Integer const>(j_world_size));
}
auto const &j_rank = config["federated_rank"];
if (IsA<Integer const>(j_rank)) {
rank = static_cast<int>(get<Integer const>(j_rank));
}
auto const &j_server_cert = config["federated_server_cert"];
if (IsA<String const>(j_server_cert)) {
server_cert = get<String const>(j_server_cert);
}
auto const &j_client_key = config["federated_client_key"];
if (IsA<String const>(j_client_key)) {
client_key = get<String const>(j_client_key);
}
auto const &j_client_cert = config["federated_client_cert"];
if (IsA<String const>(j_client_cert)) {
client_cert = get<String const>(j_client_cert);
}

if (server_address.empty()) {
LOG(FATAL) << "Federated server address must be set.";
}
if (world_size == 0) {
LOG(FATAL) << "Federated world size must be set.";
}
if (rank == -1) {
LOG(FATAL) << "Federated rank must be set.";
}
return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
client_cert);
}

/**
* @brief Construct a new federated communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
std::string const &server_cert_path, std::string const &client_key_path,
std::string const &client_cert_path)
: Communicator{world_size, rank} {
if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
} else {
client_.reset(new xgboost::federated::FederatedClient(
server_address, rank, xgboost::common::ReadAll(server_cert_path),
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
}
}

/** @brief Insecure communicator for testing only. */
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
: Communicator{world_size, rank} {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
}

~FederatedCommunicator() override { client_.reset(); }

bool IsDistributed() const override { return true; }

void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
count * GetTypeSize(data_type));
auto const received =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
static_cast<xgboost::federated::ReduceOperation>(op));
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
}

void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
if (GetWorldSize() == 1) return;
if (GetRank() == root) {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
client_->Broadcast(send_buffer, root);
} else {
auto const received = client_->Broadcast("", root);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}
}

std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }

void Print(const std::string &message) override { LOG(CONSOLE) << message; }

private:
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
};
} // namespace collective
} // namespace xgboost