Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common interface for collective communication #8057

Merged
merged 58 commits into from Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
a2df84d
implement broadcast for federated communicator
rongou Jul 6, 2022
4894334
implement allreduce
rongou Jul 6, 2022
c5fded6
add communicator factory
rongou Jul 11, 2022
f6b7259
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 11, 2022
b0831a0
add device adapter
rongou Jul 12, 2022
05b2ceb
add device communicator to factory
rongou Jul 13, 2022
f39d0a9
add rabit communicator
rongou Jul 16, 2022
f9319d7
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 16, 2022
0db9fda
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 18, 2022
cd0098d
add rabit communicator to the factory
rongou Jul 18, 2022
198ac94
add nccl device communicator
rongou Jul 18, 2022
4585888
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 19, 2022
8ae0d7a
add synchronize to device communicator
rongou Jul 19, 2022
33ebb0a
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 20, 2022
de9580c
add back print and getprocessorname
rongou Jul 21, 2022
69ea687
add python wrapper and c api
rongou Jul 22, 2022
76b0b0a
Merge remote-tracking branch 'upstream/master' into communicator
rongou Jul 22, 2022
695de5f
clean up types
rongou Jul 22, 2022
e4d0029
fix non-gpu build
rongou Jul 22, 2022
d656387
try to fix ci
rongou Jul 22, 2022
92ae35e
fix std::size_t
rongou Jul 22, 2022
de52150
portable string compare ignore case
rongou Jul 23, 2022
9bc3df6
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 15, 2022
4e2a5b8
c style size_t
rongou Aug 15, 2022
131f4b1
fix lint errors
rongou Aug 15, 2022
88510a3
cross platform setenv
rongou Aug 15, 2022
e758449
fix memory leak
rongou Aug 15, 2022
8923609
fix lint errors
rongou Aug 15, 2022
22d4536
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 22, 2022
e4dfe18
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 23, 2022
3aeae65
address review feedback
rongou Aug 25, 2022
adeab7f
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 25, 2022
cfb7496
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 26, 2022
183ab75
add python test for rabit communicator
rongou Aug 27, 2022
e3c87e0
fix failing gtest
rongou Aug 27, 2022
fcfb1d5
use json to configure communicators
rongou Aug 30, 2022
b57a1be
fix lint error
rongou Aug 31, 2022
2d52b57
Merge remote-tracking branch 'upstream/master' into communicator
rongou Aug 31, 2022
57dfe8e
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 1, 2022
ba5a6e1
get rid of factories
rongou Sep 1, 2022
2039858
fix cpu build
rongou Sep 1, 2022
c3a42fb
fix include
rongou Sep 1, 2022
0a115ff
fix python import
rongou Sep 1, 2022
63ae4e8
don't export collective.py yet
rongou Sep 1, 2022
fe79c38
skip collective communicator pytest on windows
rongou Sep 2, 2022
60da353
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 2, 2022
52c0ff3
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 2, 2022
56dec26
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 6, 2022
ed8fed2
add review feedback
rongou Sep 6, 2022
057bc49
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 7, 2022
615e621
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 7, 2022
d418563
update documentation
rongou Sep 7, 2022
c4cf82c
remove mpi communicator type
rongou Sep 7, 2022
5f0daa0
fix tests
rongou Sep 7, 2022
7a39ce3
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 8, 2022
e20b3db
Merge remote-tracking branch 'upstream/master' into communicator
rongou Sep 9, 2022
cb5f4ad
shutdown the communicator separately
rongou Sep 10, 2022
06809c6
Merge remote-tracking branch 'upstream/master' into communicator
hcho3 Sep 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -71,6 +71,9 @@
#include "../src/logging.cc"
#include "../src/global_config.cc"

// collective
#include "../src/collective/communicator.cc"

// common
#include "../src/common/common.cc"
#include "../src/common/column_matrix.cc"
Expand Down
133 changes: 133 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -9,10 +9,12 @@

#ifdef __cplusplus
#define XGB_EXTERN_C extern "C"
#include <cstddef>
#include <cstdio>
#include <cstdint>
#else
#define XGB_EXTERN_C
#include <stddef.h>
#include <stdio.h>
#include <stdint.h>
#endif // __cplusplus
Expand Down Expand Up @@ -1386,4 +1388,135 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
bst_ulong *out_dim,
bst_ulong const **out_shape,
float const **out_scores);

/*!
* \brief Initialize the collective communicator.
rongou marked this conversation as resolved.
Show resolved Hide resolved
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorInit(char const* json_config);

/*!
* \brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorFinalize(void);

/*!
* \brief Get rank of current process.
*
* \return Rank of the worker.
*/
XGB_DLL int XGCommunicatorGetRank(void);

/*!
* \brief Get total number of processes.
*
* \return Total world size.
*/
XGB_DLL int XGCommunicatorGetWorldSize(void);

/*!
* \brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
*/
XGB_DLL int XGCommunicatorIsDistributed(void);

/*!
* \brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorPrint(char const *message);

/*!
* \brief Get the name of the processor.
*
* \param name_str Pointer to received returned processor name.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);

/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);

/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);


#endif // XGBOOST_C_API_H_
49 changes: 3 additions & 46 deletions plugin/federated/engine_federated.cc
Expand Up @@ -71,7 +71,9 @@ class FederatedEngine : public IEngine {
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
auto const receive_buffer =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
static_cast<xgboost::federated::ReduceOperation>(op));
receive_buffer.copy(buffer, size);
}

Expand Down Expand Up @@ -113,51 +115,6 @@ class FederatedEngine : public IEngine {
}

private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}

/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}

void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
Expand Down
14 changes: 6 additions & 8 deletions plugin/federated/federated.proto
Expand Up @@ -12,16 +12,14 @@ service Federated {
}

enum DataType {
CHAR = 0;
UCHAR = 1;
INT = 2;
UINT = 3;
LONG = 4;
ULONG = 5;
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
LONGLONG = 8;
ULONGLONG = 9;
}

enum ReduceOperation {
Expand Down