Skip to content

Commit

Permalink
Common interface for collective communication (#8057)
Browse files Browse the repository at this point in the history
* implement broadcast for federated communicator

* implement allreduce

* add communicator factory

* add device adapter

* add device communicator to factory

* add rabit communicator

* add rabit communicator to the factory

* add nccl device communicator

* add synchronize to device communicator

* add back print and getprocessorname

* add python wrapper and c api

* clean up types

* fix non-gpu build

* try to fix ci

* fix std::size_t

* portable string compare ignore case

* c style size_t

* fix lint errors

* cross platform setenv

* fix memory leak

* fix lint errors

* address review feedback

* add python test for rabit communicator

* fix failing gtest

* use json to configure communicators

* fix lint error

* get rid of factories

* fix cpu build

* fix include

* fix python import

* don't export collective.py yet

* skip collective communicator pytest on windows

* add review feedback

* update documentation

* remove mpi communicator type

* fix tests

* shutdown the communicator separately

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
  • Loading branch information
rongou and hcho3 committed Sep 12, 2022
1 parent bc81831 commit a268654
Show file tree
Hide file tree
Showing 25 changed files with 1,771 additions and 95 deletions.
3 changes: 3 additions & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -71,6 +71,9 @@
#include "../src/logging.cc"
#include "../src/global_config.cc"

// collective
#include "../src/collective/communicator.cc"

// common
#include "../src/common/common.cc"
#include "../src/common/column_matrix.cc"
Expand Down
133 changes: 133 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -9,10 +9,12 @@

#ifdef __cplusplus
#define XGB_EXTERN_C extern "C"
#include <cstddef>
#include <cstdio>
#include <cstdint>
#else
#define XGB_EXTERN_C
#include <stddef.h>
#include <stdio.h>
#include <stdint.h>
#endif // __cplusplus
Expand Down Expand Up @@ -1386,4 +1388,135 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
bst_ulong *out_dim,
bst_ulong const **out_shape,
float const **out_scores);

/*!
* \brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorInit(char const* json_config);

/*!
* \brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorFinalize(void);

/*!
* \brief Get rank of current process.
*
* \return Rank of the worker.
*/
XGB_DLL int XGCommunicatorGetRank(void);

/*!
* \brief Get total number of processes.
*
* \return Total world size.
*/
XGB_DLL int XGCommunicatorGetWorldSize(void);

/*!
* \brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
*/
XGB_DLL int XGCommunicatorIsDistributed(void);

/*!
* \brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorPrint(char const *message);

/*!
* \brief Get the name of the processor.
*
* \param name_str Pointer to received returned processor name.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);

/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);

/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);


#endif // XGBOOST_C_API_H_
49 changes: 3 additions & 46 deletions plugin/federated/engine_federated.cc
Expand Up @@ -71,7 +71,9 @@ class FederatedEngine : public IEngine {
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
auto const receive_buffer =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
static_cast<xgboost::federated::ReduceOperation>(op));
receive_buffer.copy(buffer, size);
}

Expand Down Expand Up @@ -113,51 +115,6 @@ class FederatedEngine : public IEngine {
}

private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}

/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}

void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
Expand Down
14 changes: 6 additions & 8 deletions plugin/federated/federated.proto
Expand Up @@ -12,16 +12,14 @@ service Federated {
}

enum DataType {
CHAR = 0;
UCHAR = 1;
INT = 2;
UINT = 3;
LONG = 4;
ULONG = 5;
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
LONGLONG = 8;
ULONGLONG = 9;
}

enum ReduceOperation {
Expand Down

0 comments on commit a268654

Please sign in to comment.