Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Switch from rabit to the collective communicator #8257

Merged
merged 29 commits into from Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2d6e0dc
Switch from rabit to the collective communicator
rongou Sep 21, 2022
c72fbca
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Sep 21, 2022
07ce6f1
fix size_t specialization
rongou Sep 22, 2022
f888985
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Sep 22, 2022
9db6d96
really fix size_t
rongou Sep 22, 2022
96a232e
try again
rongou Sep 22, 2022
5853f83
add include
rongou Sep 22, 2022
3482b53
more include
rongou Sep 22, 2022
2ae73c0
fix lint errors
rongou Sep 22, 2022
2c594ca
remove rabit includes
rongou Sep 23, 2022
0f2cdd1
fix pylint error
rongou Sep 23, 2022
217e3a5
return dict from communicator context
rongou Sep 23, 2022
aa731a0
fix communicator shutdown
rongou Sep 23, 2022
8561853
fix dask test
rongou Sep 23, 2022
747093c
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Sep 26, 2022
c89df47
reset communicator mocklist
rongou Sep 26, 2022
01e19ab
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Sep 27, 2022
8c00deb
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Sep 28, 2022
c814ac1
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Sep 29, 2022
a9479d9
fix distributed tests
rongou Sep 29, 2022
af56b1a
do not save device communicator
rongou Sep 29, 2022
501d06f
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Sep 30, 2022
9be13f8
fix jvm gpu tests
rongou Sep 30, 2022
44a49fd
add python test for federated communicator
rongou Sep 30, 2022
0aff6c0
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Oct 3, 2022
2e3dc6c
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
hcho3 Oct 4, 2022
27a0f85
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Oct 4, 2022
1b7c938
Merge remote-tracking branch 'upstream/master' into switch-to-communi…
rongou Oct 5, 2022
43be5ac
Update gputreeshap submodule
hcho3 Oct 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 10 additions & 10 deletions demo/nvflare/custom/trainer.py
Expand Up @@ -52,15 +52,15 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
def _do_training(self, fl_ctx: FLContext):
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
rank = int(client_name.split('-')[1]) - 1
rabit_env = [
f'federated_server_address={self._server_address}',
f'federated_world_size={self._world_size}',
f'federated_rank={rank}',
f'federated_server_cert={self._server_cert_path}',
f'federated_client_key={self._client_key_path}',
f'federated_client_cert={self._client_cert_path}'
]
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
communicator_env = {
'federated_server_address': self._server_address,
'federated_world_size': self._world_size,
'federated_rank': rank,
'federated_server_cert': self._server_cert_path,
'federated_client_key': self._client_key_path,
'federated_client_cert': self._client_cert_path
}
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train')
dtest = xgb.DMatrix('agaricus.txt.test')
Expand All @@ -86,4 +86,4 @@ def _do_training(self, fl_ctx: FLContext):
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
run_dir = workspace.get_run_dir(run_number)
bst.save_model(os.path.join(run_dir, "test.model.json"))
xgb.rabit.tracker_print("Finished training\n")
xgb.collective.communicator_print("Finished training\n")
Expand Up @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter

import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}

import org.apache.commons.logging.LogFactory
Expand All @@ -46,7 +46,7 @@ object XGBoost {
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
Communicator.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
Expand All @@ -59,7 +59,7 @@ object XGBoost {
.map(_.toString.toInt).getOrElse(0)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
earlyStoppingRound = numEarlyStoppingRounds)
Rabit.shutdown()
Communicator.shutdown()
collector.collect(new XGBoostModel(booster))
}
}
Expand Down
Expand Up @@ -22,7 +22,7 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}

import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
Expand Down Expand Up @@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider {
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
Communicator.init(rabitEnv.asJava)
}

val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
Expand Down Expand Up @@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider {
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Rabit.shutdown()
Communicator.shutdown()
}
ret
}
Expand Down
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
Expand Down Expand Up @@ -303,7 +303,7 @@ object XGBoost extends Serializable {
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0

try {
Rabit.init(rabitEnv)
Communicator.init(rabitEnv)

watches = buildWatchesAndCheck(buildWatches)

Expand Down Expand Up @@ -342,7 +342,7 @@ object XGBoost extends Serializable {
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
throw xgbException
} finally {
Rabit.shutdown()
Communicator.shutdown()
if (watches != null) watches.delete()
}
}
Expand Down

This file was deleted.