From 1ff1562c9aed8d86322b6a963dfcad05ebf3923b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 27 Sep 2022 14:05:51 -0700 Subject: [PATCH] don't split input data in federated mode --- plugin/federated/federated_communicator.h | 6 ++++++ src/c_api/c_api.cc | 12 ++++-------- src/collective/communicator-inl.h | 7 +++++++ src/collective/communicator.h | 3 +++ src/collective/noop_communicator.h | 1 + src/collective/rabit_communicator.h | 2 ++ 6 files changed, 23 insertions(+), 8 deletions(-) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 9defef719bba..cb1eb0b8109f 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -170,6 +170,12 @@ class FederatedCommunicator : public Communicator { */ bool IsDistributed() const override { return true; } + /** + * \brief Get if the communicator is federated. + * \return True. + */ + bool IsFederated() const override { return true; } + /** * \brief Perform in-place allreduce. * \param send_receive_buffer Buffer for both sending and receiving data. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 63e8c53898a3..87a30283f925 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -208,16 +208,12 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) { XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) { API_BEGIN(); bool load_row_split = false; -#if defined(XGBOOST_USE_FEDERATED) - LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; -#else - if (collective::IsDistributed()) { - LOG(CONSOLE) << "XGBoost distributed mode detected, " - << "will split data among workers"; + if (collective::IsFederated()) { + LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; + } else if (collective::IsDistributed()) { + LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers"; load_row_split = true; } -#endif - xgboost_CHECK_C_ARG_PTR(fname); xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr(DMatrix::Load(fname, silent != 0, load_row_split)); diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 923c6d291ad3..f9fe8f18763a 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -88,6 +88,13 @@ inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); } */ inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); } +/*! + * \brief Get if the communicator is federated. + * + * \return True if the communicator is federated. + */ +inline bool IsFederated() { return Communicator::Get()->IsFederated(); } + /*! * \brief Print the message to the communicator. * diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 9c0d98a7e8c2..ac9346c64e78 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -78,6 +78,9 @@ class Communicator { /** @brief Whether the communicator is running in distributed mode. */ virtual bool IsDistributed() const = 0; + /** @brief Whether the communicator is running in federated mode. */ + virtual bool IsFederated() const = 0; + /** * @brief Combines values from all processes and distributes the result back to all processes. * diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h index 7e5aaa026735..5d4fb89f409c 100644 --- a/src/collective/noop_communicator.h +++ b/src/collective/noop_communicator.h @@ -16,6 +16,7 @@ class NoOpCommunicator : public Communicator { public: NoOpCommunicator() : Communicator(1, 0) {} bool IsDistributed() const override { return false; } + bool IsFederated() const override { return false; } void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override {} void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {} diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 3b145e6577c5..d17cabc01512 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -53,6 +53,8 @@ class RabitCommunicator : public Communicator { bool IsDistributed() const override { return rabit::IsDistributed(); } + bool IsFederated() const override { return false; } + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { switch (data_type) {