Skip to content

Commit

Permalink
[Breaking] Switch from rabit to the collective communicator (#8257)
Browse files Browse the repository at this point in the history
* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
  • Loading branch information
rongou and hcho3 committed Oct 5, 2022
1 parent e47b3a3 commit 668b8a0
Show file tree
Hide file tree
Showing 79 changed files with 801 additions and 2,208 deletions.
20 changes: 10 additions & 10 deletions demo/nvflare/custom/trainer.py
Expand Up @@ -52,15 +52,15 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
def _do_training(self, fl_ctx: FLContext):
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
rank = int(client_name.split('-')[1]) - 1
rabit_env = [
f'federated_server_address={self._server_address}',
f'federated_world_size={self._world_size}',
f'federated_rank={rank}',
f'federated_server_cert={self._server_cert_path}',
f'federated_client_key={self._client_key_path}',
f'federated_client_cert={self._client_cert_path}'
]
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
communicator_env = {
'federated_server_address': self._server_address,
'federated_world_size': self._world_size,
'federated_rank': rank,
'federated_server_cert': self._server_cert_path,
'federated_client_key': self._client_key_path,
'federated_client_cert': self._client_cert_path
}
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train')
dtest = xgb.DMatrix('agaricus.txt.test')
Expand All @@ -86,4 +86,4 @@ def _do_training(self, fl_ctx: FLContext):
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
run_dir = workspace.get_run_dir(run_number)
bst.save_model(os.path.join(run_dir, "test.model.json"))
xgb.rabit.tracker_print("Finished training\n")
xgb.collective.communicator_print("Finished training\n")
Expand Up @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter

import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}

import org.apache.commons.logging.LogFactory
Expand All @@ -46,7 +46,7 @@ object XGBoost {
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
Communicator.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
Expand All @@ -59,7 +59,7 @@ object XGBoost {
.map(_.toString.toInt).getOrElse(0)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
earlyStoppingRound = numEarlyStoppingRounds)
Rabit.shutdown()
Communicator.shutdown()
collector.collect(new XGBoostModel(booster))
}
}
Expand Down
Expand Up @@ -22,7 +22,7 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}

import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
Expand Down Expand Up @@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider {
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
Communicator.init(rabitEnv.asJava)
}

val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
Expand Down Expand Up @@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider {
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Rabit.shutdown()
Communicator.shutdown()
}
ret
}
Expand Down
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
Expand Down Expand Up @@ -303,7 +303,7 @@ object XGBoost extends Serializable {
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0

try {
Rabit.init(rabitEnv)
Communicator.init(rabitEnv)

watches = buildWatchesAndCheck(buildWatches)

Expand Down Expand Up @@ -342,7 +342,7 @@ object XGBoost extends Serializable {
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
throw xgbException
} finally {
Rabit.shutdown()
Communicator.shutdown()
if (watches != null) watches.delete()
}
}
Expand Down

This file was deleted.

0 comments on commit 668b8a0

Please sign in to comment.