diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 50a819ee8f17..ded96bcbab4c 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -71,6 +71,9 @@ #include "../src/logging.cc" #include "../src/global_config.cc" +// collective +#include "../src/collective/communicator.cc" + // common #include "../src/common/common.cc" #include "../src/common/column_matrix.cc" diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index f66b8097f50e..92d0e3c3bfae 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -9,10 +9,12 @@ #ifdef __cplusplus #define XGB_EXTERN_C extern "C" +#include #include #include #else #define XGB_EXTERN_C +#include #include #include #endif // __cplusplus @@ -1386,4 +1388,135 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config, bst_ulong *out_dim, bst_ulong const **out_shape, float const **out_scores); + +/*! + * \brief Initialize the collective communicator. + * + * Currently the communicator API is experimental, function signatures may change in the future + * without notice. + * + * Call this once before using anything. + * + * The additional configuration is not required. Usually the communicator will detect settings + * from environment variables. + * + * \param json_config JSON encoded configuration. Accepted JSON keys are: + * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. + * * rabit: Use Rabit. This is the default if the type is unspecified. + * * mpi: Use MPI. + * * federated: Use the gRPC interface for Federated Learning. + * Only applicable to the Rabit communicator (these are case-sensitive): + * - rabit_tracker_uri: Hostname of the tracker. + * - rabit_tracker_port: Port number of the tracker. + * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. + * - rabit_world_size: Total number of workers. + * - rabit_hadoop_mode: Enable Hadoop support. + * - rabit_tree_reduce_minsize: Minimal size for tree reduce. + * - rabit_reduce_ring_mincount: Minimal count to perform ring reduce. + * - rabit_reduce_buffer: Size of the reduce buffer. + * - rabit_bootstrap_cache: Size of the bootstrap cache. + * - rabit_debug: Enable debugging. + * - rabit_timeout: Enable timeout. + * - rabit_timeout_sec: Timeout in seconds. + * - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. + * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as + * environment variables): + * - DMLC_TRACKER_URI: Hostname of the tracker. + * - DMLC_TRACKER_PORT: Port number of the tracker. + * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. + * - DMLC_ROLE: Role of the current task, "worker" or "server". + * - DMLC_NUM_ATTEMPT: Number of attempts after task failure. + * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. + * Only applicable to the Federated communicator (use upper case for environment variables, use + * lower case for runtime configuration): + * - federated_server_address: Address of the federated server. + * - federated_world_size: Number of federated workers. + * - federated_rank: Rank of the current worker. + * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. + * - federated_client_key: Client key file path. Only needed for the SSL mode. + * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. + * \return 0 for success, -1 for failure. + */ +XGB_DLL int XGCommunicatorInit(char const* json_config); + +/*! + * \brief Finalize the collective communicator. + * + * Call this function after you finished all jobs. + * + * \return 0 for success, -1 for failure. + */ +XGB_DLL int XGCommunicatorFinalize(void); + +/*! + * \brief Get rank of current process. + * + * \return Rank of the worker. + */ +XGB_DLL int XGCommunicatorGetRank(void); + +/*! + * \brief Get total number of processes. + * + * \return Total world size. + */ +XGB_DLL int XGCommunicatorGetWorldSize(void); + +/*! + * \brief Get if the communicator is distributed. + * + * \return True if the communicator is distributed. + */ +XGB_DLL int XGCommunicatorIsDistributed(void); + +/*! + * \brief Print the message to the communicator. + * + * This function can be used to communicate the information of the progress to the user who monitors + * the communicator. + * + * \param message The message to be printed. + * \return 0 for success, -1 for failure. + */ +XGB_DLL int XGCommunicatorPrint(char const *message); + +/*! + * \brief Get the name of the processor. + * + * \param name_str Pointer to received returned processor name. + * \return 0 for success, -1 for failure. + */ +XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str); + +/*! + * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe. + * + * Example: + * int a = 1; + * Broadcast(&a, sizeof(a), root); + * + * \param send_receive_buffer Pointer to the send or receive buffer. + * \param size Size of the data. + * \param root The process rank to broadcast from. + * \return 0 for success, -1 for failure. + */ +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root); + +/*! + * \brief Perform in-place allreduce. This function is NOT thread-safe. + * + * Example Usage: the following code gives sum of the result + * vector data(10); + * ... + * Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); + * ... + * \param send_receive_buffer Buffer for both sending and receiving data. + * \param count Number of elements to be reduced. + * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. + * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. + * \return 0 for success, -1 for failure. + */ +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op); + + #endif // XGBOOST_C_API_H_ diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index 8f2447cf2b90..feb589e54548 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -71,7 +71,9 @@ class FederatedEngine : public IEngine { void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { auto *buffer = reinterpret_cast(sendrecvbuf); std::string const send_buffer(buffer, size); - auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op)); + auto const receive_buffer = + client_->Allreduce(send_buffer, static_cast(dtype), + static_cast(op)); receive_buffer.copy(buffer, size); } @@ -113,51 +115,6 @@ class FederatedEngine : public IEngine { } private: - /** @brief Transform mpi::DataType to xgboost::federated::DataType. */ - static xgboost::federated::DataType GetDataType(mpi::DataType data_type) { - switch (data_type) { - case mpi::kChar: - return xgboost::federated::CHAR; - case mpi::kUChar: - return xgboost::federated::UCHAR; - case mpi::kInt: - return xgboost::federated::INT; - case mpi::kUInt: - return xgboost::federated::UINT; - case mpi::kLong: - return xgboost::federated::LONG; - case mpi::kULong: - return xgboost::federated::ULONG; - case mpi::kFloat: - return xgboost::federated::FLOAT; - case mpi::kDouble: - return xgboost::federated::DOUBLE; - case mpi::kLongLong: - return xgboost::federated::LONGLONG; - case mpi::kULongLong: - return xgboost::federated::ULONGLONG; - } - utils::Error("unknown mpi::DataType"); - return xgboost::federated::CHAR; - } - - /** @brief Transform mpi::OpType to enum to MPI OP */ - static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) { - switch (op_type) { - case mpi::kMax: - return xgboost::federated::MAX; - case mpi::kMin: - return xgboost::federated::MIN; - case mpi::kSum: - return xgboost::federated::SUM; - case mpi::kBitwiseOR: - utils::Error("Bitwise OR is not supported"); - return xgboost::federated::MAX; - } - utils::Error("unknown mpi::OpType"); - return xgboost::federated::MAX; - } - void SetParam(std::string const &name, std::string const &val) { if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { server_address_ = val; diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index cba897c0ea81..5a338ba0d8f9 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -12,16 +12,14 @@ service Federated { } enum DataType { - CHAR = 0; - UCHAR = 1; - INT = 2; - UINT = 3; - LONG = 4; - ULONG = 5; + INT8 = 0; + UINT8 = 1; + INT32 = 2; + UINT32 = 3; + INT64 = 4; + UINT64 = 5; FLOAT = 6; DOUBLE = 7; - LONGLONG = 8; - ULONGLONG = 9; } enum ReduceOperation { diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h new file mode 100644 index 000000000000..6a3186b4f608 --- /dev/null +++ b/plugin/federated/federated_communicator.h @@ -0,0 +1,192 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include + +#include "../../src/collective/communicator.h" +#include "../../src/common/io.h" +#include "federated_client.h" + +namespace xgboost { +namespace collective { + +/** + * @brief A Federated Learning communicator class that handles collective communication. + */ +class FederatedCommunicator : public Communicator { + public: + /** + * @brief Create a new communicator based on JSON configuration. + * @param config JSON configuration. + * @return Communicator as specified by the JSON configuration. + */ + static Communicator *Create(Json const &config) { + std::string server_address{}; + int world_size{0}; + int rank{-1}; + std::string server_cert{}; + std::string client_key{}; + std::string client_cert{}; + + // Parse environment variables first. + auto *value = getenv("FEDERATED_SERVER_ADDRESS"); + if (value != nullptr) { + server_address = value; + } + value = getenv("FEDERATED_WORLD_SIZE"); + if (value != nullptr) { + world_size = std::stoi(value); + } + value = getenv("FEDERATED_RANK"); + if (value != nullptr) { + rank = std::stoi(value); + } + value = getenv("FEDERATED_SERVER_CERT"); + if (value != nullptr) { + server_cert = value; + } + value = getenv("FEDERATED_CLIENT_KEY"); + if (value != nullptr) { + client_key = value; + } + value = getenv("FEDERATED_CLIENT_CERT"); + if (value != nullptr) { + client_cert = value; + } + + // Runtime configuration overrides. + auto const &j_server_address = config["federated_server_address"]; + if (IsA(j_server_address)) { + server_address = get(j_server_address); + } + auto const &j_world_size = config["federated_world_size"]; + if (IsA(j_world_size)) { + world_size = static_cast(get(j_world_size)); + } + auto const &j_rank = config["federated_rank"]; + if (IsA(j_rank)) { + rank = static_cast(get(j_rank)); + } + auto const &j_server_cert = config["federated_server_cert"]; + if (IsA(j_server_cert)) { + server_cert = get(j_server_cert); + } + auto const &j_client_key = config["federated_client_key"]; + if (IsA(j_client_key)) { + client_key = get(j_client_key); + } + auto const &j_client_cert = config["federated_client_cert"]; + if (IsA(j_client_cert)) { + client_cert = get(j_client_cert); + } + + if (server_address.empty()) { + LOG(FATAL) << "Federated server address must be set."; + } + if (world_size == 0) { + LOG(FATAL) << "Federated world size must be set."; + } + if (rank == -1) { + LOG(FATAL) << "Federated rank must be set."; + } + return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key, + client_cert); + } + + /** + * @brief Construct a new federated communicator. + * + * @param world_size Total number of processes. + * @param rank Rank of the current process. + * @param server_address Address of the federated server (host:port). + * @param server_cert_path Path to the server cert file. + * @param client_key_path Path to the client key file. + * @param client_cert_path Path to the client cert file. + */ + FederatedCommunicator(int world_size, int rank, std::string const &server_address, + std::string const &server_cert_path, std::string const &client_key_path, + std::string const &client_cert_path) + : Communicator{world_size, rank} { + if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) { + client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); + } else { + client_.reset(new xgboost::federated::FederatedClient( + server_address, rank, xgboost::common::ReadAll(server_cert_path), + xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path))); + } + } + + /** + * @brief Construct an insecure federated communicator without using SSL. + * @param world_size Total number of processes. + * @param rank Rank of the current process. + * @param server_address Address of the federated server (host:port). + */ + FederatedCommunicator(int world_size, int rank, std::string const &server_address) + : Communicator{world_size, rank} { + client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); + } + + ~FederatedCommunicator() override { client_.reset(); } + + /** + * \brief Get if the communicator is distributed. + * \return True. + */ + bool IsDistributed() const override { return true; } + + /** + * \brief Perform in-place allreduce. + * \param send_receive_buffer Buffer for both sending and receiving data. + * \param count Number of elements to be reduced. + * \param data_type Enumeration of data type. + * \param op Enumeration of operation type. + */ + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) override { + std::string const send_buffer(reinterpret_cast(send_receive_buffer), + count * GetTypeSize(data_type)); + auto const received = + client_->Allreduce(send_buffer, static_cast(data_type), + static_cast(op)); + received.copy(reinterpret_cast(send_receive_buffer), count * GetTypeSize(data_type)); + } + + /** + * \brief Broadcast a memory region to all others from root. + * \param send_receive_buffer Pointer to the send or receive buffer. + * \param size Size of the data. + * \param root The process rank to broadcast from. + */ + void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { + if (GetWorldSize() == 1) return; + if (GetRank() == root) { + std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); + client_->Broadcast(send_buffer, root); + } else { + auto const received = client_->Broadcast("", root); + received.copy(reinterpret_cast(send_receive_buffer), size); + } + } + + /** + * \brief Get the name of the processor. + * \return Name of the processor. + */ + std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); } + + /** + * \brief Print the message to the communicator. + * \param message The message to be printed. + */ + void Print(const std::string &message) override { LOG(CONSOLE) << message; } + + protected: + void Shutdown() override {} + + private: + std::unique_ptr client_{}; +}; +} // namespace collective +} // namespace xgboost diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index 9e11fd51ee0b..0738f776bb43 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -10,6 +10,8 @@ #include #include +#include "../../src/common/io.h" + namespace xgboost { namespace federated { @@ -71,32 +73,35 @@ class AllreduceFunctor { void Accumulate(std::string& buffer, std::string const& input, DataType data_type, ReduceOperation reduce_operation) const { switch (data_type) { - case DataType::CHAR: - Accumulate(&buffer[0], reinterpret_cast(input.data()), buffer.size(), + case DataType::INT8: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), buffer.size(), reduce_operation); break; - case DataType::UCHAR: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), buffer.size(), + case DataType::UINT8: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), buffer.size(), reduce_operation); break; - case DataType::INT: - Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), - buffer.size() / sizeof(int), reduce_operation); + case DataType::INT32: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::uint32_t), reduce_operation); break; - case DataType::UINT: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(unsigned int), reduce_operation); + case DataType::UINT32: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::uint32_t), reduce_operation); break; - case DataType::LONG: - Accumulate(reinterpret_cast(&buffer[0]), reinterpret_cast(input.data()), - buffer.size() / sizeof(long), reduce_operation); + case DataType::INT64: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::int64_t), reduce_operation); break; - case DataType::ULONG: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(unsigned long), reduce_operation); + case DataType::UINT64: + Accumulate(reinterpret_cast(&buffer[0]), + reinterpret_cast(input.data()), + buffer.size() / sizeof(std::uint64_t), reduce_operation); break; case DataType::FLOAT: Accumulate(reinterpret_cast(&buffer[0]), @@ -108,16 +113,6 @@ class AllreduceFunctor { reinterpret_cast(input.data()), buffer.size() / sizeof(double), reduce_operation); break; - case DataType::LONGLONG: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(long long), reduce_operation); - break; - case DataType::ULONGLONG: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(unsigned long long), reduce_operation); - break; default: throw std::invalid_argument("Invalid data type"); } @@ -201,13 +196,6 @@ grpc::Status FederatedService::Handle(Request const* request, Reply* reply, return grpc::Status::OK; } -std::string ReadFile(char const* path) { - auto stream = std::ifstream(path); - std::ostringstream out; - out << stream.rdbuf(); - return out.str(); -} - void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); @@ -216,10 +204,10 @@ void RunServer(int port, int world_size, char const* server_key_file, char const grpc::ServerBuilder builder; auto options = grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); - options.pem_root_certs = ReadFile(client_cert_file); + options.pem_root_certs = xgboost::common::ReadAll(client_cert_file); auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); - key.private_key = ReadFile(server_key_file); - key.cert_chain = ReadFile(server_cert_file); + key.private_key = xgboost::common::ReadAll(server_key_file); + key.cert_chain = xgboost::common::ReadAll(server_cert_file); options.pem_key_cert_pairs.push_back(key); builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py new file mode 100644 index 000000000000..e4662d744e50 --- /dev/null +++ b/python-package/xgboost/collective.py @@ -0,0 +1,243 @@ +"""XGBoost collective communication related API.""" +import ctypes +import json +import logging +import pickle +from enum import IntEnum, unique +from typing import Any, List + +import numpy as np + +from ._typing import _T +from .core import _LIB, _check_call, c_str, py_str, from_pystr_to_cstr + +LOGGER = logging.getLogger("[xgboost.collective]") + + +def init(**args: Any) -> None: + """Initialize the collective library with arguments. + + Parameters + ---------- + args: Dict[str, Any] + Keyword arguments representing the parameters and their values. + + Accepted parameters: + - xgboost_communicator: The type of the communicator. Can be set as an environment + variable. + * rabit: Use Rabit. This is the default if the type is unspecified. + * federated: Use the gRPC interface for Federated Learning. + Only applicable to the Rabit communicator (these are case sensitive): + -- rabit_tracker_uri: Hostname of the tracker. + -- rabit_tracker_port: Port number of the tracker. + -- rabit_task_id: ID of the current task, can be used to obtain deterministic rank + assignment. + -- rabit_world_size: Total number of workers. + -- rabit_hadoop_mode: Enable Hadoop support. + -- rabit_tree_reduce_minsize: Minimal size for tree reduce. + -- rabit_reduce_ring_mincount: Minimal count to perform ring reduce. + -- rabit_reduce_buffer: Size of the reduce buffer. + -- rabit_bootstrap_cache: Size of the bootstrap cache. + -- rabit_debug: Enable debugging. + -- rabit_timeout: Enable timeout. + -- rabit_timeout_sec: Timeout in seconds. + -- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. + Only applicable to the Rabit communicator (these are case-sensitive, and can be set as + environment variables): + -- DMLC_TRACKER_URI: Hostname of the tracker. + -- DMLC_TRACKER_PORT: Port number of the tracker. + -- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank + assignment. + -- DMLC_ROLE: Role of the current task, "worker" or "server". + -- DMLC_NUM_ATTEMPT: Number of attempts after task failure. + -- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. + Only applicable to the Federated communicator (use upper case for environment variables, use + lower case for runtime configuration): + -- federated_server_address: Address of the federated server. + -- federated_world_size: Number of federated workers. + -- federated_rank: Rank of the current worker. + -- federated_server_cert: Server certificate file path. Only needed for the SSL mode. + -- federated_client_key: Client key file path. Only needed for the SSL mode. + -- federated_client_cert: Client certificate file path. Only needed for the SSL mode. + """ + config = from_pystr_to_cstr(json.dumps(args)) + _check_call(_LIB.XGCommunicatorInit(config)) + + +def finalize() -> None: + """Finalize the communicator.""" + _check_call(_LIB.XGCommunicatorFinalize()) + + +def get_rank() -> int: + """Get rank of current process. + + Returns + ------- + rank : int + Rank of current process. + """ + ret = _LIB.XGCommunicatorGetRank() + return ret + + +def get_world_size() -> int: + """Get total number workers. + + Returns + ------- + n : int + Total number of process. + """ + ret = _LIB.XGCommunicatorGetWorldSize() + return ret + + +def is_distributed() -> int: + """If the collective communicator is distributed.""" + is_dist = _LIB.XGCommunicatorIsDistributed() + return is_dist + + +def communicator_print(msg: Any) -> None: + """Print message to the communicator. + + This function can be used to communicate the information of + the progress to the communicator. + + Parameters + ---------- + msg : str + The message to be printed to the communicator. + """ + if not isinstance(msg, str): + msg = str(msg) + is_dist = _LIB.XGCommunicatorIsDistributed() + if is_dist != 0: + _check_call(_LIB.XGCommunicatorPrint(c_str(msg))) + else: + print(msg.strip(), flush=True) + + +def get_processor_name() -> str: + """Get the processor name. + + Returns + ------- + name : str + the name of processor(host) + """ + name_str = ctypes.c_char_p() + _check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str))) + value = name_str.value + assert value + return py_str(value) + + +def broadcast(data: _T, root: int) -> _T: + """Broadcast object from one node to all other nodes. + + Parameters + ---------- + data : any type that can be pickled + Input data, if current rank does not equal root, this can be None + root : int + Rank of the node to broadcast data from. + + Returns + ------- + object : int + the result of broadcast. + """ + rank = get_rank() + length = ctypes.c_ulong() + if root == rank: + assert data is not None, 'need to pass in data when broadcasting' + s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + length.value = len(s) + # run first broadcast + _check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length), + ctypes.sizeof(ctypes.c_ulong), root)) + if root != rank: + dptr = (ctypes.c_char * length.value)() + # run second + _check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p), + length.value, root)) + data = pickle.loads(dptr.raw) + del dptr + else: + _check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), + length.value, root)) + del s + return data + + +# enumeration of dtypes +DTYPE_ENUM__ = { + np.dtype('int8'): 0, + np.dtype('uint8'): 1, + np.dtype('int32'): 2, + np.dtype('uint32'): 3, + np.dtype('int64'): 4, + np.dtype('uint64'): 5, + np.dtype('float32'): 6, + np.dtype('float64'): 7 +} + + +@unique +class Op(IntEnum): + """Supported operations for allreduce.""" + MAX = 0 + MIN = 1 + SUM = 2 + + +def allreduce( # pylint:disable=invalid-name + data: np.ndarray, op: Op +) -> np.ndarray: + """Perform allreduce, return the result. + + Parameters + ---------- + data : + Input data. + op : + Reduction operator. + + Returns + ------- + result : + The result of allreduce, have same shape as data + + Notes + ----- + This function is not thread-safe. + """ + if not isinstance(data, np.ndarray): + raise TypeError('allreduce only takes in numpy.ndarray') + buf = data.ravel() + if buf.base is data.base: + buf = buf.copy() + if buf.dtype not in DTYPE_ENUM__: + raise Exception(f"data type {buf.dtype} not supported") + _check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p), + buf.size, DTYPE_ENUM__[buf.dtype], + int(op), None, None)) + return buf + + +class CommunicatorContext: + """A context controlling collective communicator initialization and finalization.""" + + def __init__(self, **args: Any) -> None: + self.args = args + + def __enter__(self) -> None: + init(**self.args) + assert is_distributed() + LOGGER.debug("-------------- communicator say hello ------------------") + + def __exit__(self, *args: List) -> None: + finalize() + LOGGER.debug("--------------- communicator say bye ------------------") diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 698b8fb8fba9..9b5bea3acd00 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -22,6 +22,7 @@ #include "c_api_error.h" #include "c_api_utils.h" +#include "../collective/communicator.h" #include "../common/io.h" #include "../common/charconv.h" #include "../data/adapter.h" @@ -1370,6 +1371,62 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, API_END(); } +using xgboost::collective::Communicator; + +XGB_DLL int XGCommunicatorInit(char const* json_config) { + API_BEGIN(); + Json config { Json::Load(StringView{json_config}) }; + Communicator::Init(config); + API_END(); +} + +XGB_DLL int XGCommunicatorFinalize(void) { + API_BEGIN(); + Communicator::Finalize(); + API_END(); +} + +XGB_DLL int XGCommunicatorGetRank(void) { + return Communicator::Get()->GetRank(); +} + +XGB_DLL int XGCommunicatorGetWorldSize(void) { + return Communicator::Get()->GetWorldSize(); +} + +XGB_DLL int XGCommunicatorIsDistributed(void) { + return Communicator::Get()->IsDistributed(); +} + +XGB_DLL int XGCommunicatorPrint(char const *message) { + API_BEGIN(); + Communicator::Get()->Print(message); + API_END(); +} + +XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { + API_BEGIN(); + auto& local = *GlobalConfigAPIThreadLocalStore::Get(); + local.ret_str = Communicator::Get()->GetProcessorName(); + *name_str = local.ret_str.c_str(); + API_END(); +} + +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { + API_BEGIN(); + Communicator::Get()->Broadcast(send_receive_buffer, size, root); + API_END(); +} + +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, + int enum_op) { + API_BEGIN(); + Communicator::Get()->AllReduce( + send_receive_buffer, count, static_cast(enum_dtype), + static_cast(enum_op)); + API_END(); +} + #if defined(XGBOOST_USE_FEDERATED) XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path, char const *server_cert_path, char const *client_cert_path) { diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc new file mode 100644 index 000000000000..73765223b225 --- /dev/null +++ b/src/collective/communicator.cc @@ -0,0 +1,59 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "communicator.h" + +#include "rabit_communicator.h" + +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_communicator.h" +#endif + +namespace xgboost { +namespace collective { + +thread_local std::unique_ptr Communicator::communicator_{}; +thread_local CommunicatorType Communicator::type_{}; + +void Communicator::Init(Json const& config) { + if (communicator_) { + LOG(FATAL) << "Communicator can only be initialized once."; + } + + auto type = GetTypeFromEnv(); + auto const arg = GetTypeFromConfig(config); + if (arg != CommunicatorType::kUnknown) { + type = arg; + } + if (type == CommunicatorType::kUnknown) { + // Default to Rabit if unspecified. + type = CommunicatorType::kRabit; + } + type_ = type; + switch (type) { + case CommunicatorType::kRabit: { + communicator_.reset(RabitCommunicator::Create(config)); + break; + } + case CommunicatorType::kFederated: { +#if defined(XGBOOST_USE_FEDERATED) + communicator_.reset(FederatedCommunicator::Create(config)); +#else + LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; +#endif + break; + } + case CommunicatorType::kUnknown: + LOG(FATAL) << "Unknown communicator type."; + } +} + +#ifndef XGBOOST_USE_CUDA +void Communicator::Finalize() { + communicator_->Shutdown(); + communicator_.reset(nullptr); +} +#endif + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu new file mode 100644 index 000000000000..2485000d9ad4 --- /dev/null +++ b/src/collective/communicator.cu @@ -0,0 +1,41 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "communicator.h" +#include "device_communicator.cuh" +#include "device_communicator_adapter.cuh" +#ifdef XGBOOST_USE_NCCL +#include "nccl_device_communicator.cuh" +#endif + +namespace xgboost { +namespace collective { + +thread_local int Communicator::device_ordinal_{-1}; +thread_local std::unique_ptr Communicator::device_communicator_{}; + +void Communicator::Finalize() { + communicator_->Shutdown(); + communicator_.reset(nullptr); + device_ordinal_ = -1; + device_communicator_.reset(nullptr); +} + +DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { + if (!device_communicator_ || device_ordinal_ != device_ordinal) { + device_ordinal_ = device_ordinal; +#ifdef XGBOOST_USE_NCCL + if (type_ != CommunicatorType::kFederated) { + device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get())); + } else { + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); + } +#else + device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); +#endif + } + return device_communicator_.get(); +} + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/communicator.h b/src/collective/communicator.h new file mode 100644 index 000000000000..14ee201618cd --- /dev/null +++ b/src/collective/communicator.h @@ -0,0 +1,218 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include +#include + +#include +#include + +namespace xgboost { +namespace collective { + +/** @brief Defines the integral and floating data types. */ +enum class DataType { + kInt8 = 0, + kUInt8 = 1, + kInt32 = 2, + kUInt32 = 3, + kInt64 = 4, + kUInt64 = 5, + kFloat = 6, + kDouble = 7 +}; + +/** @brief Get the size of the data type. */ +inline std::size_t GetTypeSize(DataType data_type) { + std::size_t size{0}; + switch (data_type) { + case DataType::kInt8: + size = sizeof(std::int8_t); + break; + case DataType::kUInt8: + size = sizeof(std::uint8_t); + break; + case DataType::kInt32: + size = sizeof(std::int32_t); + break; + case DataType::kUInt32: + size = sizeof(std::uint32_t); + break; + case DataType::kInt64: + size = sizeof(std::int64_t); + break; + case DataType::kUInt64: + size = sizeof(std::uint64_t); + break; + case DataType::kFloat: + size = sizeof(float); + break; + case DataType::kDouble: + size = sizeof(double); + break; + default: + LOG(FATAL) << "Unknown data type."; + } + return size; +} + +/** @brief Defines the reduction operation. */ +enum class Operation { kMax = 0, kMin = 1, kSum = 2 }; + +class DeviceCommunicator; + +enum class CommunicatorType { kUnknown, kRabit, kFederated }; + +/** \brief Case-insensitive string comparison. */ +inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { +#ifdef _MSC_VER + return _stricmp(s1, s2); +#else // _MSC_VER + return strcasecmp(s1, s2); +#endif // _MSC_VER +} + +/** + * @brief A communicator class that handles collective communication. + */ +class Communicator { + public: + /** + * @brief Initialize the communicator. This can only be done once. + * + * @param config JSON configuration for the communicator. + */ + static void Init(Json const &config); + + /** @brief Finalize the communicator. */ + static void Finalize(); + + /** @brief Get the communicator instance. */ + static Communicator *Get() { return communicator_.get(); } + +#if defined(XGBOOST_USE_CUDA) + /** + * @brief Get the device communicator. + * + * @param device_ordinal ID of the device. + * @return An instance of device communicator. + */ + static DeviceCommunicator *GetDevice(int device_ordinal); +#endif + + virtual ~Communicator() = default; + + /** @brief Get the total number of processes. */ + int GetWorldSize() const { return world_size_; } + + /** @brief Get the rank of the current processes. */ + int GetRank() const { return rank_; } + + /** @brief Whether the communicator is running in distributed mode. */ + virtual bool IsDistributed() const = 0; + + /** + * @brief Combines values from all processes and distributes the result back to all processes. + * + * @param send_receive_buffer Buffer storing the data. + * @param count Number of elements in the buffer. + * @param data_type Data type stored in the buffer. + * @param op The operation to perform. + */ + virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) = 0; + + /** + * @brief Broadcasts a message from the process with rank `root` to all other processes of the + * group. + * + * @param send_receive_buffer Buffer storing the data. + * @param size Size of the data in bytes. + * @param root Rank of broadcast root. + */ + virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0; + + /** + * @brief Gets the name of the processor. + */ + virtual std::string GetProcessorName() = 0; + + /** + * @brief Prints the message. + */ + virtual void Print(std::string const &message) = 0; + + /** @brief Get the communicator type from environment variables. Visible for testing. */ + static CommunicatorType GetTypeFromEnv() { + auto *env = std::getenv("XGBOOST_COMMUNICATOR"); + if (env != nullptr) { + return StringToType(env); + } else { + return CommunicatorType::kUnknown; + } + } + + /** @brief Get the communicator type from runtime configuration. Visible for testing. */ + static CommunicatorType GetTypeFromConfig(Json const &config) { + auto const &j_upper = config["XGBOOST_COMMUNICATOR"]; + if (IsA(j_upper)) { + return StringToType(get(j_upper).c_str()); + } + auto const &j_lower = config["xgboost_communicator"]; + if (IsA(j_lower)) { + return StringToType(get(j_lower).c_str()); + } + return CommunicatorType::kUnknown; + } + + protected: + /** + * @brief Construct a new communicator. + * + * @param world_size Total number of processes. + * @param rank Rank of the current process. + */ + Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) { + if (world_size < 1) { + LOG(FATAL) << "World size " << world_size << " is less than 1."; + } + if (rank < 0) { + LOG(FATAL) << "Rank " << rank << " is less than 0."; + } + if (rank >= world_size) { + LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << "."; + } + } + + /** + * @brief Shuts down the communicator. + */ + virtual void Shutdown() = 0; + + private: + static CommunicatorType StringToType(char const *str) { + CommunicatorType result = CommunicatorType::kUnknown; + if (!CompareStringsCaseInsensitive("rabit", str)) { + result = CommunicatorType::kRabit; + } else if (!CompareStringsCaseInsensitive("federated", str)) { + result = CommunicatorType::kFederated; + } else { + LOG(FATAL) << "Unknown communicator type " << str; + } + return result; + } + + static thread_local std::unique_ptr communicator_; + static thread_local CommunicatorType type_; +#if defined(XGBOOST_USE_CUDA) + static thread_local int device_ordinal_; + static thread_local std::unique_ptr device_communicator_; +#endif + + int const world_size_; + int const rank_; +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh new file mode 100644 index 000000000000..15d18cead02f --- /dev/null +++ b/src/collective/device_communicator.cuh @@ -0,0 +1,42 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include + +#include "../common/device_helpers.cuh" + +namespace xgboost { +namespace collective { + +/** + * @brief Collective communicator for device buffers. + */ +class DeviceCommunicator { + public: + virtual ~DeviceCommunicator() = default; + + /** + * @brief Sum values from all processes and distribute the result back to all processes. + * @param send_receive_buffer Buffer storing the data. + * @param count Number of elements in the buffer. + */ + virtual void AllReduceSum(double *send_receive_buffer, int count) = 0; + + /** + * @brief Gather variable-length values from all processes. + * @param send_buffer Buffer storing the input data. + * @param length_bytes Length in bytes of the input data. + * @param segments Size of each segment. + * @param receive_buffer Buffer storing the output data. + */ + virtual void AllGatherV(void const *send_buffer, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *receive_buffer) = 0; + + /** @brief Synchronize device operations. */ + virtual void Synchronize() = 0; +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh new file mode 100644 index 000000000000..794049bfcb1c --- /dev/null +++ b/src/collective/device_communicator_adapter.cuh @@ -0,0 +1,76 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once + +#include "communicator.h" +#include "device_communicator.cuh" + +namespace xgboost { +namespace collective { + +class DeviceCommunicatorAdapter : public DeviceCommunicator { + public: + DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator) + : device_ordinal_{device_ordinal}, communicator_{communicator} { + if (device_ordinal_ < 0) { + LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; + } + if (communicator_ == nullptr) { + LOG(FATAL) << "Communicator cannot be null."; + } + } + + ~DeviceCommunicatorAdapter() override = default; + + void AllReduceSum(double *send_receive_buffer, int count) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + auto size = count * sizeof(double); + host_buffer_.reserve(size); + dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); + communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum); + dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); + } + + void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, + dh::caching_device_vector *receive_buffer) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + int const world_size = communicator_->GetWorldSize(); + int const rank = communicator_->GetRank(); + + segments->clear(); + segments->resize(world_size, 0); + segments->at(rank) = length_bytes; + communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, + Operation::kMax); + auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); + receive_buffer->resize(total_bytes); + + host_buffer_.reserve(total_bytes); + size_t offset = 0; + for (int32_t i = 0; i < world_size; ++i) { + size_t as_bytes = segments->at(i); + if (i == rank) { + dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank), + cudaMemcpyDefault)); + } + communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i); + offset += as_bytes; + } + dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, + cudaMemcpyDefault)); + } + + void Synchronize() override { + // Noop. + } + + private: + int const device_ordinal_; + Communicator *communicator_; + /// Host buffer used to call communicator functions. + std::vector host_buffer_{}; +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh new file mode 100644 index 000000000000..ad9f57589c53 --- /dev/null +++ b/src/collective/nccl_device_communicator.cuh @@ -0,0 +1,149 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once + +#include "../common/device_helpers.cuh" +#include "communicator.h" +#include "device_communicator.cuh" + +namespace xgboost { +namespace collective { + +class NcclDeviceCommunicator : public DeviceCommunicator { + public: + NcclDeviceCommunicator(int device_ordinal, Communicator *communicator) + : device_ordinal_{device_ordinal}, communicator_{communicator} { + if (device_ordinal_ < 0) { + LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; + } + if (communicator_ == nullptr) { + LOG(FATAL) << "Communicator cannot be null."; + } + + int32_t const rank = communicator_->GetRank(); + int32_t const world = communicator_->GetWorldSize(); + + std::vector uuids(world * kUuidLength, 0); + auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; + auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); + GetCudaUUID(s_this_uuid); + + // TODO(rongou): replace this with allgather. + communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); + + std::vector> converted(world); + size_t j = 0; + for (size_t i = 0; i < uuids.size(); i += kUuidLength) { + converted[j] = xgboost::common::Span{uuids.data() + i, kUuidLength}; + j++; + } + + auto iter = std::unique(converted.begin(), converted.end()); + auto n_uniques = std::distance(converted.begin(), iter); + + CHECK_EQ(n_uniques, world) + << "Multiple processes within communication group running on same CUDA " + << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; + + nccl_unique_id_ = GetUniqueId(); + dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); + dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); + } + + ~NcclDeviceCommunicator() override { + dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); + ncclCommDestroy(nccl_comm_); + if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { + LOG(CONSOLE) << "======== NCCL Statistics========"; + LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; + LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576; + } + } + + void AllReduceSum(double *send_receive_buffer, int count) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble, + ncclSum, nccl_comm_, cuda_stream_)); + allreduce_bytes_ += count * sizeof(double); + allreduce_calls_ += 1; + } + + void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, + dh::caching_device_vector *receive_buffer) override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + int const world_size = communicator_->GetWorldSize(); + int const rank = communicator_->GetRank(); + + segments->clear(); + segments->resize(world_size, 0); + segments->at(rank) = length_bytes; + communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, + Operation::kMax); + auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); + receive_buffer->resize(total_bytes); + + size_t offset = 0; + dh::safe_nccl(ncclGroupStart()); + for (int32_t i = 0; i < world_size; ++i) { + size_t as_bytes = segments->at(i); + dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, + ncclChar, i, nccl_comm_, cuda_stream_)); + offset += as_bytes; + } + dh::safe_nccl(ncclGroupEnd()); + } + + void Synchronize() override { + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); + } + + private: + static constexpr std::size_t kUuidLength = + sizeof(std::declval().uuid) / sizeof(uint64_t); + + void GetCudaUUID(xgboost::common::Span const &uuid) const { + cudaDeviceProp prob{}; + dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_)); + std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); + } + + static std::string PrintUUID(xgboost::common::Span const &uuid) { + std::stringstream ss; + for (auto v : uuid) { + ss << std::hex << v; + } + return ss.str(); + } + + /** + * \fn ncclUniqueId GetUniqueId() + * + * \brief Gets the Unique ID from NCCL to be used in setting up interprocess + * communication + * + * \return the Unique ID + */ + ncclUniqueId GetUniqueId() { + static const int kRootRank = 0; + ncclUniqueId id; + if (communicator_->GetRank() == kRootRank) { + dh::safe_nccl(ncclGetUniqueId(&id)); + } + communicator_->Broadcast(static_cast(&id), sizeof(ncclUniqueId), + static_cast(kRootRank)); + return id; + } + + int const device_ordinal_; + Communicator *communicator_; + ncclComm_t nccl_comm_{}; + cudaStream_t cuda_stream_{}; + ncclUniqueId nccl_unique_id_{}; + size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. + size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h new file mode 100644 index 000000000000..3b145e6577c5 --- /dev/null +++ b/src/collective/rabit_communicator.h @@ -0,0 +1,120 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include + +#include +#include + +#include "communicator.h" +#include "xgboost/json.h" + +namespace xgboost { +namespace collective { + +class RabitCommunicator : public Communicator { + public: + static Communicator *Create(Json const &config) { + std::vector args_str; + for (auto &items : get(config)) { + switch (items.second.GetValue().Type()) { + case xgboost::Value::ValueKind::kString: { + args_str.push_back(items.first + "=" + get(items.second)); + break; + } + case xgboost::Value::ValueKind::kInteger: { + args_str.push_back(items.first + "=" + std::to_string(get(items.second))); + break; + } + case xgboost::Value::ValueKind::kBoolean: { + if (get(items.second)) { + args_str.push_back(items.first + "=1"); + } else { + args_str.push_back(items.first + "=0"); + } + break; + } + default: + break; + } + } + std::vector args; + for (auto &key_value : args_str) { + args.push_back(&key_value[0]); + } + if (!rabit::Init(static_cast(args.size()), &args[0])) { + LOG(FATAL) << "Failed to initialize Rabit"; + } + return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank()); + } + + RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} + + bool IsDistributed() const override { return rabit::IsDistributed(); } + + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) override { + switch (data_type) { + case DataType::kInt8: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kUInt8: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kInt32: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kUInt32: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kInt64: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kUInt64: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kFloat: + DoAllReduce(send_receive_buffer, count, op); + break; + case DataType::kDouble: + DoAllReduce(send_receive_buffer, count, op); + break; + default: + LOG(FATAL) << "Unknown data type"; + } + } + + void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { + rabit::Broadcast(send_receive_buffer, size, root); + } + + std::string GetProcessorName() override { return rabit::GetProcessorName(); } + + void Print(const std::string &message) override { rabit::TrackerPrint(message); } + + protected: + void Shutdown() override { + rabit::Finalize(); + } + + private: + template + void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { + switch (op) { + case Operation::kMax: + rabit::Allreduce(static_cast(send_receive_buffer), count); + break; + case Operation::kMin: + rabit::Allreduce(static_cast(send_receive_buffer), count); + break; + case Operation::kSum: + rabit::Allreduce(static_cast(send_receive_buffer), count); + break; + default: + LOG(FATAL) << "Unknown allreduce operation"; + } + } +}; +} // namespace collective +} // namespace xgboost diff --git a/src/common/io.h b/src/common/io.h index b377623ea138..bcc6c4704a5d 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "common.h" @@ -111,6 +112,22 @@ inline std::string ReadAll(dmlc::Stream* fi, PeekableInStream* fp) { } return buffer; } + +/** + * \brief Read the whole file content into a string. + */ +inline std::string ReadAll(std::string const &path) { + std::ifstream stream(path); + if (!stream.is_open()) { + LOG(FATAL) << "Could not open file " << path; + } + std::string content{std::istreambuf_iterator(stream), std::istreambuf_iterator()}; + if (content.empty()) { + LOG(FATAL) << "Empty file " << path; + } + return content; +} + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_IO_H_ diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 9dfd429d01e2..51cdecd9d4be 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -22,7 +22,7 @@ if (PLUGIN_FEDERATED) target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) target_link_libraries(testxgboost PRIVATE federated_client) else (PLUGIN_FEDERATED) - file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc") + file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.*") list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES}) endif (PLUGIN_FEDERATED) diff --git a/tests/cpp/collective/test_communicator.cc b/tests/cpp/collective/test_communicator.cc new file mode 100644 index 000000000000..e66e38255345 --- /dev/null +++ b/tests/cpp/collective/test_communicator.cc @@ -0,0 +1,54 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include + +#include "../../../src/collective/communicator.h" + +namespace xgboost { +namespace collective { + +TEST(CommunicatorFactory, TypeFromEnv) { + EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv()); + + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "rabit"); + EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv()); + + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "Federated"); + EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv()); + + dmlc::SetEnv("XGBOOST_COMMUNICATOR", "foo"); + EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error); +} + +TEST(CommunicatorFactory, TypeFromArgs) { + Json config{JsonObject()}; + EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); + + config["xgboost_communicator"] = String("rabit"); + EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); + + config["xgboost_communicator"] = String("federated"); + EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); + + config["xgboost_communicator"] = String("foo"); + EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); +} + +TEST(CommunicatorFactory, TypeFromArgsUpperCase) { + Json config{JsonObject()}; + EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("rabit"); + EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("federated"); + EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); + + config["XGBOOST_COMMUNICATOR"] = String("foo"); + EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); +} + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu new file mode 100644 index 000000000000..47de054c6d4b --- /dev/null +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -0,0 +1,26 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#ifdef XGBOOST_USE_NCCL + +#include + +#include "../../../src/collective/nccl_device_communicator.cuh" + +namespace xgboost { +namespace collective { + +TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { + auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) { + auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +} // namespace collective +} // namespace xgboost + +#endif diff --git a/tests/cpp/collective/test_rabit_communicator.cc b/tests/cpp/collective/test_rabit_communicator.cc new file mode 100644 index 000000000000..ba22d8fdb84f --- /dev/null +++ b/tests/cpp/collective/test_rabit_communicator.cc @@ -0,0 +1,39 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include + +#include "../../../src/collective/rabit_communicator.h" + +namespace xgboost { +namespace collective { + +TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { + auto construct = []() { RabitCommunicator comm{0, 0}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooSmall) { + auto construct = []() { RabitCommunicator comm{1, -1}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooBig) { + auto construct = []() { RabitCommunicator comm{1, 1}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) { + RabitCommunicator comm{6, 3}; + EXPECT_EQ(comm.GetWorldSize(), 6); + EXPECT_EQ(comm.GetRank(), 3); +} + +TEST(RabitCommunicatorSimpleTest, IsNotDistributed) { + RabitCommunicator comm{2, 1}; + // Rabit is only distributed with a tracker. + EXPECT_FALSE(comm.IsDistributed()); +} + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu new file mode 100644 index 000000000000..09187f940c5f --- /dev/null +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -0,0 +1,105 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include +#include + +#include + +#include "../../../plugin/federated/federated_communicator.h" +#include "../../../plugin/federated/federated_server.h" +#include "../../../src/collective/device_communicator_adapter.cuh" + +namespace xgboost { +namespace collective { + +std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) + +class FederatedAdapterTest : public ::testing::Test { + protected: + void SetUp() override { + server_thread_.reset(new std::thread([this] { + grpc::ServerBuilder builder; + federated::FederatedService service{kWorldSize}; + builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + server_->Wait(); + })); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_->join(); + } + + static int const kWorldSize{2}; + std::unique_ptr server_thread_; + std::unique_ptr server_; +}; + +TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { + auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { + auto construct = []() { DeviceCommunicatorAdapter adapter{0, nullptr}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread([rank] { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + DeviceCommunicatorAdapter adapter{rank, &comm}; + int const count = 3; + thrust::device_vector buffer(count, 0); + thrust::sequence(buffer.begin(), buffer.end()); + adapter.AllReduceSum(buffer.data().get(), count); + thrust::host_vector host_buffer = buffer; + EXPECT_EQ(host_buffer.size(), count); + for (auto i = 0; i < count; i++) { + EXPECT_EQ(host_buffer[i], i * 2); + } + })); + } + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedAdapterTest, DeviceAllGatherV) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread([rank] { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + DeviceCommunicatorAdapter adapter{rank, &comm}; + + int const count = rank + 2; + thrust::device_vector buffer(count, 0); + thrust::sequence(buffer.begin(), buffer.end()); + std::vector segments(kWorldSize); + dh::caching_device_vector receive_buffer{}; + + adapter.AllGatherV(buffer.data().get(), count, &segments, &receive_buffer); + + EXPECT_EQ(segments[0], 2); + EXPECT_EQ(segments[1], 3); + thrust::host_vector host_buffer = receive_buffer; + EXPECT_EQ(host_buffer.size(), 5); + int expected[] = {0, 1, 0, 1, 2}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(host_buffer[i], expected[i]); + } + })); + } + for (auto& thread : threads) { + thread.join(); + } +} + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc new file mode 100644 index 000000000000..2d9f233db573 --- /dev/null +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -0,0 +1,119 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include +#include + +#include + +#include "../../../plugin/federated/federated_communicator.h" +#include "../../../plugin/federated/federated_server.h" + +namespace xgboost { +namespace collective { + +std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) + +class FederatedCommunicatorTest : public ::testing::Test { + public: + static void VerifyAllreduce(int rank) { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + CheckAllreduce(comm); + } + + static void VerifyBroadcast(int rank) { + FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + CheckBroadcast(comm, rank); + } + + protected: + void SetUp() override { + server_thread_.reset(new std::thread([this] { + grpc::ServerBuilder builder; + federated::FederatedService service{kWorldSize}; + builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + server_->Wait(); + })); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_->join(); + } + + static void CheckAllreduce(FederatedCommunicator &comm) { + int buffer[] = {1, 2, 3, 4, 5}; + comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); + int expected[] = {3, 6, 9, 12, 15}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(buffer[i], expected[i]); + } + } + + static void CheckBroadcast(FederatedCommunicator &comm, int rank) { + if (rank == 0) { + std::string buffer{"hello"}; + comm.Broadcast(&buffer[0], buffer.size(), 0); + EXPECT_EQ(buffer, "hello"); + } else { + std::string buffer{" "}; + comm.Broadcast(&buffer[0], buffer.size(), 0); + EXPECT_EQ(buffer, "hello"); + } + } + + static int const kWorldSize{3}; + std::unique_ptr server_thread_; + std::unique_ptr server_; +}; + +TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { + auto construct = []() { FederatedCommunicator comm{0, 0, kServerAddress, "", "", ""}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) { + auto construct = []() { FederatedCommunicator comm{1, -1, kServerAddress, "", "", ""}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { + auto construct = []() { FederatedCommunicator comm{1, 1, kServerAddress, "", "", ""}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { + FederatedCommunicator comm{6, 3, kServerAddress}; + EXPECT_EQ(comm.GetWorldSize(), 6); + EXPECT_EQ(comm.GetRank(), 3); +} + +TEST(FederatedCommunicatorSimpleTest, IsDistributed) { + FederatedCommunicator comm{2, 1, kServerAddress}; + EXPECT_TRUE(comm.IsDistributed()); +} + +TEST_F(FederatedCommunicatorTest, Allreduce) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank)); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST_F(FederatedCommunicatorTest, Broadcast) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank)); + } + for (auto &thread : threads) { + thread.join(); + } +} +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index b20c17b09de5..1c3e4f0bc84c 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -62,7 +62,7 @@ class FederatedServerTest : public ::testing::Test { static void CheckAllreduce(federated::FederatedClient& client) { int data[] = {1, 2, 3, 4, 5}; std::string send_buffer(reinterpret_cast(data), sizeof(data)); - auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM); + auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM); auto const* result = reinterpret_cast(reply.data()); int expected[] = {3, 6, 9, 12, 15}; for (auto i = 0; i < 5; i++) { diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index d83d7f2292fe..ea5a3a0f35cf 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -22,6 +22,7 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None: def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None: rabit_env = [ + 'xgboost_communicator=federated', f'federated_server_address=localhost:{port}', f'federated_world_size={world_size}', f'federated_rank={rank}' diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py new file mode 100644 index 000000000000..1b9727ebf05b --- /dev/null +++ b/tests/python/test_collective.py @@ -0,0 +1,39 @@ +import multiprocessing +import socket +import sys + +import numpy as np +import pytest + +import xgboost as xgb +from xgboost import RabitTracker +from xgboost import collective + +if sys.platform.startswith("win"): + pytest.skip("Skipping collective tests on Windows", allow_module_level=True) + + +def run_rabit_worker(rabit_env, world_size): + with xgb.collective.CommunicatorContext(**rabit_env): + assert xgb.collective.get_world_size() == world_size + assert xgb.collective.is_distributed() + assert xgb.collective.get_processor_name() == socket.gethostname() + ret = xgb.collective.broadcast('test1234', 0) + assert str(ret) == 'test1234' + ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) + assert np.array_equal(ret, np.asarray([2, 4, 6])) + + +def test_rabit_communicator(): + world_size = 2 + tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) + tracker.start(world_size) + workers = [] + for _ in range(world_size): + worker = multiprocessing.Process(target=run_rabit_worker, + args=(tracker.worker_envs(), world_size)) + workers.append(worker) + worker.start() + for worker in workers: + worker.join() + assert worker.exitcode == 0