Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't split input data in federated mode #8279

Merged
merged 2 commits into from Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions plugin/federated/federated_communicator.h
Expand Up @@ -170,6 +170,12 @@ class FederatedCommunicator : public Communicator {
*/
bool IsDistributed() const override { return true; }

/**
* \brief Get if the communicator is federated.
* \return True.
*/
bool IsFederated() const override { return true; }

/**
* \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data.
Expand Down
12 changes: 4 additions & 8 deletions src/c_api/c_api.cc
Expand Up @@ -208,16 +208,12 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN();
bool load_row_split = false;
#if defined(XGBOOST_USE_FEDERATED)
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
#else
if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
load_row_split = true;
}
#endif

xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
Expand Down
7 changes: 7 additions & 0 deletions src/collective/communicator-inl.h
Expand Up @@ -88,6 +88,13 @@ inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
*/
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }

/*!
* \brief Get if the communicator is federated.
*
* \return True if the communicator is federated.
*/
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }

/*!
* \brief Print the message to the communicator.
*
Expand Down
3 changes: 3 additions & 0 deletions src/collective/communicator.h
Expand Up @@ -78,6 +78,9 @@ class Communicator {
/** @brief Whether the communicator is running in distributed mode. */
virtual bool IsDistributed() const = 0;

/** @brief Whether the communicator is running in federated mode. */
virtual bool IsFederated() const = 0;

/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
Expand Down
1 change: 1 addition & 0 deletions src/collective/noop_communicator.h
Expand Up @@ -16,6 +16,7 @@ class NoOpCommunicator : public Communicator {
public:
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
Expand Down
2 changes: 2 additions & 0 deletions src/collective/rabit_communicator.h
Expand Up @@ -53,6 +53,8 @@ class RabitCommunicator : public Communicator {

bool IsDistributed() const override { return rabit::IsDistributed(); }

bool IsFederated() const override { return false; }

void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
Expand Down