From 2d6e0dc84b3aff89c7971a1d26ef07ddfa6f91b8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 21 Sep 2022 16:19:50 -0700 Subject: [PATCH 01/18] Switch from rabit to the collective communicator --- demo/nvflare/custom/trainer.py | 20 +- .../dmlc/xgboost4j/scala/flink/XGBoost.scala | 6 +- .../xgboost4j/scala/spark/PreXGBoost.scala | 6 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 6 +- .../scala/spark/RabitRobustnessSuite.scala | 277 ---------- ... XGBoostCommunicatorRegressionSuite.scala} | 14 +- .../java/ml/dmlc/xgboost4j/java/Rabit.java | 154 ------ .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 6 +- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 13 - .../xgboost4j/src/native/xgboost4j.cpp | 105 ---- jvm-packages/xgboost4j/src/native/xgboost4j.h | 56 -- .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 4 +- plugin/federated/CMakeLists.txt | 2 +- plugin/federated/engine_federated.cc | 197 ------- plugin/federated/federated_communicator.h | 34 ++ python-package/xgboost/__init__.py | 5 +- python-package/xgboost/callback.py | 16 +- python-package/xgboost/dask.py | 36 +- python-package/xgboost/rabit.py | 249 --------- python-package/xgboost/spark/core.py | 11 +- python-package/xgboost/spark/utils.py | 20 +- rabit/CMakeLists.txt | 4 +- src/c_api/c_api.cc | 27 +- src/cli_main.cc | 52 +- src/collective/communicator-inl.h | 198 +++++++ src/collective/communicator.cc | 7 +- src/collective/communicator.h | 34 -- src/collective/device_communicator.cuh | 16 +- .../device_communicator_adapter.cuh | 35 +- src/collective/nccl_device_communicator.cuh | 43 +- src/collective/noop_communicator.h | 28 + src/common/device_helpers.cu | 139 ----- src/common/device_helpers.cuh | 512 +----------------- src/common/quantile.cc | 29 +- src/common/quantile.cu | 23 +- src/common/quantile.cuh | 3 +- src/common/random.h | 8 +- src/common/timer.cc | 10 +- src/data/data.cc | 48 +- src/data/iterative_dmatrix.cc | 5 +- src/data/iterative_dmatrix.cu | 4 +- src/data/simple_dmatrix.cc | 4 +- src/data/simple_dmatrix.cu | 2 +- src/data/sparse_page_dmatrix.cc | 8 +- src/gbm/gbtree.cc | 2 +- src/learner.cc | 11 +- src/logging.cc | 4 +- src/metric/auc.cc | 27 +- src/metric/auc.cu | 12 +- src/metric/auc.h | 9 +- src/metric/elementwise_metric.cu | 6 +- src/metric/multiclass_metric.cu | 2 +- src/metric/rank_metric.cc | 15 +- src/metric/survival_metric.cu | 2 +- src/objective/adaptive.h | 7 +- src/tree/hist/histogram.h | 6 +- src/tree/updater_approx.cc | 4 +- src/tree/updater_colmaker.cc | 2 +- src/tree/updater_gpu_hist.cu | 46 +- src/tree/updater_quantile_hist.cc | 4 +- src/tree/updater_refresh.cc | 12 +- src/tree/updater_sync.cc | 14 +- tests/cpp/common/test_quantile.cu | 9 +- tests/distributed/distributed_gpu.py | 10 +- tests/distributed/test_basic.py | 6 +- tests/distributed/test_federated.py | 2 +- tests/distributed/test_issue3402.py | 2 +- tests/python/test_tracker.py | 22 +- 68 files changed, 659 insertions(+), 2053 deletions(-) delete mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala rename jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/{XGBoostRabitRegressionSuite.scala => XGBoostCommunicatorRegressionSuite.scala} (89%) delete mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java delete mode 100644 plugin/federated/engine_federated.cc delete mode 100644 python-package/xgboost/rabit.py create mode 100644 src/collective/communicator-inl.h create mode 100644 src/collective/noop_communicator.h delete mode 100644 src/common/device_helpers.cu diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py index c19d9799f143..fd56da0e57fe 100644 --- a/demo/nvflare/custom/trainer.py +++ b/demo/nvflare/custom/trainer.py @@ -52,15 +52,15 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, def _do_training(self, fl_ctx: FLContext): client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME) rank = int(client_name.split('-')[1]) - 1 - rabit_env = [ - f'federated_server_address={self._server_address}', - f'federated_world_size={self._world_size}', - f'federated_rank={rank}', - f'federated_server_cert={self._server_cert_path}', - f'federated_client_key={self._client_key_path}', - f'federated_client_cert={self._client_cert_path}' - ] - with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): + communicator_env = { + 'federated_server_address': self._server_address, + 'federated_world_size': self._world_size, + 'federated_rank': rank, + 'federated_server_cert': self._server_cert_path, + 'federated_client_key': self._client_key_path, + 'federated_client_cert': self._client_cert_path + } + with xgb.collective.CommunicatorContext(**communicator_env): # Load file, file will not be sharded in federated mode. dtrain = xgb.DMatrix('agaricus.txt.train') dtest = xgb.DMatrix('agaricus.txt.test') @@ -86,4 +86,4 @@ def _do_training(self, fl_ctx: FLContext): run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) run_dir = workspace.get_run_dir(run_number) bst.save_model(os.path.join(run_dir, "test.model.json")) - xgb.rabit.tracker_print("Finished training\n") + xgb.collective.communicator_print("Finished training\n") diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala index c9aa1631f02e..6878f1865bbe 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.flink import scala.collection.JavaConverters.asScalaIteratorConverter import ml.dmlc.xgboost4j.LabeledPoint -import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker} +import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala} import org.apache.commons.logging.LogFactory @@ -46,7 +46,7 @@ object XGBoost { collector: Collector[XGBoostModel]): Unit = { workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask)) logger.info("start with env" + workerEnvs.toString) - Rabit.init(workerEnvs) + Communicator.init(workerEnvs) val mapper = (x: LabeledVector) => { val (index, value) = x.vector.toSeq.unzip LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray) @@ -59,7 +59,7 @@ object XGBoost { .map(_.toString.toInt).getOrElse(0) val booster = XGBoostScala.train(trainMat, paramMap, round, watches, earlyStoppingRound = numEarlyStoppingRounds) - Rabit.shutdown() + Communicator.shutdown() collector.collect(new XGBoostModel(booster)) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 818842608ab8..176a54832859 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -22,7 +22,7 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import scala.collection.{AbstractIterator, Iterator, mutable} -import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.java.Communicator import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon @@ -266,7 +266,7 @@ object PreXGBoost extends PreXGBoostProvider { if (batchCnt == 0) { val rabitEnv = Array( "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Rabit.init(rabitEnv.asJava) + Communicator.init(rabitEnv.asJava) } val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol)) @@ -298,7 +298,7 @@ object PreXGBoost extends PreXGBoostProvider { override def next(): Row = { val ret = batchIterImpl.next() if (!batchIterImpl.hasNext) { - Rabit.shutdown() + Communicator.shutdown() } ret } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index c49655217468..281997295850 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import scala.util.Random import scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager @@ -303,7 +303,7 @@ object XGBoost extends Serializable { val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0 try { - Rabit.init(rabitEnv) + Communicator.init(rabitEnv) watches = buildWatchesAndCheck(buildWatches) @@ -342,7 +342,7 @@ object XGBoost extends Serializable { logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException) throw xgbException } finally { - Rabit.shutdown() + Communicator.shutdown() if (watches != null) watches.delete() } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala deleted file mode 100644 index 26ea2ef71595..000000000000 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala +++ /dev/null @@ -1,277 +0,0 @@ -/* - Copyright (c) 2014-2022 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import java.util.concurrent.LinkedBlockingDeque - -import scala.util.Random - -import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker} -import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} -import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus -import ml.dmlc.xgboost4j.scala.DMatrix -import org.scalatest.{FunSuite} - -class RabitRobustnessSuite extends FunSuite with PerTest { - - private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = { - val classifier = new XGBoostClassifier(paramMap) - val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc) - xgbParamsFactory.buildXGBRuntimeParams - } - - - test("Customize host ip and python exec for Rabit tracker") { - val hostIp = "192.168.22.111" - val pythonExec = "/usr/bin/python3" - - val paramMap = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "python", hostIp)) - val xgbExecParams = getXGBoostExecutionParams(paramMap) - val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) - tracker match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.contains(hostIp)) - assert(cmd.startsWith("python")) - case _ => assert(false, "expected python tracker implementation") - } - - val paramMap1 = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "python", "", pythonExec)) - val xgbExecParams1 = getXGBoostExecutionParams(paramMap1) - val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf) - tracker1 match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.startsWith(pythonExec)) - assert(!cmd.contains(hostIp)) - case _ => assert(false, "expected python tracker implementation") - } - - val paramMap2 = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec)) - val xgbExecParams2 = getXGBoostExecutionParams(paramMap2) - val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf) - tracker2 match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.startsWith(pythonExec)) - assert(cmd.contains(s" --host-ip=${hostIp}")) - case _ => assert(false, "expected python tracker implementation") - } - } - - test("training with Scala-implemented Rabit tracker") { - val eval = new EvalError() - val training = buildDataFrame(Classification.train) - val testDM = new DMatrix(Classification.test.iterator) - val paramMap = Map("eta" -> "1", "max_depth" -> "6", - "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")) - val model = new XGBoostClassifier(paramMap).fit(training) - assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) - } - - test("test Rabit allreduce to validate Scala-implemented Rabit tracker") { - val vectorLength = 100 - val rdd = sc.parallelize( - (1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache() - - val tracker = new ScalaRabitTracker(numWorkers) - tracker.start(0) - val trackerEnvs = tracker.getWorkerEnvs - val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]() - - val rawData = rdd.mapPartitions { iter => - Iterator(iter.toArray) - }.collect() - - val maxVec = (0 until vectorLength).toArray.map { j => - (0 until numWorkers).toArray.map { i => rawData(i)(j) }.max - } - - val allReduceResults = rdd.mapPartitions { iter => - Rabit.init(trackerEnvs) - val arr = iter.toArray - val results = Rabit.allReduce(arr, Rabit.OpType.MAX) - Rabit.shutdown() - Iterator(results) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - allReduceResults.foreachPartition(() => _) - val byPartitionResults = allReduceResults.collect() - assert(byPartitionResults(0).length == vectorLength) - collectedAllReduceResults.put(byPartitionResults(0)) - } - } - sparkThread.start() - assert(tracker.waitFor(0L) == 0) - sparkThread.join() - - assert(collectedAllReduceResults.poll().sameElements(maxVec)) - } - - test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") { - /* - Deliberately create new instances of SparkContext in each unit test to avoid reusing the - same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes - by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore - corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent - tests on a reentrant thread will crash the entire Spark application, an undesired side-effect - that should be avoided. - */ - val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() - - val tracker = new PyRabitTracker(numWorkers) - tracker.start(0) - val trackerEnvs = tracker.getWorkerEnvs - - val workerCount: Int = numWorkers - /* - Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the - last created worker. A cascading event chain will be triggered once the RuntimeException is - thrown: the thread running the dummy spark job (sparkThread) catches the exception and - delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself. - - The Java RabitTracker class reacts to exceptions by killing the spawned process running - the Python tracker. If at least one Rabit worker has yet connected to the tracker before - it is killed, the resulted connection failure will trigger the Rabit worker to call - "exit(-1);" in the native C++ code, effectively ending the dummy Spark task. - - In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are - isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks - running in separate containers. However, as unit tests are run in Spark local mode, in which - tasks are executed by threads belonging to the same process, one thread calling "exit(-1);" - ultimately kills the entire process, which also happens to host the Spark driver, causing - the entire Spark application to crash. - - To prevent unit tests from crashing, deterministic delays were introduced to make sure that - the exception is thrown at last, ideally after all worker connections have been established. - For the same reason, the Java RabitTracker class delays the killing of the Python tracker - process to ensure that pending worker connections are handled. - */ - val dummyTasks = rdd.mapPartitions { iter => - Rabit.init(trackerEnvs) - val index = iter.next() - Thread.sleep(100 + index * 10) - if (index == workerCount) { - // kill the worker by throwing an exception - throw new RuntimeException("Worker exception.") - } - Rabit.shutdown() - Iterator(index) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - // forces a Spark job. - dummyTasks.foreachPartition(() => _) - } - } - - sparkThread.setUncaughtExceptionHandler(tracker) - sparkThread.start() - assert(tracker.waitFor(0) != 0) - } - - test("test Scala RabitTracker's exception handling: it should not hang forever.") { - val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() - - val tracker = new ScalaRabitTracker(numWorkers) - tracker.start(0) - val trackerEnvs = tracker.getWorkerEnvs - - val workerCount: Int = numWorkers - val dummyTasks = rdd.mapPartitions { iter => - Rabit.init(trackerEnvs) - val index = iter.next() - Thread.sleep(100 + index * 10) - if (index == workerCount) { - // kill the worker by throwing an exception - throw new RuntimeException("Worker exception.") - } - Rabit.shutdown() - Iterator(index) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - // forces a Spark job. - dummyTasks.foreachPartition(() => _) - } - } - sparkThread.setUncaughtExceptionHandler(tracker) - sparkThread.start() - assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) - } - - test("test Scala RabitTracker's workerConnectionTimeout") { - val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() - - val tracker = new ScalaRabitTracker(numWorkers) - tracker.start(500) - val trackerEnvs = tracker.getWorkerEnvs - - val dummyTasks = rdd.mapPartitions { iter => - val index = iter.next() - // simulate that the first worker cannot connect to tracker due to network issues. - if (index != 1) { - Rabit.init(trackerEnvs) - Thread.sleep(1000) - Rabit.shutdown() - } - - Iterator(index) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - // forces a Spark job. - dummyTasks.foreachPartition(() => _) - } - } - sparkThread.setUncaughtExceptionHandler(tracker) - sparkThread.start() - // should fail due to connection timeout - assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) - } - - test("should allow the dataframe containing rabit calls to be partially evaluated for" + - " multiple times (ISSUE-4406)") { - val paramMap = Map( - "eta" -> "1", - "max_depth" -> "6", - "silent" -> "1", - "objective" -> "binary:logistic") - val trainingDF = buildDataFrame(Classification.train) - val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10, - "num_workers" -> numWorkers)).fit(trainingDF) - val prediction = model.transform(trainingDF) - // a partial evaluation of dataframe will cause rabit initialized but not shutdown in some - // threads - prediction.show() - // a full evaluation here will re-run init and shutdown all rabit proxy - // expecting no error - prediction.collect() - } -} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala similarity index 89% rename from jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala rename to jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala index 00a29681ca73..1094a89f111a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala @@ -16,7 +16,7 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.java.Communicator import ml.dmlc.xgboost4j.scala.Booster import scala.collection.JavaConverters._ @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException -class XGBoostRabitRegressionSuite extends FunSuite with PerTest { +class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest { val predictionErrorMin = 0.00001f val maxFailure = 2; @@ -47,8 +47,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)) .fit(training) - assert(Rabit.rabitEnvs.asScala.size > 3) - Rabit.rabitEnvs.asScala.foreach( item => { + assert(Communicator.communicatorEnvs.asScala.size > 3) + Communicator.communicatorEnvs.asScala.foreach( item => { if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") }) @@ -70,8 +70,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1) ).fit(training) - assert(Rabit.rabitEnvs.asScala.size > 3) - Rabit.rabitEnvs.asScala.foreach( item => { + assert(Communicator.communicatorEnvs.asScala.size > 3) + Communicator.communicatorEnvs.asScala.foreach( item => { if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") }) // check the equality of single instance prediction @@ -85,7 +85,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { test("test rabit timeout fail handle") { val training = buildDataFrame(Classification.train) // mock rank 0 failure during 8th allreduce synchronization - Rabit.mockList = Array("0,8,0,0").toList.asJava + Communicator.mockList = Array("0,8,0,0").toList.asJava intercept[SparkException] { new XGBoostClassifier(Map( diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java deleted file mode 100644 index 7e019dc65ccc..000000000000 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ /dev/null @@ -1,154 +0,0 @@ -package ml.dmlc.xgboost4j.java; - -import java.io.Serializable; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; - -/** - * Rabit global class for synchronization. - */ -public class Rabit { - - public enum OpType implements Serializable { - MAX(0), MIN(1), SUM(2), BITWISE_OR(3); - - private int op; - - public int getOperand() { - return this.op; - } - - OpType(int op) { - this.op = op; - } - } - - public enum DataType implements Serializable { - CHAR(0, 1), UCHAR(1, 1), INT(2, 4), UNIT(3, 4), - LONG(4, 8), ULONG(5, 8), FLOAT(6, 4), DOUBLE(7, 8), - LONGLONG(8, 8), ULONGLONG(9, 8); - - private int enumOp; - private int size; - - public int getEnumOp() { - return this.enumOp; - } - - public int getSize() { - return this.size; - } - - DataType(int enumOp, int size) { - this.enumOp = enumOp; - this.size = size; - } - } - - private static void checkCall(int ret) throws XGBoostError { - if (ret != 0) { - throw new XGBoostError(XGBoostJNI.XGBGetLastError()); - } - } - // used as way to test/debug passed rabit init parameters - public static Map rabitEnvs; - public static List mockList = new LinkedList<>(); - /** - * Initialize the rabit library on current working thread. - * @param envs The additional environment variables to pass to rabit. - * @throws XGBoostError - */ - public static void init(Map envs) throws XGBoostError { - rabitEnvs = envs; - String[] args = new String[envs.size() + mockList.size()]; - int idx = 0; - for (java.util.Map.Entry e : envs.entrySet()) { - args[idx++] = e.getKey() + '=' + e.getValue(); - } - // pass list of rabit mock strings eg mock=0,1,0,0 - for(String mock : mockList) { - args[idx++] = "mock=" + mock; - } - checkCall(XGBoostJNI.RabitInit(args)); - } - - /** - * Shutdown the rabit engine in current working thread, equals to finalize. - * @throws XGBoostError - */ - public static void shutdown() throws XGBoostError { - checkCall(XGBoostJNI.RabitFinalize()); - } - - /** - * Print the message on rabit tracker. - * @param msg - * @throws XGBoostError - */ - public static void trackerPrint(String msg) throws XGBoostError { - checkCall(XGBoostJNI.RabitTrackerPrint(msg)); - } - - /** - * Get version number of current stored model in the thread. - * which means how many calls to CheckPoint we made so far. - * @return version Number. - * @throws XGBoostError - */ - public static int versionNumber() throws XGBoostError { - int[] out = new int[1]; - checkCall(XGBoostJNI.RabitVersionNumber(out)); - return out[0]; - } - - /** - * get rank of current thread. - * @return the rank. - * @throws XGBoostError - */ - public static int getRank() throws XGBoostError { - int[] out = new int[1]; - checkCall(XGBoostJNI.RabitGetRank(out)); - return out[0]; - } - - /** - * get world size of current job. - * @return the worldsize - * @throws XGBoostError - */ - public static int getWorldSize() throws XGBoostError { - int[] out = new int[1]; - checkCall(XGBoostJNI.RabitGetWorldSize(out)); - return out[0]; - } - - /** - * perform Allreduce on distributed float vectors using operator op. - * This implementation of allReduce does not support customized prepare function callback in the - * native code, as this function is meant for testing purposes only (to test the Rabit tracker.) - * - * @param elements local elements on distributed workers. - * @param op operator used for Allreduce. - * @return All-reduced float elements according to the given operator. - */ - public static float[] allReduce(float[] elements, OpType op) { - DataType dataType = DataType.FLOAT; - ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length) - .order(ByteOrder.nativeOrder()); - - for (float el : elements) { - buffer.putFloat(el); - } - buffer.flip(); - - XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand()); - float[] results = new float[elements.length]; - buffer.asFloatBuffer().get(results); - - return results; - } -} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index bd521dda0b08..75e18957ff9f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -254,16 +254,16 @@ public static Booster trainAndSaveCheckpoint( } if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) { if (shouldPrint(params, iter)) { - Rabit.trackerPrint(String.format( + Communicator.communicatorPrint(String.format( "early stopping after %d rounds away from the best iteration", earlyStoppingRounds )); } break; } - if (Rabit.getRank() == 0 && shouldPrint(params, iter)) { + if (Communicator.getRank() == 0 && shouldPrint(params, iter)) { if (shouldPrint(params, iter)){ - Rabit.trackerPrint(evalInfo + '\n'); + Communicator.communicatorPrint(evalInfo + '\n'); } } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 72234f526b08..afe576598956 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -135,19 +135,6 @@ public final static native int XGBoosterDumpModelExWithFeatures( public final static native int XGBoosterSaveRabitCheckpoint(long handle); public final static native int XGBoosterGetNumFeature(long handle, long[] feature); - // rabit functions - public final static native int RabitInit(String[] args); - public final static native int RabitFinalize(); - public final static native int RabitTrackerPrint(String msg); - public final static native int RabitGetRank(int[] out); - public final static native int RabitGetWorldSize(int[] out); - public final static native int RabitVersionNumber(int[] out); - - // Perform Allreduce operation on data in sendrecvbuf. - // This JNI function does not support the callback function for data preparation yet. - final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count, - int enum_dtype, int enum_op); - // communicator functions public final static native int CommunicatorInit(String[] args); public final static native int CommunicatorFinalize(); diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index a89e0f07a341..749fa5b40cdb 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -872,111 +872,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea return ret; } -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitInit - * Signature: ([Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit - (JNIEnv *jenv, jclass jcls, jobjectArray jargs) { - std::vector args; - std::vector argv; - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs); - for (bst_ulong i = 0; i < len; ++i) { - jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i); - const char *s = jenv->GetStringUTFChars(arg, 0); - args.push_back(std::string(s, jenv->GetStringLength(arg))); - if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s); - if (args.back().length() == 0) args.pop_back(); - } - - for (size_t i = 0; i < args.size(); ++i) { - argv.push_back(&args[i][0]); - } - - if (RabitInit(args.size(), dmlc::BeginPtr(argv))) { - return 0; - } else { - return 1; - } -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitFinalize - * Signature: ()I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize - (JNIEnv *jenv, jclass jcls) { - if (RabitFinalize()) { - return 0; - } else { - return 1; - } -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitTrackerPrint - * Signature: (Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint - (JNIEnv *jenv, jclass jcls, jstring jmsg) { - std::string str(jenv->GetStringUTFChars(jmsg, 0), - jenv->GetStringLength(jmsg)); - JVM_CHECK_CALL(RabitTrackerPrint(str.c_str())); - return 0; -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitGetRank - * Signature: ([I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank - (JNIEnv *jenv, jclass jcls, jintArray jout) { - jint rank = RabitGetRank(); - jenv->SetIntArrayRegion(jout, 0, 1, &rank); - return 0; -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitGetWorldSize - * Signature: ([I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize - (JNIEnv *jenv, jclass jcls, jintArray jout) { - jint out = RabitGetWorldSize(); - jenv->SetIntArrayRegion(jout, 0, 1, &out); - return 0; -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitVersionNumber - * Signature: ([I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber - (JNIEnv *jenv, jclass jcls, jintArray jout) { - jint out = RabitVersionNumber(); - jenv->SetIntArrayRegion(jout, 0, 1, &out); - return 0; -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitAllreduce - * Signature: (Ljava/nio/ByteBuffer;III)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce - (JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) { - void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf); - JVM_CHECK_CALL(RabitAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op, NULL, NULL)); - - return 0; -} - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 7baae983cf51..5afe92b524ab 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -279,62 +279,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature (JNIEnv *, jclass, jlong, jlongArray); -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitInit - * Signature: ([Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit - (JNIEnv *, jclass, jobjectArray); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitFinalize - * Signature: ()I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize - (JNIEnv *, jclass); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitTrackerPrint - * Signature: (Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint - (JNIEnv *, jclass, jstring); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitGetRank - * Signature: ([I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank - (JNIEnv *, jclass, jintArray); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitGetWorldSize - * Signature: ([I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize - (JNIEnv *, jclass, jintArray); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitVersionNumber - * Signature: ([I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber - (JNIEnv *, jclass, jintArray); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: RabitAllreduce - * Signature: (Ljava/nio/ByteBuffer;III)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce - (JNIEnv *, jclass, jobject, jint, jint, jint); - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 7ea1604c3fd4..cf174c6dd82d 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -300,7 +300,7 @@ public void testCreateFromDenseMatrixRef() throws XGBoostError { public void testTrainWithDenseMatrixRef() throws XGBoostError { Map rabitEnv = new HashMap<>(); rabitEnv.put("DMLC_TASK_ID", "0"); - Rabit.init(rabitEnv); + Communicator.init(rabitEnv); DMatrix trainMat = null; BigDenseMatrix data0 = null; try { @@ -348,7 +348,7 @@ public void testTrainWithDenseMatrixRef() throws XGBoostError { else if (data0 != null) { data0.dispose(); } - Rabit.shutdown(); + Communicator.shutdown(); } } diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 39eac924dd89..24ba47abfb8e 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -23,6 +23,6 @@ target_sources(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) # Rabit engine for Federated Learning. -target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc) +target_sources(objxgboost PRIVATE federated_server.cc) target_link_libraries(objxgboost PRIVATE federated_client) target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc deleted file mode 100644 index feb589e54548..000000000000 --- a/plugin/federated/engine_federated.cc +++ /dev/null @@ -1,197 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include -#include - -#include "federated_client.h" -#include "rabit/internal/engine.h" -#include "rabit/internal/utils.h" - -namespace MPI { // NOLINT -// MPI data type to be compatible with existing MPI interface -class Datatype { - public: - size_t type_size; - explicit Datatype(size_t type_size) : type_size(type_size) {} -}; -} // namespace MPI - -namespace rabit { -namespace engine { - -/*! \brief implementation of engine using federated learning */ -class FederatedEngine : public IEngine { - public: - void Init(int argc, char *argv[]) { - // Parse environment variables first. - for (auto const &env_var : env_vars_) { - char const *value = getenv(env_var.c_str()); - if (value != nullptr) { - SetParam(env_var, value); - } - } - // Command line argument overrides. - for (int i = 0; i < argc; ++i) { - std::string const key_value = argv[i]; - auto const delimiter = key_value.find('='); - if (delimiter != std::string::npos) { - SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1)); - } - } - utils::Printf("Connecting to federated server %s, world size %d, rank %d", - server_address_.c_str(), world_size_, rank_); - if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) { - utils::Printf("Certificates not specified, turning off SSL."); - client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_)); - } else { - client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_, - client_key_, client_cert_)); - } - } - - void Finalize() { client_.reset(); } - - void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, - size_t size_prev_slice) override { - throw std::logic_error("FederatedEngine:: Allgather is not supported"); - } - - std::string Allgather(void *sendbuf, size_t total_size) { - std::string const send_buffer(reinterpret_cast(sendbuf), total_size); - return client_->Allgather(send_buffer); - } - - void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer, - PreprocFunction prepare_fun, void *prepare_arg) override { - throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead"); - } - - void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { - auto *buffer = reinterpret_cast(sendrecvbuf); - std::string const send_buffer(buffer, size); - auto const receive_buffer = - client_->Allreduce(send_buffer, static_cast(dtype), - static_cast(op)); - receive_buffer.copy(buffer, size); - } - - int GetRingPrevRank() const override { - throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported"); - } - - void Broadcast(void *sendrecvbuf, size_t size, int root) override { - if (world_size_ == 1) return; - auto *buffer = reinterpret_cast(sendrecvbuf); - std::string const send_buffer(buffer, size); - auto const receive_buffer = client_->Broadcast(send_buffer, root); - if (rank_ != root) { - receive_buffer.copy(buffer, size); - } - } - - int LoadCheckPoint() override { return 0; } - - void CheckPoint() override { version_number_ += 1; } - - int VersionNumber() const override { return version_number_; } - - /*! \brief get rank of current node */ - int GetRank() const override { return rank_; } - - /*! \brief get total number of */ - int GetWorldSize() const override { return world_size_; } - - /*! \brief whether it is distributed */ - bool IsDistributed() const override { return true; } - - /*! \brief get the host name of current node */ - std::string GetHost() const override { return "rank" + std::to_string(rank_); } - - void TrackerPrint(const std::string &msg) override { - // simply print information into the tracker - utils::Printf("%s", msg.c_str()); - } - - private: - void SetParam(std::string const &name, std::string const &val) { - if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { - server_address_ = val; - } else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) { - world_size_ = std::stoi(val); - } else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) { - rank_ = std::stoi(val); - } else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) { - server_cert_ = ReadFile(val); - } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) { - client_key_ = ReadFile(val); - } else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) { - client_cert_ = ReadFile(val); - } - } - - static std::string ReadFile(std::string const &path) { - auto stream = std::ifstream(path.data()); - std::ostringstream out; - out << stream.rdbuf(); - return out.str(); - } - - // clang-format off - std::vector const env_vars_{ - "FEDERATED_SERVER_ADDRESS", - "FEDERATED_WORLD_SIZE", - "FEDERATED_RANK", - "FEDERATED_SERVER_CERT", - "FEDERATED_CLIENT_KEY", - "FEDERATED_CLIENT_CERT" }; - // clang-format on - std::string server_address_{"localhost:9091"}; - int world_size_{1}; - int rank_{0}; - std::string server_cert_{}; - std::string client_key_{}; - std::string client_cert_{}; - std::unique_ptr client_{}; - int version_number_{0}; -}; - -// Singleton federated engine. -FederatedEngine engine; // NOLINT(cert-err58-cpp) - -/*! \brief initialize the synchronization module */ -bool Init(int argc, char *argv[]) { - try { - engine.Init(argc, argv); - return true; - } catch (std::exception const &e) { - fprintf(stderr, " failed in federated Init %s\n", e.what()); - return false; - } -} - -/*! \brief finalize synchronization module */ -bool Finalize() { - try { - engine.Finalize(); - return true; - } catch (const std::exception &e) { - fprintf(stderr, "failed in federated shutdown %s\n", e.what()); - return false; - } -} - -/*! \brief singleton method to get engine */ -IEngine *GetEngine() { return &engine; } - -// perform in-place allreduce, on sendrecvbuf -void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red, - mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, - void *prepare_arg) { - if (prepare_fun != nullptr) prepare_fun(prepare_arg); - if (engine.GetWorldSize() == 1) return; - engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op); -} -} // namespace engine -} // namespace rabit diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 6a3186b4f608..9defef719bba 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -11,6 +11,40 @@ namespace xgboost { namespace collective { +/** @brief Get the size of the data type. */ +inline std::size_t GetTypeSize(DataType data_type) { + std::size_t size{0}; + switch (data_type) { + case DataType::kInt8: + size = sizeof(std::int8_t); + break; + case DataType::kUInt8: + size = sizeof(std::uint8_t); + break; + case DataType::kInt32: + size = sizeof(std::int32_t); + break; + case DataType::kUInt32: + size = sizeof(std::uint32_t); + break; + case DataType::kInt64: + size = sizeof(std::int64_t); + break; + case DataType::kUInt64: + size = sizeof(std::uint64_t); + break; + case DataType::kFloat: + size = sizeof(float); + break; + case DataType::kDouble: + size = sizeof(double); + break; + default: + LOG(FATAL) << "Unknown data type."; + } + return size; +} + /** * @brief A Federated Learning communicator class that handles collective communication. */ diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 6c29de98d9dc..220093b47c4c 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -3,9 +3,8 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md """ -from . import rabit # noqa from . import tracker # noqa -from . import dask +from . import collective, dask from .core import ( Booster, DataIter, @@ -63,4 +62,6 @@ "XGBRFRegressor", # dask "dask", + # collective + "collective", ] diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 102d53ce4252..b1bf882b092d 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -13,7 +13,7 @@ from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast, Sequence, Any import numpy -from . import rabit +from . import collective from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees @@ -100,7 +100,7 @@ def _allreduce_metric(score: _ART) -> _ART: as final result. ''' - world = rabit.get_world_size() + world = collective.get_world_size() assert world != 0 if world == 1: return score @@ -108,7 +108,7 @@ def _allreduce_metric(score: _ART) -> _ART: raise ValueError( 'xgboost.cv function should not be used in distributed environment.') arr = numpy.array([score]) - arr = rabit.allreduce(arr, rabit.Op.SUM) / world + arr = collective.allreduce(arr, collective.Op.SUM) / world return arr[0] @@ -485,7 +485,7 @@ def after_iteration(self, model: _Model, epoch: int, return False msg: str = f'[{epoch}]' - if rabit.get_rank() == self.printer_rank: + if collective.get_rank() == self.printer_rank: for data, metric in evals_log.items(): for metric_name, log in metric.items(): stdv: Optional[float] = None @@ -498,7 +498,7 @@ def after_iteration(self, model: _Model, epoch: int, msg += '\n' if (epoch % self.period) == 0 or self.period == 1: - rabit.tracker_print(msg) + collective.communicator_print(msg) self._latest = None else: # There is skipped message @@ -506,8 +506,8 @@ def after_iteration(self, model: _Model, epoch: int, return False def after_training(self, model: _Model) -> _Model: - if rabit.get_rank() == self.printer_rank and self._latest is not None: - rabit.tracker_print(self._latest) + if collective.get_rank() == self.printer_rank and self._latest is not None: + collective.communicator_print(self._latest) return model @@ -552,7 +552,7 @@ def after_iteration(self, model: _Model, epoch: int, path = os.path.join(self._path, self._name + '_' + str(epoch) + ('.pkl' if self._as_pickle else '.json')) self._epoch = 0 - if rabit.get_rank() == 0: + if collective.get_rank() == 0: if self._as_pickle: with open(path, 'wb') as fd: pickle.dump(model, fd) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 9a74d0143681..675c3c8a499e 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -58,7 +58,7 @@ import numpy -from . import config, rabit +from . import collective, config from ._typing import _T, FeatureNames, FeatureTypes from .callback import TrainingCallback from .compat import DataFrame, LazyLoader, concat, lazy_isinstance @@ -117,7 +117,7 @@ TrainReturnT = Dict[str, Any] # type:ignore __all__ = [ - "RabitContext", + "CommunicatorContext", "DaskDMatrix", "DaskDeviceQuantileDMatrix", "DaskXGBRegressor", @@ -163,7 +163,7 @@ def _try_start_tracker( if isinstance(addrs[0], tuple): host_ip = addrs[0][0] port = addrs[0][1] - rabit_context = RabitTracker( + rabit_tracker = RabitTracker( host_ip=get_host_ip(host_ip), n_workers=n_workers, port=port, @@ -173,12 +173,12 @@ def _try_start_tracker( addr = addrs[0] assert isinstance(addr, str) or addr is None host_ip = get_host_ip(addr) - rabit_context = RabitTracker( + rabit_tracker = RabitTracker( host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task" ) - env.update(rabit_context.worker_envs()) - rabit_context.start(n_workers) - thread = Thread(target=rabit_context.join) + env.update(rabit_tracker.worker_envs()) + rabit_tracker.start(n_workers) + thread = Thread(target=rabit_tracker.join) thread.daemon = True thread.start() except socket.error as e: @@ -218,11 +218,11 @@ def _assert_dask_support() -> None: LOGGER.warning(msg) -class RabitContext(rabit.RabitContext): - """A context controlling rabit initialization and finalization.""" +class CommunicatorContext(collective.CommunicatorContext): + """A context controlling collective communicator initialization and finalization.""" - def __init__(self, args: List[bytes]) -> None: - super().__init__(args) + def __init__(self, **args) -> None: + super().__init__(**args) worker = distributed.get_worker() with distributed.worker_client() as client: info = client.scheduler_info() @@ -232,9 +232,7 @@ def __init__(self, args: List[bytes]) -> None: # not the same as task ID is string and "10" is sorted before "2") with dask # worker ID. This outsources the rank assignment to dask and prevents # non-deterministic issue. - self.args.append( - (f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode() - ) + self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{wid}]:" + str(worker.address) def dconcat(value: Sequence[_T]) -> _T: @@ -816,7 +814,7 @@ def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix: async def _get_rabit_args( n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client" -) -> List[bytes]: +) -> Dict[str, Union[str, int]]: """Get rabit context arguments from data distribution in DaskDMatrix.""" # There are 3 possible different addresses: # 1. Provided by user via dask.config @@ -855,9 +853,7 @@ async def _get_rabit_args( env = await client.run_on_scheduler( _start_tracker, n_workers, sched_addr, user_addr ) - - rabit_args = [f"{k}={v}".encode() for k, v in env.items()] - return rabit_args + return env def _get_dask_config() -> Optional[Dict[str, Any]]: @@ -912,7 +908,7 @@ async def _train_async( def dispatched_train( parameters: Dict, - rabit_args: List[bytes], + rabit_args: Dict[str, Union[str, int]], train_id: int, evals_name: List[str], evals_id: List[int], @@ -936,7 +932,7 @@ def dispatched_train( n_threads = dwnt local_param.update({"nthread": n_threads, "n_jobs": n_threads}) local_history: TrainingCallback.EvalsLog = {} - with RabitContext(rabit_args), config.config_context(**global_config): + with CommunicatorContext(**rabit_args), config.config_context(**global_config): Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads) evals: List[Tuple[DMatrix, str]] = [] for i, ref in enumerate(refs): diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py deleted file mode 100644 index f5da7a353330..000000000000 --- a/python-package/xgboost/rabit.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Distributed XGBoost Rabit related API.""" -import ctypes -from enum import IntEnum, unique -import logging -import pickle -from typing import Any, TypeVar, Callable, Optional, cast, List, Union - -import numpy as np - -from .core import _LIB, c_str, _check_call - -LOGGER = logging.getLogger("[xgboost.rabit]") - - -def _init_rabit() -> None: - """internal library initializer.""" - if _LIB is not None: - _LIB.RabitGetRank.restype = ctypes.c_int - _LIB.RabitGetWorldSize.restype = ctypes.c_int - _LIB.RabitIsDistributed.restype = ctypes.c_int - _LIB.RabitVersionNumber.restype = ctypes.c_int - - -def init(args: Optional[List[bytes]] = None) -> None: - """Initialize the rabit library with arguments""" - if args is None: - args = [] - arr = (ctypes.c_char_p * len(args))() - arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args) - _LIB.RabitInit(len(arr), arr) - - -def finalize() -> None: - """Finalize the process, notify tracker everything is done.""" - _LIB.RabitFinalize() - - -def get_rank() -> int: - """Get rank of current process. - - Returns - ------- - rank : int - Rank of current process. - """ - ret = _LIB.RabitGetRank() - return ret - - -def get_world_size() -> int: - """Get total number workers. - - Returns - ------- - n : int - Total number of process. - """ - ret = _LIB.RabitGetWorldSize() - return ret - - -def is_distributed() -> int: - '''If rabit is distributed.''' - is_dist = _LIB.RabitIsDistributed() - return is_dist - - -def tracker_print(msg: Any) -> None: - """Print message to the tracker. - - This function can be used to communicate the information of - the progress to the tracker - - Parameters - ---------- - msg : str - The message to be printed to tracker. - """ - if not isinstance(msg, str): - msg = str(msg) - is_dist = _LIB.RabitIsDistributed() - if is_dist != 0: - _check_call(_LIB.RabitTrackerPrint(c_str(msg))) - else: - print(msg.strip(), flush=True) - - -def get_processor_name() -> bytes: - """Get the processor name. - - Returns - ------- - name : str - the name of processor(host) - """ - mxlen = 256 - length = ctypes.c_ulong() - buf = ctypes.create_string_buffer(mxlen) - _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen) - return buf.value - - -T = TypeVar("T") # pylint:disable=invalid-name - - -def broadcast(data: T, root: int) -> T: - """Broadcast object from one node to all other nodes. - - Parameters - ---------- - data : any type that can be pickled - Input data, if current rank does not equal root, this can be None - root : int - Rank of the node to broadcast data from. - - Returns - ------- - object : int - the result of broadcast. - """ - rank = get_rank() - length = ctypes.c_ulong() - if root == rank: - assert data is not None, 'need to pass in data when broadcasting' - s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) - length.value = len(s) - # run first broadcast - _check_call(_LIB.RabitBroadcast(ctypes.byref(length), - ctypes.sizeof(ctypes.c_ulong), root)) - if root != rank: - dptr = (ctypes.c_char * length.value)() - # run second - _check_call(_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), - length.value, root)) - data = pickle.loads(dptr.raw) - del dptr - else: - _check_call(_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), - length.value, root)) - del s - return data - - -# enumeration of dtypes -DTYPE_ENUM__ = { - np.dtype('int8'): 0, - np.dtype('uint8'): 1, - np.dtype('int32'): 2, - np.dtype('uint32'): 3, - np.dtype('int64'): 4, - np.dtype('uint64'): 5, - np.dtype('float32'): 6, - np.dtype('float64'): 7 -} - - -@unique -class Op(IntEnum): - '''Supported operations for rabit.''' - MAX = 0 - MIN = 1 - SUM = 2 - OR = 3 - - -def allreduce( # pylint:disable=invalid-name - data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None -) -> np.ndarray: - """Perform allreduce, return the result. - - Parameters - ---------- - data : - Input data. - op : - Reduction operators, can be MIN, MAX, SUM, BITOR - prepare_fun : - Lazy preprocessing function, if it is not None, prepare_fun(data) - will be called by the function before performing allreduce, to initialize the data - If the result of Allreduce can be recovered directly, - then prepare_fun will NOT be called - - Returns - ------- - result : - The result of allreduce, have same shape as data - - Notes - ----- - This function is not thread-safe. - """ - if not isinstance(data, np.ndarray): - raise Exception('allreduce only takes in numpy.ndarray') - buf = data.ravel() - if buf.base is data.base: - buf = buf.copy() - if buf.dtype not in DTYPE_ENUM__: - raise Exception(f"data type {buf.dtype} not supported") - if prepare_fun is None: - _check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), - buf.size, DTYPE_ENUM__[buf.dtype], - int(op), None, None)) - else: - func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - - def pfunc(_: Any) -> None: - """prepare function.""" - fn = cast(Callable[[np.ndarray], None], prepare_fun) - fn(data) - _check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), - buf.size, DTYPE_ENUM__[buf.dtype], - op, func_ptr(pfunc), None)) - return buf - - -def version_number() -> int: - """Returns version number of current stored model. - - This means how many calls to CheckPoint we made so far. - - Returns - ------- - version : int - Version number of currently stored model - """ - ret = _LIB.RabitVersionNumber() - return ret - - -class RabitContext: - """A context controlling rabit initialization and finalization.""" - - def __init__(self, args: List[bytes] = None) -> None: - if args is None: - args = [] - self.args = args - - def __enter__(self) -> None: - init(self.args) - assert is_distributed() - LOGGER.debug("-------------- rabit say hello ------------------") - - def __exit__(self, *args: List) -> None: - finalize() - LOGGER.debug("--------------- rabit say bye ------------------") - - -# initialization script -_init_rabit() diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index adb46d92c4c9..b99941dcbf01 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -3,6 +3,7 @@ # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=too-few-public-methods, too-many-lines from typing import Iterator, Optional, Tuple +import json import numpy as np import pandas as pd @@ -57,7 +58,7 @@ HasQueryIdCol, ) from .utils import ( - RabitContext, + CommunicatorContext, _get_args_from_message_list, _get_default_params_from_func, _get_gpu_id, @@ -747,7 +748,7 @@ def _train_booster(pandas_df_iter): ): dmatrix_kwargs["max_bin"] = booster_params["max_bin"] - _rabit_args = "" + _rabit_args = {} if context.partitionId() == 0: get_logger("XGBoostPySpark").info( "booster params: %s\n" @@ -758,12 +759,12 @@ def _train_booster(pandas_df_iter): dmatrix_kwargs, ) - _rabit_args = str(_get_rabit_args(context, num_workers)) + _rabit_args = _get_rabit_args(context, num_workers) - messages = context.allGather(message=str(_rabit_args)) + messages = context.allGather(message=json.dumps(_rabit_args)) _rabit_args = _get_args_from_message_list(messages) evals_result = {} - with RabitContext(_rabit_args, context): + with CommunicatorContext(context, **_rabit_args): dtrain, dvalid = create_dmatrix_from_partitions( pandas_df_iter, features_cols_names, diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index d05ef7623a37..79c040f2701f 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -1,6 +1,7 @@ # type: ignore """Xgboost pyspark integration submodule for helper functions.""" import inspect +import json import logging import sys from threading import Thread @@ -9,7 +10,7 @@ from pyspark.sql.session import SparkSession from xgboost.tracker import RabitTracker -from xgboost import rabit +from xgboost import collective def get_class_name(cls): @@ -36,21 +37,21 @@ def _get_default_params_from_func(func, unsupported_set): return filtered_params_dict -class RabitContext: +class CommunicatorContext: """ - A context controlling rabit initialization and finalization. + A context controlling collective communicator initialization and finalization. This isn't specificially necessary (note Part 3), but it is more understandable coding-wise. """ - def __init__(self, args, context): + def __init__(self, context, **args): self.args = args - self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode()) + self.args["DMLC_TASK_ID"] = str(context.partitionId()) def __enter__(self): - rabit.init(self.args) + collective.init(**self.args) def __exit__(self, *args): - rabit.finalize() + collective.finalize() def _start_tracker(context, n_workers): @@ -74,8 +75,7 @@ def _get_rabit_args(context, n_workers): """ # pylint: disable=consider-using-f-string env = _start_tracker(context, n_workers) - rabit_args = [("%s=%s" % item).encode() for item in env.items()] - return rabit_args + return env def _get_host_ip(context): @@ -95,7 +95,7 @@ def _get_args_from_message_list(messages): if message != "": output = message break - return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")] + return json.loads(output) def _get_spark_session(): diff --git a/rabit/CMakeLists.txt b/rabit/CMakeLists.txt index 3a76794f5f58..ad39fb249791 100644 --- a/rabit/CMakeLists.txt +++ b/rabit/CMakeLists.txt @@ -6,9 +6,7 @@ set(RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc ${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc) -if (PLUGIN_FEDERATED) - # Skip the engine if the Federated Learning plugin is enabled. -elseif (RABIT_BUILD_MPI) +if (RABIT_BUILD_MPI) list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc) elseif (RABIT_MOCK) list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9b5bea3acd00..19e04f903475 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -22,7 +22,7 @@ #include "c_api_error.h" #include "c_api_utils.h" -#include "../collective/communicator.h" +#include "../collective/communicator-inl.h" #include "../common/io.h" #include "../common/charconv.h" #include "../data/adapter.h" @@ -210,7 +210,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle #if defined(XGBOOST_USE_FEDERATED) LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; #else - if (rabit::IsDistributed()) { + if (collective::IsDistributed()) { LOG(CONSOLE) << "XGBoost distributed mode detected, " << "will split data among workers"; load_row_split = true; @@ -1371,59 +1371,54 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, API_END(); } -using xgboost::collective::Communicator; - XGB_DLL int XGCommunicatorInit(char const* json_config) { API_BEGIN(); - Json config { Json::Load(StringView{json_config}) }; - Communicator::Init(config); + collective::Init(json_config); API_END(); } XGB_DLL int XGCommunicatorFinalize(void) { API_BEGIN(); - Communicator::Finalize(); + collective::Finalize(); API_END(); } XGB_DLL int XGCommunicatorGetRank(void) { - return Communicator::Get()->GetRank(); + return collective::GetRank(); } XGB_DLL int XGCommunicatorGetWorldSize(void) { - return Communicator::Get()->GetWorldSize(); + return collective::GetWorldSize(); } XGB_DLL int XGCommunicatorIsDistributed(void) { - return Communicator::Get()->IsDistributed(); + return collective::IsDistributed(); } XGB_DLL int XGCommunicatorPrint(char const *message) { API_BEGIN(); - Communicator::Get()->Print(message); + collective::Print(message); API_END(); } XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { API_BEGIN(); auto& local = *GlobalConfigAPIThreadLocalStore::Get(); - local.ret_str = Communicator::Get()->GetProcessorName(); + local.ret_str = collective::GetProcessorName(); *name_str = local.ret_str.c_str(); API_END(); } XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { API_BEGIN(); - Communicator::Get()->Broadcast(send_receive_buffer, size, root); + collective::Broadcast(send_receive_buffer, size, root); API_END(); } XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, int enum_op) { API_BEGIN(); - Communicator::Get()->AllReduce( - send_receive_buffer, count, static_cast(enum_dtype), - static_cast(enum_op)); + collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op); API_END(); } diff --git a/src/cli_main.cc b/src/cli_main.cc index 9e3b9e6c820f..7fbe0a34b774 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -25,6 +25,7 @@ #include #include #include +#include "collective/communicator-inl.h" #include "common/common.h" #include "common/config.h" #include "common/io.h" @@ -156,7 +157,7 @@ struct CLIParam : public XGBoostParameter { if (name_pred == "stdout") { save_period = 0; } - if (dsplit == 0 && rabit::IsDistributed()) { + if (dsplit == 0 && collective::IsDistributed()) { dsplit = 2; } } @@ -186,26 +187,22 @@ class CLI { kHelp } print_info_ {kNone}; - int ResetLearner(std::vector> const &matrices) { + void ResetLearner(std::vector> const &matrices) { learner_.reset(Learner::Create(matrices)); - int version = rabit::LoadCheckPoint(); - if (version == 0) { - if (param_.model_in != CLIParam::kNull) { - this->LoadModel(param_.model_in, learner_.get()); - learner_->SetParams(param_.cfg); - } else { - learner_->SetParams(param_.cfg); - } + if (param_.model_in != CLIParam::kNull) { + this->LoadModel(param_.model_in, learner_.get()); + learner_->SetParams(param_.cfg); + } else { + learner_->SetParams(param_.cfg); } learner_->Configure(); - return version; } void CLITrain() { const double tstart_data_load = dmlc::GetTime(); - if (rabit::IsDistributed()) { - std::string pname = rabit::GetProcessorName(); - LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank(); + if (collective::IsDistributed()) { + std::string pname = collective::GetProcessorName(); + LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank(); } // load in data. std::shared_ptr dtrain(DMatrix::Load( @@ -230,48 +227,45 @@ class CLI { eval_data_names.emplace_back("train"); } // initialize the learner. - int32_t version = this->ResetLearner(cache_mats); + this->ResetLearner(cache_mats); LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec"; // start training. const double start = dmlc::GetTime(); + int32_t version = 0; for (int i = version / 2; i < param_.num_round; ++i) { double elapsed = dmlc::GetTime() - start; if (version % 2 == 0) { LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed"; learner_->UpdateOneIter(i, dtrain); - rabit::CheckPoint(); version += 1; } - CHECK_EQ(version, rabit::VersionNumber()); std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names); - if (rabit::IsDistributed()) { - if (rabit::GetRank() == 0) { + if (collective::IsDistributed()) { + if (collective::GetRank() == 0) { LOG(TRACKER) << res; } } else { LOG(CONSOLE) << res; } if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 && - rabit::GetRank() == 0) { + collective::GetRank() == 0) { std::ostringstream os; os << param_.model_dir << '/' << std::setfill('0') << std::setw(4) << i + 1 << ".model"; this->SaveModel(os.str(), learner_.get()); } - rabit::CheckPoint(); version += 1; - CHECK_EQ(version, rabit::VersionNumber()); } LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start << " sec"; // always save final round if ((param_.save_period == 0 || param_.num_round % param_.save_period != 0) && - rabit::GetRank() == 0) { + collective::GetRank() == 0) { std::ostringstream os; if (param_.model_out == CLIParam::kNull) { os << param_.model_dir << '/' << std::setfill('0') << std::setw(4) @@ -467,7 +461,6 @@ class CLI { return; } - rabit::Init(argc, argv); std::string config_path = argv[1]; common::ConfigParser cp(config_path); @@ -480,6 +473,15 @@ class CLI { } } + // Initialize the collective communicator. + Json json{JsonObject()}; + for (auto& kv: cfg) { + json[kv.first] = String(kv.second); + } + std::string json_str; + Json::Dump(json, &json_str); + collective::Init(json_str.c_str()); + param_.Configure(cfg); } @@ -517,7 +519,7 @@ class CLI { } ~CLI() { - rabit::Finalize(); + collective::Finalize(); } }; } // namespace xgboost diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h new file mode 100644 index 000000000000..d848eaef2829 --- /dev/null +++ b/src/collective/communicator-inl.h @@ -0,0 +1,198 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include "communicator.h" + +namespace xgboost { +namespace collective { + +/*! + * \brief Initialize the collective communicator. + * + * Currently the communicator API is experimental, function signatures may change in the future + * without notice. + * + * Call this once before using anything. + * + * The additional configuration is not required. Usually the communicator will detect settings + * from environment variables. + * + * \param json_config JSON encoded configuration. Accepted JSON keys are: + * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. + * * rabit: Use Rabit. This is the default if the type is unspecified. + * * mpi: Use MPI. + * * federated: Use the gRPC interface for Federated Learning. + * Only applicable to the Rabit communicator (these are case-sensitive): + * - rabit_tracker_uri: Hostname of the tracker. + * - rabit_tracker_port: Port number of the tracker. + * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. + * - rabit_world_size: Total number of workers. + * - rabit_hadoop_mode: Enable Hadoop support. + * - rabit_tree_reduce_minsize: Minimal size for tree reduce. + * - rabit_reduce_ring_mincount: Minimal count to perform ring reduce. + * - rabit_reduce_buffer: Size of the reduce buffer. + * - rabit_bootstrap_cache: Size of the bootstrap cache. + * - rabit_debug: Enable debugging. + * - rabit_timeout: Enable timeout. + * - rabit_timeout_sec: Timeout in seconds. + * - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. + * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as + * environment variables): + * - DMLC_TRACKER_URI: Hostname of the tracker. + * - DMLC_TRACKER_PORT: Port number of the tracker. + * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. + * - DMLC_ROLE: Role of the current task, "worker" or "server". + * - DMLC_NUM_ATTEMPT: Number of attempts after task failure. + * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. + * Only applicable to the Federated communicator (use upper case for environment variables, use + * lower case for runtime configuration): + * - federated_server_address: Address of the federated server. + * - federated_world_size: Number of federated workers. + * - federated_rank: Rank of the current worker. + * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. + * - federated_client_key: Client key file path. Only needed for the SSL mode. + * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. + */ +inline void Init(char const *json_config) { + Json config{Json::Load(StringView{json_config})}; + Communicator::Init(config); +} + +/*! + * \brief Finalize the collective communicator. + * + * Call this function after you finished all jobs. + */ +inline void Finalize() { Communicator::Finalize(); } + +/*! + * \brief Get rank of current process. + * + * \return Rank of the worker. + */ +inline int GetRank() { return Communicator::Get()->GetRank(); } + +/*! + * \brief Get total number of processes. + * + * \return Total world size. + */ +inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); } + +/*! + * \brief Get if the communicator is distributed. + * + * \return True if the communicator is distributed. + */ +inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); } + +/*! + * \brief Print the message to the communicator. + * + * This function can be used to communicate the information of the progress to the user who monitors + * the communicator. + * + * \param message The message to be printed. + */ +inline void Print(char const *message) { Communicator::Get()->Print(message); } + +inline void Print(std::string const &message) { Communicator::Get()->Print(message); } + +/*! + * \brief Get the name of the processor. + * + * \return Name of the processor. + */ +inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); } + +/*! + * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe. + * + * Example: + * int a = 1; + * Broadcast(&a, sizeof(a), root); + * + * \param send_receive_buffer Pointer to the send or receive buffer. + * \param size Size of the data. + * \param root The process rank to broadcast from. + */ +inline void Broadcast(void *send_receive_buffer, size_t size, int root) { + Communicator::Get()->Broadcast(send_receive_buffer, size, root); +} + +inline void Broadcast(std::string *sendrecv_data, int root) { + size_t size = sendrecv_data->length(); + Broadcast(&size, sizeof(size), root); + if (sendrecv_data->length() != size) { + sendrecv_data->resize(size); + } + if (size != 0) { + Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root); + } +} + +/*! + * \brief Perform in-place allreduce. This function is NOT thread-safe. + * + * Example Usage: the following code gives sum of the result + * vector data(10); + * ... + * Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); + * ... + * \param send_receive_buffer Buffer for both sending and receiving data. + * \param count Number of elements to be reduced. + * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. + * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. + */ +inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) { + Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast(data_type), + static_cast(op)); +} + +inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) { + Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op); +} + +template +inline void Allreduce(int8_t *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op); +} + +template +inline void Allreduce(uint8_t *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op); +} + +template +inline void Allreduce(int32_t *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op); +} + +template +inline void Allreduce(uint32_t *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op); +} + +template +inline void Allreduce(int64_t *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op); +} + +template +inline void Allreduce(uint64_t *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); +} + +template +inline void Allreduce(float *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op); +} + +template +inline void Allreduce(double *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op); +} + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index 73765223b225..ad5de231c6e8 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -3,6 +3,7 @@ */ #include "communicator.h" +#include "noop_communicator.h" #include "rabit_communicator.h" #if defined(XGBOOST_USE_FEDERATED) @@ -12,14 +13,10 @@ namespace xgboost { namespace collective { -thread_local std::unique_ptr Communicator::communicator_{}; +thread_local std::unique_ptr Communicator::communicator_{new NoOpCommunicator()}; thread_local CommunicatorType Communicator::type_{}; void Communicator::Init(Json const& config) { - if (communicator_) { - LOG(FATAL) << "Communicator can only be initialized once."; - } - auto type = GetTypeFromEnv(); auto const arg = GetTypeFromConfig(config); if (arg != CommunicatorType::kUnknown) { diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 14ee201618cd..9c0d98a7e8c2 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -23,40 +23,6 @@ enum class DataType { kDouble = 7 }; -/** @brief Get the size of the data type. */ -inline std::size_t GetTypeSize(DataType data_type) { - std::size_t size{0}; - switch (data_type) { - case DataType::kInt8: - size = sizeof(std::int8_t); - break; - case DataType::kUInt8: - size = sizeof(std::uint8_t); - break; - case DataType::kInt32: - size = sizeof(std::int32_t); - break; - case DataType::kUInt32: - size = sizeof(std::uint32_t); - break; - case DataType::kInt64: - size = sizeof(std::int64_t); - break; - case DataType::kUInt64: - size = sizeof(std::uint64_t); - break; - case DataType::kFloat: - size = sizeof(float); - break; - case DataType::kDouble: - size = sizeof(double); - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return size; -} - /** @brief Defines the reduction operation. */ enum class Operation { kMax = 0, kMin = 1, kSum = 2 }; diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh index 15d18cead02f..07664213a698 100644 --- a/src/collective/device_communicator.cuh +++ b/src/collective/device_communicator.cuh @@ -21,7 +21,21 @@ class DeviceCommunicator { * @param send_receive_buffer Buffer storing the data. * @param count Number of elements in the buffer. */ - virtual void AllReduceSum(double *send_receive_buffer, int count) = 0; + virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0; + + /** + * @brief Sum values from all processes and distribute the result back to all processes. + * @param send_receive_buffer Buffer storing the data. + * @param count Number of elements in the buffer. + */ + virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0; + + /** + * @brief Sum values from all processes and distribute the result back to all processes. + * @param send_receive_buffer Buffer storing the data. + * @param count Number of elements in the buffer. + */ + virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0; /** * @brief Gather variable-length values from all processes. diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index 794049bfcb1c..43ed51cad4a5 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -23,17 +23,24 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { ~DeviceCommunicatorAdapter() override = default; - void AllReduceSum(double *send_receive_buffer, int count) override { - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto size = count * sizeof(double); - host_buffer_.reserve(size); - dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); - communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum); - dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); + void AllReduceSum(float *send_receive_buffer, size_t count) override { + DoAllReduceSum(send_receive_buffer, count); + } + + void AllReduceSum(double *send_receive_buffer, size_t count) override { + DoAllReduceSum(send_receive_buffer, count); + } + + void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override { + DoAllReduceSum(send_receive_buffer, count); } void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, dh::caching_device_vector *receive_buffer) override { + if (communicator_->GetWorldSize() == 1) { + return; + } + dh::safe_cuda(cudaSetDevice(device_ordinal_)); int const world_size = communicator_->GetWorldSize(); int const rank = communicator_->GetRank(); @@ -66,6 +73,20 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { } private: + template + void DoAllReduceSum(T *send_receive_buffer, size_t count) { + if (communicator_->GetWorldSize() == 1) { + return; + } + + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + auto size = count * sizeof(T); + host_buffer_.reserve(size); + dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); + communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum); + dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); + } + int const device_ordinal_; Communicator *communicator_; /// Host buffer used to call communicator functions. diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index ad9f57589c53..a28c66f3555f 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -24,6 +24,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator { int32_t const rank = communicator_->GetRank(); int32_t const world = communicator_->GetWorldSize(); + if (world == 1) { + return; + } + std::vector uuids(world * kUuidLength, 0); auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); @@ -52,6 +56,9 @@ class NcclDeviceCommunicator : public DeviceCommunicator { } ~NcclDeviceCommunicator() override { + if (communicator_->GetWorldSize() == 1) { + return; + } dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); ncclCommDestroy(nccl_comm_); if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { @@ -61,16 +68,24 @@ class NcclDeviceCommunicator : public DeviceCommunicator { } } - void AllReduceSum(double *send_receive_buffer, int count) override { - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble, - ncclSum, nccl_comm_, cuda_stream_)); - allreduce_bytes_ += count * sizeof(double); - allreduce_calls_ += 1; + void AllReduceSum(float *send_receive_buffer, size_t count) override { + DoAllReduceSum(send_receive_buffer, count); + } + + void AllReduceSum(double *send_receive_buffer, size_t count) override { + DoAllReduceSum(send_receive_buffer, count); + } + + void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override { + DoAllReduceSum(send_receive_buffer, count); } void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, dh::caching_device_vector *receive_buffer) override { + if (communicator_->GetWorldSize() == 1) { + return; + } + dh::safe_cuda(cudaSetDevice(device_ordinal_)); int const world_size = communicator_->GetWorldSize(); int const rank = communicator_->GetRank(); @@ -95,6 +110,9 @@ class NcclDeviceCommunicator : public DeviceCommunicator { } void Synchronize() override { + if (communicator_->GetWorldSize() == 1) { + return; + } dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); } @@ -136,6 +154,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator { return id; } + template + void DoAllReduceSum(T *send_receive_buffer, size_t count) { + if (communicator_->GetWorldSize() == 1) { + return; + } + + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum, + nccl_comm_, cuda_stream_)); + allreduce_bytes_ += count * sizeof(T); + allreduce_calls_ += 1; + } + int const device_ordinal_; Communicator *communicator_; ncclComm_t nccl_comm_{}; diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h new file mode 100644 index 000000000000..0bc4a4948c35 --- /dev/null +++ b/src/collective/noop_communicator.h @@ -0,0 +1,28 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include "communicator.h" + +namespace xgboost { +namespace collective { + +/** + * A no-op communicator, used for non-distributed training. + */ +class NoOpCommunicator : public Communicator { + public: + NoOpCommunicator() : Communicator(1, 0) {} + bool IsDistributed() const override { return false; } + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op) override {} + void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {} + std::string GetProcessorName() override { return ""; } + void Print(const std::string &message) override { LOG(CONSOLE) << message; } + + protected: + void Shutdown() override {} +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/common/device_helpers.cu b/src/common/device_helpers.cu deleted file mode 100644 index 12ee25e87ee3..000000000000 --- a/src/common/device_helpers.cu +++ /dev/null @@ -1,139 +0,0 @@ -/*! - * Copyright 2017-2019 XGBoost contributors - * - * \brief Utilities for CUDA. - */ -#ifdef XGBOOST_USE_NCCL -#include -#endif // #ifdef XGBOOST_USE_NCCL -#include - -#include "device_helpers.cuh" - -namespace dh { - -constexpr std::size_t kUuidLength = - sizeof(std::declval().uuid) / sizeof(uint64_t); - -void GetCudaUUID(int device_ord, xgboost::common::Span uuid) { - cudaDeviceProp prob; - safe_cuda(cudaGetDeviceProperties(&prob, device_ord)); - std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); -} - -std::string PrintUUID(xgboost::common::Span uuid) { - std::stringstream ss; - for (auto v : uuid) { - ss << std::hex << v; - } - return ss.str(); -} - -#ifdef XGBOOST_USE_NCCL -void NcclAllReducer::DoInit(int _device_ordinal) { - int32_t const rank = rabit::GetRank(); - int32_t const world = rabit::GetWorldSize(); - if (world == 1) { - return; - } - - std::vector uuids(world * kUuidLength, 0); - auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; - auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); - GetCudaUUID(_device_ordinal, s_this_uuid); - - // No allgather yet. - rabit::Allreduce(uuids.data(), uuids.size()); - - std::vector> converted(world);; - size_t j = 0; - for (size_t i = 0; i < uuids.size(); i += kUuidLength) { - converted[j] = - xgboost::common::Span{uuids.data() + i, kUuidLength}; - j++; - } - - auto iter = std::unique(converted.begin(), converted.end()); - auto n_uniques = std::distance(converted.begin(), iter); - - CHECK_EQ(n_uniques, world) - << "Multiple processes within communication group running on same CUDA " - << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; - - - id_ = GetUniqueId(); - dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank)); - safe_cuda(cudaStreamCreate(&stream_)); -} - -void NcclAllReducer::DoAllGather(void const *data, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *recvbuf) { - int32_t world = rabit::GetWorldSize(); - segments->clear(); - segments->resize(world, 0); - segments->at(rabit::GetRank()) = length_bytes; - rabit::Allreduce(segments->data(), segments->size()); - auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0); - recvbuf->resize(total_bytes); - - size_t offset = 0; - safe_nccl(ncclGroupStart()); - for (int32_t i = 0; i < world; ++i) { - size_t as_bytes = segments->at(i); - safe_nccl( - ncclBroadcast(data, recvbuf->data().get() + offset, - as_bytes, ncclChar, i, comm_, stream_)); - offset += as_bytes; - } - safe_nccl(ncclGroupEnd()); -} - -NcclAllReducer::~NcclAllReducer() { - if (initialised_) { - dh::safe_cuda(cudaStreamDestroy(stream_)); - ncclCommDestroy(comm_); - } - if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { - LOG(CONSOLE) << "======== NCCL Statistics========"; - LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; - LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576; - } -} -#else -void RabitAllReducer::DoInit(int _device_ordinal) { -#if !defined(XGBOOST_USE_FEDERATED) - if (rabit::IsDistributed()) { - LOG(CONSOLE) << "XGBoost is not compiled with NCCL, falling back to Rabit."; - } -#endif -} - -void RabitAllReducer::DoAllGather(void const *data, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *recvbuf) { - size_t world = rabit::GetWorldSize(); - segments->clear(); - segments->resize(world, 0); - segments->at(rabit::GetRank()) = length_bytes; - rabit::Allreduce(segments->data(), segments->size()); - auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); - recvbuf->resize(total_bytes); - - sendrecvbuf_.reserve(total_bytes); - auto rank = rabit::GetRank(); - size_t offset = 0; - for (int32_t i = 0; i < world; ++i) { - size_t as_bytes = segments->at(i); - if (i == rank) { - safe_cuda( - cudaMemcpy(sendrecvbuf_.data() + offset, data, segments->at(rank), cudaMemcpyDefault)); - } - rabit::Broadcast(sendrecvbuf_.data() + offset, as_bytes, i); - offset += as_bytes; - } - safe_cuda(cudaMemcpy(recvbuf->data().get(), sendrecvbuf_.data(), total_bytes, cudaMemcpyDefault)); -} -#endif // XGBOOST_USE_NCCL - -} // namespace dh diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 754e47ff40ca..6e9781d7da75 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -17,7 +17,6 @@ #include #include -#include #include #include @@ -34,6 +33,7 @@ #include "xgboost/span.h" #include "xgboost/global_config.h" +#include "../collective/communicator-inl.h" #include "common.h" #include "algorithm.cuh" @@ -402,7 +402,7 @@ inline detail::MemoryLogger &GlobalMemoryLogger() { // dh::DebugSyncDevice(__FILE__, __LINE__); inline void DebugSyncDevice(std::string file="", int32_t line = -1) { if (file != "" && line != -1) { - auto rank = rabit::GetRank(); + auto rank = xgboost::collective::GetRank(); LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line; } safe_cuda(cudaDeviceSynchronize()); @@ -421,7 +421,7 @@ using XGBBaseDeviceAllocator = thrust::device_malloc_allocator; inline void ThrowOOMError(std::string const& err, size_t bytes) { auto device = CurrentDevice(); - auto rank = rabit::GetRank(); + auto rank = xgboost::collective::GetRank(); std::stringstream ss; ss << "Memory allocation error on worker " << rank << ": " << err << "\n" << "- Free memory: " << AvailableMemory(device) << "\n" @@ -735,512 +735,6 @@ using TypedDiscard = std::conditional_t(), detail::TypedDiscardCTK114, detail::TypedDiscard>; -/** - * \class AllReducer - * - * \brief All reducer class that manages its own communication group and - * streams. Must be initialised before use. If XGBoost is compiled without NCCL, - * this falls back to use Rabit. - */ -template -class AllReducerBase : public xgboost::common::Crtp { - public: - virtual ~AllReducerBase() = default; - - /** - * \brief Initialise with the desired device ordinal for this allreducer. - * - * \param device_ordinal The device ordinal. - */ - void Init(int _device_ordinal) { - device_ordinal_ = _device_ordinal; - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - if (rabit::GetWorldSize() == 1) { - return; - } - this->Underlying().DoInit(_device_ordinal); - initialised_ = true; - } - - /** - * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept - * different size of data on different workers. - * - * \param data Buffer storing the input data. - * \param length_bytes Size of input data in bytes. - * \param segments Size of data on each worker. - * \param recvbuf Buffer storing the result of data from all workers. - */ - void AllGather(void const *data, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *recvbuf) { - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf); - } - - /** - * \brief Allgather. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param data Buffer storing the input data. - * \param length Size of input data in bytes. - * \param recvbuf Buffer storing the result of data from all workers. - */ - void AllGather(uint32_t const *data, size_t length, - dh::caching_device_vector *recvbuf) { - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoAllGather(data, length, recvbuf); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void AllReduceSum(const double *sendbuff, double *recvbuff, int count) { - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); - allreduce_bytes_ += count * sizeof(double); - allreduce_calls_ += 1; - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void AllReduceSum(const float *sendbuff, float *recvbuff, int count) { - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); - allreduce_bytes_ += count * sizeof(float); - allreduce_calls_ += 1; - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms. - * - * \param count Number of. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of. - */ - void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) { - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); - allreduce_bytes_ += count * sizeof(int64_t); - allreduce_calls_ += 1; - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) { - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); - allreduce_bytes_ += count * sizeof(uint32_t); - allreduce_calls_ += 1; - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) { - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); - allreduce_bytes_ += count * sizeof(uint64_t); - allreduce_calls_ += 1; - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * Specialization for size_t, which is implementation defined so it might or might not - * be one of uint64_t/uint32_t/unsigned long long/unsigned long. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - template ::value && - !std::is_same::value> // NOLINT - * = nullptr> - void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT - if (rabit::GetWorldSize() == 1) { - return; - } - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT - this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count); - allreduce_bytes_ += count * sizeof(T); - allreduce_calls_ += 1; - } - - /** - * \fn void Synchronize() - * - * \brief Synchronizes the entire communication group. - */ - void Synchronize() { - CHECK(initialised_); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - this->Underlying().DoSynchronize(); - } - - protected: - bool initialised_{false}; - size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. - size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. - - private: - int device_ordinal_{-1}; -}; - -#ifdef XGBOOST_USE_NCCL -class NcclAllReducer : public AllReducerBase { - public: - friend class AllReducerBase; - - ~NcclAllReducer() override; - - private: - /** - * \brief Initialise with the desired device ordinal for this communication - * group. - * - * \param device_ordinal The device ordinal. - */ - void DoInit(int _device_ordinal); - - /** - * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept - * different size of data on different workers. - * - * \param data Buffer storing the input data. - * \param length_bytes Size of input data in bytes. - * \param segments Size of data on each worker. - * \param recvbuf Buffer storing the result of data from all workers. - */ - void DoAllGather(void const *data, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *recvbuf); - - /** - * \brief Allgather. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param data Buffer storing the input data. - * \param length Size of input data in bytes. - * \param recvbuf Buffer storing the result of data from all workers. - */ - void DoAllGather(uint32_t const *data, size_t length, - dh::caching_device_vector *recvbuf) { - size_t world = rabit::GetWorldSize(); - recvbuf->resize(length * world); - safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, comm_, stream_)); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) { - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm_, stream_)); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) { - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm_, stream_)); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms. - * - * \param count Number of. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of. - */ - void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) { - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm_, stream_)); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) { - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_)); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) { - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL but without needing - * streams or comms. - * - * Specialization for size_t, which is implementation defined so it might or might not - * be one of uint64_t/uint32_t/unsigned long long/unsigned long. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - template ::value && - !std::is_same::value> // NOLINT - * = nullptr> - void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT - dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); - } - - /** - * \brief Synchronizes the entire communication group. - */ - void DoSynchronize() { dh::safe_cuda(cudaStreamSynchronize(stream_)); } - - /** - * \fn ncclUniqueId GetUniqueId() - * - * \brief Gets the Unique ID from NCCL to be used in setting up interprocess - * communication - * - * \return the Unique ID - */ - ncclUniqueId GetUniqueId() { - static const int kRootRank = 0; - ncclUniqueId id; - if (rabit::GetRank() == kRootRank) { - dh::safe_nccl(ncclGetUniqueId(&id)); - } - rabit::Broadcast(static_cast(&id), sizeof(ncclUniqueId), static_cast(kRootRank)); - return id; - } - - ncclComm_t comm_; - cudaStream_t stream_; - ncclUniqueId id_; -}; - -using AllReducer = NcclAllReducer; -#else -class RabitAllReducer : public AllReducerBase { - public: - friend class AllReducerBase; - - private: - /** - * \brief Initialise with the desired device ordinal for this allreducer. - * - * \param device_ordinal The device ordinal. - */ - static void DoInit(int _device_ordinal); - - /** - * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept - * different size of data on different workers. - * - * \param data Buffer storing the input data. - * \param length_bytes Size of input data in bytes. - * \param segments Size of data on each worker. - * \param recvbuf Buffer storing the result of data from all workers. - */ - void DoAllGather(void const *data, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *recvbuf); - - /** - * \brief Allgather. Use in exactly the same way as NCCL. - * - * \param data Buffer storing the input data. - * \param length Size of input data in bytes. - * \param recvbuf Buffer storing the result of data from all workers. - */ - void DoAllGather(uint32_t *data, size_t length, dh::caching_device_vector *recvbuf) { - size_t world = rabit::GetWorldSize(); - auto total_size = length * world; - recvbuf->resize(total_size); - sendrecvbuf_.reserve(total_size); - auto rank = rabit::GetRank(); - safe_cuda(cudaMemcpy(sendrecvbuf_.data() + rank * length, data, length, cudaMemcpyDefault)); - rabit::Allgather(sendrecvbuf_.data(), total_size, rank * length, length, length); - safe_cuda(cudaMemcpy(data, sendrecvbuf_.data(), total_size, cudaMemcpyDefault)); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const double *sendbuff, double *recvbuff, int count) { - RabitAllReduceSum(sendbuff, recvbuff, count); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const float *sendbuff, float *recvbuff, int count) { - RabitAllReduceSum(sendbuff, recvbuff, count); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) { - RabitAllReduceSum(sendbuff, recvbuff, count); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) { - RabitAllReduceSum(sendbuff, recvbuff, count); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - void DoAllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) { - RabitAllReduceSum(sendbuff, recvbuff, count); - } - - /** - * \brief Allreduce. Use in exactly the same way as NCCL. - * - * Specialization for size_t, which is implementation defined so it might or might not - * be one of uint64_t/uint32_t/unsigned long long/unsigned long. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - template ::value && - !std::is_same::value> // NOLINT - * = nullptr> - void DoAllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT - RabitAllReduceSum(sendbuff, recvbuff, count); - } - - /** - * \brief Synchronizes the allreducer. - */ - void DoSynchronize() {} - - /** - * \brief Allreduce. Use in exactly the same way as NCCL. - * - * Copy the device buffer to host, call rabit allreduce, then copy the buffer back - * to device. - * - * \param sendbuff The sendbuff. - * \param recvbuff The recvbuff. - * \param count Number of elements. - */ - template - void RabitAllReduceSum(const T *sendbuff, T *recvbuff, int count) { - auto total_size = count * sizeof(T); - sendrecvbuf_.reserve(total_size); - safe_cuda(cudaMemcpy(sendrecvbuf_.data(), sendbuff, total_size, cudaMemcpyDefault)); - rabit::Allreduce(reinterpret_cast(sendrecvbuf_.data()), count); - safe_cuda(cudaMemcpy(recvbuff, sendrecvbuf_.data(), total_size, cudaMemcpyDefault)); - } - - /// Host buffer used to call rabit functions. - std::vector sendrecvbuf_{}; -}; - -using AllReducer = RabitAllReducer; -#endif - template ::index_type> xgboost::common::Span ToSpan( diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 180588c31b8c..3fa8c66b1a71 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -6,6 +6,7 @@ #include #include +#include "../collective/communicator-inl.h" #include "../data/adapter.h" #include "categorical.h" #include "hist_util.h" @@ -144,8 +145,8 @@ struct QuantileAllreduce { void AllreduceCategories(Span feature_types, int32_t n_threads, std::vector> *p_categories) { auto &categories = *p_categories; - auto world_size = rabit::GetWorldSize(); - auto rank = rabit::GetRank(); + auto world_size = collective::GetWorldSize(); + auto rank = collective::GetRank(); if (world_size == 1) { return; } @@ -163,7 +164,8 @@ void AllreduceCategories(Span feature_types, int32_t n_thread std::vector global_feat_ptrs(feature_ptr.size() * world_size, 0); size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin); - rabit::Allreduce(global_feat_ptrs.data(), global_feat_ptrs.size()); + collective::Allreduce(global_feat_ptrs.data(), + global_feat_ptrs.size()); // move all categories into a flatten vector to prepare for allreduce size_t total = feature_ptr.back(); @@ -176,7 +178,8 @@ void AllreduceCategories(Span feature_types, int32_t n_thread // indptr for indexing workers std::vector global_worker_ptr(world_size + 1, 0); global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr - rabit::Allreduce(global_worker_ptr.data(), global_worker_ptr.size()); + collective::Allreduce(global_worker_ptr.data(), + global_worker_ptr.size()); std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin()); // total number of categories in all workers with all features auto gtotal = global_worker_ptr.back(); @@ -188,7 +191,8 @@ void AllreduceCategories(Span feature_types, int32_t n_thread CHECK_EQ(rank_size, total); std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin); // gather values from all workers. - rabit::Allreduce(global_categories.data(), global_categories.size()); + collective::Allreduce(global_categories.data(), + global_categories.size()); QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs, categories.size()}; ParallelFor(categories.size(), n_threads, [&](auto fidx) { @@ -217,8 +221,8 @@ void SketchContainerImpl::GatherSketchInfo( std::vector *p_global_sketches) { auto &worker_segments = *p_worker_segments; worker_segments.resize(1, 0); - auto world = rabit::GetWorldSize(); - auto rank = rabit::GetRank(); + auto world = collective::GetWorldSize(); + auto rank = collective::GetRank(); auto n_columns = sketches_.size(); // get the size of each feature. @@ -237,7 +241,7 @@ void SketchContainerImpl::GatherSketchInfo( std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1); // Gather all column pointers - rabit::Allreduce(sketches_scan.data(), sketches_scan.size()); + collective::Allreduce(sketches_scan.data(), sketches_scan.size()); for (int32_t i = 0; i < world; ++i) { size_t back = (i + 1) * (n_columns + 1) - 1; auto n_entries = sketches_scan.at(back); @@ -265,7 +269,7 @@ void SketchContainerImpl::GatherSketchInfo( static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float), "Unexpected size of sketch entry."); - rabit::Allreduce( + collective::Allreduce( reinterpret_cast(global_sketches.data()), global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)); } @@ -277,7 +281,7 @@ void SketchContainerImpl::AllReduce( monitor_.Start(__func__); size_t n_columns = sketches_.size(); - rabit::Allreduce(&n_columns, 1); + collective::Allreduce(&n_columns, 1); CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers"; AllreduceCategories(feature_types_, n_threads_, &categories_); @@ -291,7 +295,8 @@ void SketchContainerImpl::AllReduce( // Prune the intermediate num cuts for synchronization. std::vector global_column_size(columns_size_); - rabit::Allreduce(global_column_size.data(), global_column_size.size()); + collective::Allreduce(global_column_size.data(), + global_column_size.size()); ParallelFor(sketches_.size(), n_threads_, [&](size_t i) { int32_t intermediate_num_cuts = static_cast( @@ -311,7 +316,7 @@ void SketchContainerImpl::AllReduce( num_cuts[i] = intermediate_num_cuts; }); - auto world = rabit::GetWorldSize(); + auto world = collective::GetWorldSize(); if (world == 1) { monitor_.Stop(__func__); return; diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 5f69eafb300e..39589bf69121 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -500,19 +500,18 @@ void SketchContainer::FixError() { void SketchContainer::AllReduce() { dh::safe_cuda(cudaSetDevice(device_)); - auto world = rabit::GetWorldSize(); + auto world = collective::GetWorldSize(); if (world == 1) { return; } timer_.Start(__func__); - if (!reducer_) { - reducer_ = std::make_unique(); - reducer_->Init(device_); + if (!communicator_) { + communicator_ = collective::Communicator::GetDevice(device_); } // Reduce the overhead on syncing. size_t global_sum_rows = num_rows_; - rabit::Allreduce(&global_sum_rows, 1); + collective::Allreduce(&global_sum_rows, 1); size_t intermediate_num_cuts = std::min(global_sum_rows, static_cast(num_bins_ * kFactor)); this->Prune(intermediate_num_cuts); @@ -520,26 +519,24 @@ void SketchContainer::AllReduce() { auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1); size_t n = d_columns_ptr.size(); - rabit::Allreduce(&n, 1); + collective::Allreduce(&n, 1); CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers"; // Get the columns ptr from all workers dh::device_vector gathered_ptrs; gathered_ptrs.resize(d_columns_ptr.size() * world, 0); - size_t rank = rabit::GetRank(); + size_t rank = collective::GetRank(); auto offset = rank * d_columns_ptr.size(); thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(), gathered_ptrs.begin() + offset); - reducer_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.data().get(), - gathered_ptrs.size()); + communicator_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size()); // Get the data from all workers. std::vector recv_lengths; dh::caching_device_vector recvbuf; - reducer_->AllGather(this->Current().data().get(), - dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, - &recvbuf); - reducer_->Synchronize(); + communicator_->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(), + &recv_lengths, &recvbuf); + communicator_->Synchronize(); // Segment the received data. auto s_recvbuf = dh::ToSpan(recvbuf); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index be8ea1834caf..0bd45e09e563 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -5,6 +5,7 @@ #include "xgboost/span.h" #include "xgboost/data.h" +#include "../collective/device_communicator.cuh" #include "device_helpers.cuh" #include "quantile.h" #include "timer.h" @@ -37,7 +38,7 @@ class SketchContainer { private: Monitor timer_; - std::unique_ptr reducer_; + collective::DeviceCommunicator* communicator_; HostDeviceVector feature_types_; bst_row_t num_rows_; bst_feature_t num_columns_; diff --git a/src/common/random.h b/src/common/random.h index c5d38f3394eb..1467aa57c177 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -9,18 +9,20 @@ #include #include + #include #include -#include #include #include #include #include #include #include +#include -#include "xgboost/host_device_vector.h" +#include "../collective/communicator-inl.h" #include "common.h" +#include "xgboost/host_device_vector.h" namespace xgboost { namespace common { @@ -143,7 +145,7 @@ class ColumnSampler { */ ColumnSampler() { uint32_t seed = common::GlobalRandom()(); - rabit::Broadcast(&seed, sizeof(seed), 0); + collective::Broadcast(&seed, sizeof(seed), 0); rng_.seed(seed); } diff --git a/src/common/timer.cc b/src/common/timer.cc index d711446d3016..6e5585171f0e 100644 --- a/src/common/timer.cc +++ b/src/common/timer.cc @@ -1,13 +1,17 @@ /*! * Copyright by Contributors 2019 */ +#include "timer.h" + #include + #include +#include #include #include #include -#include -#include "timer.h" + +#include "../collective/communicator-inl.h" #if defined(XGBOOST_USE_NVTX) #include @@ -54,7 +58,7 @@ void Monitor::PrintStatistics(StatMap const& statistics) const { void Monitor::Print() const { if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; } - auto rank = rabit::GetRank(); + auto rank = collective::GetRank(); StatMap stat_map; for (auto const &kv : statistics_map_) { stat_map[kv.first] = std::make_pair( diff --git a/src/data/data.cc b/src/data/data.cc index 71bb2b32b462..715684417928 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -2,36 +2,36 @@ * Copyright 2015-2022 by XGBoost Contributors * \file data.cc */ +#include "xgboost/data.h" + #include + #include #include -#include "dmlc/io.h" -#include "xgboost/data.h" -#include "xgboost/c_api.h" -#include "xgboost/host_device_vector.h" -#include "xgboost/logging.h" -#include "xgboost/version_config.h" -#include "xgboost/learner.h" -#include "xgboost/string_view.h" - -#include "sparse_page_writer.h" -#include "simple_dmatrix.h" - +#include "../collective/communicator-inl.h" +#include "../common/group_data.h" #include "../common/io.h" #include "../common/linalg_op.h" #include "../common/math.h" #include "../common/numeric.h" -#include "../common/version.h" -#include "../common/group_data.h" #include "../common/threading_utils.h" +#include "../common/version.h" #include "../data/adapter.h" #include "../data/iterative_dmatrix.h" +#include "./sparse_page_dmatrix.h" +#include "./sparse_page_source.h" +#include "dmlc/io.h" #include "file_iterator.h" - +#include "simple_dmatrix.h" +#include "sparse_page_writer.h" #include "validation.h" -#include "./sparse_page_source.h" -#include "./sparse_page_dmatrix.h" +#include "xgboost/c_api.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/learner.h" +#include "xgboost/logging.h" +#include "xgboost/string_view.h" +#include "xgboost/version_config.h" namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>); @@ -792,12 +792,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, size_t pos = cache_shards[i].rfind('.'); if (pos == std::string::npos) { os << cache_shards[i] - << ".r" << rabit::GetRank() - << "-" << rabit::GetWorldSize(); + << ".r" << collective::GetRank() + << "-" << collective::GetWorldSize(); } else { os << cache_shards[i].substr(0, pos) - << ".r" << rabit::GetRank() - << "-" << rabit::GetWorldSize() + << ".r" << collective::GetRank() + << "-" << collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length()); } if (i + 1 != cache_shards.size()) { @@ -820,8 +820,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, int partid = 0, npart = 1; if (load_row_split) { - partid = rabit::GetRank(); - npart = rabit::GetWorldSize(); + partid = collective::GetRank(); + npart = collective::GetWorldSize(); } else { // test option to load in part npart = 1; @@ -876,7 +876,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, /* sync up number of features after matrix loaded. * partitioned data will fail the train/val validation check * since partitioned data not knowing the real number of features. */ - rabit::Allreduce(&dmat->Info().num_col_, 1); + collective::Allreduce(&dmat->Info().num_col_, 1); return dmat; } diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 30583a9439bc..2e86cd957380 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -5,6 +5,7 @@ #include +#include "../collective/communicator-inl.h" #include "../common/column_matrix.h" #include "../common/hist_util.h" #include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. @@ -138,7 +139,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, // We use do while here as the first batch is fetched in ctor if (n_features == 0) { n_features = num_cols(); - rabit::Allreduce(&n_features, 1); + collective::Allreduce(&n_features, 1); column_sizes.resize(n_features); info_.num_col_ = n_features; } else { @@ -156,7 +157,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, // From here on Info() has the correct data shape Info().num_row_ = accumulated_rows; Info().num_nonzero_ = nnz; - rabit::Allreduce(&info_.num_col_, 1); + collective::Allreduce(&info_.num_col_, 1); CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) { return f > accumulated_rows; })) << "Something went wrong during iteration."; diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index ceb470a5c7e5..291614dc6869 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -62,7 +62,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing, dh::safe_cuda(cudaSetDevice(get_device())); if (cols == 0) { cols = num_cols(); - rabit::Allreduce(&cols, 1); + collective::Allreduce(&cols, 1); this->info_.num_col_ = cols; } else { CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; @@ -163,7 +163,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing, iter.Reset(); // Synchronise worker columns - rabit::Allreduce(&info_.num_col_, 1); + collective::Allreduce(&info_.num_col_, 1); } BatchSet IterativeDMatrix::GetEllpackBatches(BatchParam const& param) { diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index a373ff0196e5..584f11d72b3c 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -181,7 +181,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { // Synchronise worker columns - rabit::Allreduce(&info_.num_col_, 1); + collective::Allreduce(&info_.num_col_, 1); if (adapter->NumRows() == kAdapterUnknownSize) { using IteratorAdapterT @@ -312,7 +312,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i } // Synchronise worker columns info_.num_col_ = adapter->NumColumns(); - rabit::Allreduce(&info_.num_col_, 1); + collective::Allreduce(&info_.num_col_, 1); info_.num_row_ = total_batch_size; info_.num_nonzero_ = data_vec.size(); CHECK_EQ(offset_vec.back(), info_.num_nonzero_); diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 12f44fb85db1..64f308b8c2bd 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -35,7 +35,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); // Synchronise worker columns - rabit::Allreduce(&info_.num_col_, 1); + collective::Allreduce(&info_.num_col_, 1); } template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing, diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index a90150ce8c83..df9d99fdd6d7 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -5,6 +5,8 @@ * \author Tianqi Chen */ #include "./sparse_page_dmatrix.h" + +#include "../collective/communicator-inl.h" #include "./simple_batch_iterator.h" #include "gradient_index.h" @@ -46,8 +48,8 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p cache_prefix_{std::move(cache_prefix)} { ctx_.nthread = nthreads; cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_; - if (rabit::IsDistributed()) { - cache_prefix_ += ("-r" + std::to_string(rabit::GetRank())); + if (collective::IsDistributed()) { + cache_prefix_ += ("-r" + std::to_string(collective::GetRank())); } DMatrixProxy *proxy = MakeProxy(proxy_); auto iter = DataIterProxy{ @@ -94,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p this->info_.num_col_ = n_features; this->info_.num_nonzero_ = nnz; - rabit::Allreduce(&info_.num_col_, 1); + collective::Allreduce(&info_.num_col_, 1); CHECK_NE(info_.num_col_, 0); } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index a4106888f240..e9e888a98e46 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -135,7 +135,7 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) { return; } - if (rabit::IsDistributed()) { + if (collective::IsDistributed()) { LOG(INFO) << "Tree method is automatically selected to be 'approx' " "for distributed training."; tparam_.tree_method = TreeMethod::kApprox; diff --git a/src/learner.cc b/src/learner.cc index 2ee83fb71cbf..b411524f0ef1 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -23,6 +23,7 @@ #include #include +#include "collective/communicator-inl.h" #include "common/charconv.h" #include "common/common.h" #include "common/io.h" @@ -476,7 +477,7 @@ class LearnerConfiguration : public Learner { // add additional parameters // These are cosntraints that need to be satisfied. - if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) { + if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) { tparam_.dsplit = DataSplitMode::kRow; } @@ -755,7 +756,7 @@ class LearnerConfiguration : public Learner { num_feature = std::max(num_feature, static_cast(num_col)); } - rabit::Allreduce(&num_feature, 1); + collective::Allreduce(&num_feature, 1); if (num_feature > mparam_.num_feature) { mparam_.num_feature = num_feature; } @@ -1081,7 +1082,7 @@ class LearnerIO : public LearnerConfiguration { cfg_.insert(n.cbegin(), n.cend()); // copy dsplit from config since it will not run again during restore - if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) { + if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) { tparam_.dsplit = DataSplitMode::kRow; } @@ -1226,7 +1227,7 @@ class LearnerImpl : public LearnerIO { } // Configuration before data is known. void CheckDataSplitMode() { - if (rabit::IsDistributed()) { + if (collective::IsDistributed()) { CHECK(tparam_.dsplit != DataSplitMode::kAuto) << "Precondition violated; dsplit cannot be 'auto' in distributed mode"; if (tparam_.dsplit == DataSplitMode::kCol) { @@ -1486,7 +1487,7 @@ class LearnerImpl : public LearnerIO { } if (p_fmat->Info().num_row_ == 0) { - LOG(WARNING) << "Empty dataset at worker: " << rabit::GetRank(); + LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank(); } } diff --git a/src/logging.cc b/src/logging.cc index d689ae34c4a1..2fe2d0f10e2f 100644 --- a/src/logging.cc +++ b/src/logging.cc @@ -13,6 +13,8 @@ #include "xgboost/logging.h" #include "xgboost/json.h" +#include "collective/communicator-inl.h" + #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 // Override logging mechanism for non-R interfaces void dmlc::CustomLogMessage::Log(const std::string& msg) { @@ -32,7 +34,7 @@ ConsoleLogger::~ConsoleLogger() { TrackerLogger::~TrackerLogger() { log_stream_ << '\n'; - rabit::TrackerPrint(log_stream_.str()); + collective::Print(log_stream_.str()); } } // namespace xgboost diff --git a/src/metric/auc.cc b/src/metric/auc.cc index d7a7fff4ada3..e564134c4347 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -1,27 +1,27 @@ /*! * Copyright 2021 by XGBoost Contributors */ +#include "auc.h" + +#include #include #include -#include #include #include #include #include -#include #include +#include #include -#include "rabit/rabit.h" -#include "xgboost/linalg.h" -#include "xgboost/host_device_vector.h" -#include "xgboost/metric.h" - -#include "auc.h" - +#include "../collective/communicator-inl.h" #include "../common/common.h" #include "../common/math.h" #include "../common/threading_utils.h" +#include "rabit/rabit.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/linalg.h" +#include "xgboost/metric.h" namespace xgboost { namespace metric { @@ -117,7 +117,8 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, // we have 2 averages going in here, first is among workers, second is among // classes. allreduce sums up fp/tp auc for each class. - rabit::Allreduce(results.Values().data(), results.Values().size()); + collective::Allreduce(results.Values().data(), + results.Values().size()); double auc_sum{0}; double tp_sum{0}; for (size_t c = 0; c < n_classes; ++c) { @@ -265,7 +266,7 @@ class EvalAUC : public Metric { } // We use the global size to handle empty dataset. std::array meta{info.labels.Size(), preds.Size()}; - rabit::Allreduce(meta.data(), meta.size()); + collective::Allreduce(meta.data(), meta.size()); if (meta[0] == 0) { // Empty across all workers, which is not supported. auc = std::numeric_limits::quiet_NaN(); @@ -287,7 +288,7 @@ class EvalAUC : public Metric { } std::array results{auc, static_cast(valid_groups)}; - rabit::Allreduce(results.data(), results.size()); + collective::Allreduce(results.data(), results.size()); auc = results[0]; valid_groups = static_cast(results[1]); @@ -316,7 +317,7 @@ class EvalAUC : public Metric { } double local_area = fp * tp; std::array result{auc, local_area}; - rabit::Allreduce(result.data(), result.size()); + collective::Allreduce(result.data(), result.size()); std::tie(auc, local_area) = common::UnpackArr(std::move(result)); if (local_area <= 0) { // the dataset across all workers have only positive or negative sample diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 536f6442d9fa..96356bbb2c55 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -15,6 +15,7 @@ #include "xgboost/span.h" #include "xgboost/data.h" #include "auc.h" +#include "../collective/device_communicator.cuh" #include "../common/device_helpers.cuh" #include "../common/ranking_utils.cuh" @@ -46,7 +47,7 @@ struct DeviceAUCCache { dh::device_vector unique_idx; // p^T: transposed prediction matrix, used by MultiClassAUC dh::device_vector predts_t; - std::unique_ptr reducer; + collective::DeviceCommunicator* communicator; void Init(common::Span predts, bool is_multi, int32_t device) { if (sorted_idx.size() != predts.size()) { @@ -58,9 +59,8 @@ struct DeviceAUCCache { predts_t.resize(sorted_idx.size()); } } - if (is_multi && !reducer) { - reducer.reset(new dh::AllReducer); - reducer->Init(device); + if (is_multi && !communicator) { + communicator = collective::Communicator::GetDevice(device); } } }; @@ -205,9 +205,9 @@ double ScaleClasses(common::Span results, common::Span local_are common::Span tp, common::Span auc, std::shared_ptr cache, size_t n_classes) { dh::XGBDeviceAllocator alloc; - if (rabit::IsDistributed()) { + if (collective::IsDistributed()) { CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice()); - cache->reducer->AllReduceSum(results.data(), results.data(), results.size()); + cache->communicator->AllReduceSum(results.data(), results.size()); } auto reduce_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { diff --git a/src/metric/auc.h b/src/metric/auc.h index c42df6890a39..82fe46bfe065 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -10,13 +10,14 @@ #include #include +#include "../collective/communicator-inl.h" +#include "../common/common.h" +#include "../common/threading_utils.h" #include "rabit/rabit.h" #include "xgboost/base.h" -#include "xgboost/span.h" #include "xgboost/data.h" #include "xgboost/metric.h" -#include "../common/common.h" -#include "../common/threading_utils.h" +#include "xgboost/span.h" namespace xgboost { namespace metric { @@ -101,7 +102,7 @@ XGBOOST_DEVICE inline double CalcDeltaPRAUC(double fp_prev, double fp, inline void InvalidGroupAUC() { LOG(INFO) << "Invalid group with less than 3 samples is found on worker " - << rabit::GetRank() << ". Calculating AUC value requires at " + << collective::GetRank() << ". Calculating AUC value requires at " << "least 2 pairs of samples."; } diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 596894a547c2..be8530f0882d 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -196,8 +196,8 @@ class PseudoErrorLoss : public Metric { return std::make_tuple(v, wt); }); double dat[2]{result.Residue(), result.Weights()}; - if (rabit::IsDistributed()) { - rabit::Allreduce(dat, 2); + if (collective::IsDistributed()) { + collective::Allreduce(dat, 2); } return EvalRowMAPE::GetFinal(dat[0], dat[1]); } @@ -365,7 +365,7 @@ struct EvalEWiseBase : public Metric { }); double dat[2]{result.Residue(), result.Weights()}; - rabit::Allreduce(dat, 2); + collective::Allreduce(dat, 2); return Policy::GetFinal(dat[0], dat[1]); } diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 3c2ef7d38a21..3a8695df04aa 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -185,7 +185,7 @@ struct EvalMClassBase : public Metric { dat[0] = result.Residue(); dat[1] = result.Weights(); } - rabit::Allreduce(dat, 2); + collective::Allreduce(dat, 2); return Derived::GetFinal(dat[0], dat[1]); } /*! diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index a79b67cb3b84..a251e7c87c56 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -20,17 +20,18 @@ // corresponding headers that brings in those function declaration can't be included with CUDA). // This precludes the CPU and GPU logic to coexist inside a .cu file +#include #include #include -#include -#include +#include #include -#include "xgboost/host_device_vector.h" +#include "../collective/communicator-inl.h" #include "../common/math.h" #include "../common/threading_utils.h" #include "metric_common.h" +#include "xgboost/host_device_vector.h" namespace { @@ -103,7 +104,7 @@ struct EvalAMS : public Metric { } double Eval(const HostDeviceVector& preds, const MetaInfo& info) override { - CHECK(!rabit::IsDistributed()) << "metric AMS do not support distributed evaluation"; + CHECK(!collective::IsDistributed()) << "metric AMS do not support distributed evaluation"; using namespace std; // NOLINT(*) const auto ndata = static_cast(info.labels.Size()); @@ -216,10 +217,10 @@ struct EvalRank : public Metric, public EvalRankConfig { exc.Rethrow(); } - if (rabit::IsDistributed()) { + if (collective::IsDistributed()) { double dat[2]{sum_metric, static_cast(ngroups)}; // approximately estimate the metric using mean - rabit::Allreduce(dat, 2); + collective::Allreduce(dat, 2); return dat[0] / dat[1]; } else { return sum_metric / ngroups; @@ -341,7 +342,7 @@ struct EvalCox : public Metric { public: EvalCox() = default; double Eval(const HostDeviceVector& preds, const MetaInfo& info) override { - CHECK(!rabit::IsDistributed()) << "Cox metric does not support distributed evaluation"; + CHECK(!collective::IsDistributed()) << "Cox metric does not support distributed evaluation"; using namespace std; // NOLINT(*) const auto ndata = static_cast(info.labels.Size()); diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index 296764be0b14..db5ceaa84f7a 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -214,7 +214,7 @@ template struct EvalEWiseSurvivalBase : public Metric { info.labels_upper_bound_, preds); double dat[2]{result.Residue(), result.Weights()}; - rabit::Allreduce(dat, 2); + collective::Allreduce(dat, 2); return Policy::GetFinal(dat[0], dat[1]); } diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index 00d27a57afef..66fd0e4f6fb7 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -7,6 +7,7 @@ #include #include +#include "../collective/communicator-inl.h" #include "../common/common.h" #include "rabit/rabit.h" #include "xgboost/generic_parameters.h" @@ -39,7 +40,7 @@ inline void UpdateLeafValues(std::vector* p_quantiles, std::vector(&n_leaf, 1); + collective::Allreduce(&n_leaf, 1); CHECK(quantiles.empty() || quantiles.size() == n_leaf); if (quantiles.empty()) { quantiles.resize(n_leaf, std::numeric_limits::quiet_NaN()); @@ -49,12 +50,12 @@ inline void UpdateLeafValues(std::vector* p_quantiles, std::vector n_valids(quantiles.size()); std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(), [](float q) { return static_cast(!std::isnan(q)); }); - rabit::Allreduce(n_valids.data(), n_valids.size()); + collective::Allreduce(n_valids.data(), n_valids.size()); // convert to 0 for all reduce std::replace_if( quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f); // use the mean value - rabit::Allreduce(quantiles.data(), quantiles.size()); + collective::Allreduce(quantiles.data(), quantiles.size()); for (size_t i = 0; i < n_leaf; ++i) { if (n_valids[i] > 0) { quantiles[i] /= static_cast(n_valids[i]); diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 70fad741be73..80bea6271c56 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -8,6 +8,7 @@ #include #include +#include "../../collective/communicator-inl.h" #include "../../common/hist_util.h" #include "../../data/gradient_index.h" #include "expand_entry.h" @@ -196,8 +197,9 @@ class HistogramBuilder { } }); - rabit::Allreduce(reinterpret_cast(this->hist_[starting_index].data()), - builder_.GetNumBins() * sync_count * 2); + collective::Allreduce( + reinterpret_cast(this->hist_[starting_index].data()), + builder_.GetNumBins() * sync_count * 2); ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, p_tree); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 8bb787d7d2f5..5b56eaa52d57 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -74,7 +74,7 @@ class GloablApproxBuilder { } histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_, - rabit::IsDistributed()); + collective::IsDistributed()); monitor_->Stop(__func__); } @@ -88,7 +88,7 @@ class GloablApproxBuilder { for (auto const &g : gpair) { root_sum.Add(g); } - rabit::Allreduce(reinterpret_cast(&root_sum), 2); + collective::Allreduce(reinterpret_cast(&root_sum), 2); std::vector nodes{best}; size_t i = 0; auto space = ConstructHistSpace(partitioner_, nodes); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index f55769cc0869..fd06aeb02282 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -100,7 +100,7 @@ class ColMaker: public TreeUpdater { void Update(HostDeviceVector *gpair, DMatrix *dmat, common::Span> /*out_position*/, const std::vector &trees) override { - if (rabit::IsDistributed()) { + if (collective::IsDistributed()) { LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't " "support distributed training."; } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e1a940a9df34..ebadb8ac0e02 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -19,6 +19,7 @@ #include "xgboost/span.h" #include "xgboost/json.h" +#include "../collective/device_communicator.cuh" #include "../common/io.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" @@ -523,14 +524,13 @@ struct GPUHistMakerDevice { } // num histograms is the number of contiguous histograms in memory to reduce over - void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) { + void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); - reducer->AllReduceSum(reinterpret_cast(d_node_hist), - reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * - (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) * - num_histograms); + communicator->AllReduceSum(reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * + (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) * + num_histograms); monitor.Stop("AllReduce"); } @@ -538,8 +538,8 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(std::vector const& candidates, dh::AllReducer* reducer, - const RegTree& tree) { + void BuildHistLeftRight(std::vector const& candidates, + collective::DeviceCommunicator* communicator, const RegTree& tree) { if (candidates.empty()) return; // Some nodes we will manually compute histograms // others we will do by subtraction @@ -570,7 +570,7 @@ struct GPUHistMakerDevice { // Reduce all in one go // This gives much better latency in a distributed setting // when processing a large batch - this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size()); + this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size()); for (size_t i = 0; i < subtraction_nidx.size(); i++) { auto build_hist_nidx = hist_nidx.at(i); @@ -580,7 +580,7 @@ struct GPUHistMakerDevice { if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { // Calculate other histogram manually this->BuildHist(subtraction_trick_nidx); - this->AllReduceHist(subtraction_trick_nidx, reducer, 1); + this->AllReduceHist(subtraction_trick_nidx, communicator, 1); } } } @@ -589,7 +589,7 @@ struct GPUHistMakerDevice { RegTree& tree = *p_tree; // Sanity check - have we created a leaf with no training instances? - if (!rabit::IsDistributed() && row_partitioner) { + if (!collective::IsDistributed() && row_partitioner) { CHECK(row_partitioner->GetRows(candidate.nid).size() > 0) << "No training instances in this leaf!"; } @@ -638,7 +638,7 @@ struct GPUHistMakerDevice { parent.RightChild()); } - GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) { + GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) { constexpr bst_node_t kRootNIdx = 0; dh::XGBCachingDeviceAllocator alloc; auto gpair_it = dh::MakeTransformIterator( @@ -646,11 +646,11 @@ struct GPUHistMakerDevice { GradientPairPrecise root_sum = dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(), GradientPairPrecise{}, thrust::plus{}); - rabit::Allreduce(reinterpret_cast(&root_sum), 2); + collective::Allreduce(reinterpret_cast(&root_sum), 2); hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); - this->AllReduceHist(kRootNIdx, reducer, 1); + this->AllReduceHist(kRootNIdx, communicator, 1); // Remember root stats node_sum_gradients[kRootNIdx] = root_sum; @@ -665,7 +665,7 @@ struct GPUHistMakerDevice { } void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo task, - RegTree* p_tree, dh::AllReducer* reducer, + RegTree* p_tree, collective::DeviceCommunicator* communicator, HostDeviceVector* p_out_position) { auto& tree = *p_tree; // Process maximum 32 nodes at a time @@ -676,7 +676,7 @@ struct GPUHistMakerDevice { monitor.Stop("Reset"); monitor.Start("InitRoot"); - driver.Push({ this->InitRoot(p_tree, reducer) }); + driver.Push({ this->InitRoot(p_tree, communicator) }); monitor.Stop("InitRoot"); // The set of leaves that can be expanded asynchronously @@ -703,7 +703,7 @@ struct GPUHistMakerDevice { monitor.Stop("UpdatePosition"); monitor.Start("BuildHist"); - this->BuildHistLeftRight(filtered_expand_set, reducer, tree); + this->BuildHistLeftRight(filtered_expand_set, communicator, tree); monitor.Stop("BuildHist"); monitor.Start("EvaluateSplits"); @@ -785,11 +785,11 @@ class GPUHistMaker : public TreeUpdater { void InitDataOnce(DMatrix* dmat) { CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device"; info_ = &dmat->Info(); - reducer_.Init({ctx_->gpu_id}); // NOLINT + communicator_ = collective::Communicator::GetDevice(ctx_->gpu_id); // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); - rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); + collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); BatchParam batch_param{ ctx_->gpu_id, @@ -819,12 +819,12 @@ class GPUHistMaker : public TreeUpdater { void CheckTreesSynchronized(RegTree* local_tree) const { std::string s_model; common::MemoryBufferStream fs(&s_model); - int rank = rabit::GetRank(); + int rank = collective::GetRank(); if (rank == 0) { local_tree->Save(&fs); } fs.Seek(0); - rabit::Broadcast(&s_model, 0); + collective::Broadcast(&s_model, 0); RegTree reference_tree{}; // rank 0 tree reference_tree.Load(&fs); CHECK(*local_tree == reference_tree); @@ -837,7 +837,7 @@ class GPUHistMaker : public TreeUpdater { monitor_.Stop("InitData"); gpair->SetDevice(ctx_->gpu_id); - maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position); + maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator_, p_out_position); } bool UpdatePredictionCache(const DMatrix* data, @@ -863,7 +863,7 @@ class GPUHistMaker : public TreeUpdater { GPUHistMakerTrainParam hist_maker_param_; - dh::AllReducer reducer_; + collective::DeviceCommunicator* communicator_; DMatrix* p_last_fmat_{nullptr}; RegTree const* p_last_tree_{nullptr}; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 2ccc426cf73a..3a35781d43ba 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -103,7 +103,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot( for (auto const &grad : gpair_h) { grad_stat.Add(grad.GetGrad(), grad.GetHess()); } - rabit::Allreduce(reinterpret_cast(&grad_stat), 2); + collective::Allreduce(reinterpret_cast(&grad_stat), 2); } auto weight = evaluator_->InitRoot(GradStats{grad_stat}); @@ -320,7 +320,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, ++page_id; } histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, - rabit::IsDistributed()); + collective::IsDistributed()); if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 5dd7694b9792..f3cc1a9fa105 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -7,14 +7,15 @@ #include #include -#include #include +#include -#include "xgboost/json.h" -#include "./param.h" +#include "../collective/communicator-inl.h" #include "../common/io.h" #include "../common/threading_utils.h" #include "../predictor/predict_fn.h" +#include "./param.h" +#include "xgboost/json.h" namespace xgboost { namespace tree { @@ -100,8 +101,9 @@ class TreeRefresher : public TreeUpdater { } }); }; - rabit::Allreduce(&dmlc::BeginPtr(stemp[0])->sum_grad, stemp[0].size() * 2, - lazy_get_stats); + lazy_get_stats(); + collective::Allreduce(&dmlc::BeginPtr(stemp[0])->sum_grad, + stemp[0].size() * 2); // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 5b89f80ec993..331a982b15ba 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -4,12 +4,14 @@ * \brief synchronize the tree in all distributed nodes */ #include -#include -#include + #include +#include +#include -#include "xgboost/json.h" +#include "../collective/communicator-inl.h" #include "../common/io.h" +#include "xgboost/json.h" namespace xgboost { namespace tree { @@ -35,17 +37,17 @@ class TreeSyncher : public TreeUpdater { void Update(HostDeviceVector*, DMatrix*, common::Span> /*out_position*/, const std::vector& trees) override { - if (rabit::GetWorldSize() == 1) return; + if (collective::GetWorldSize() == 1) return; std::string s_model; common::MemoryBufferStream fs(&s_model); - int rank = rabit::GetRank(); + int rank = collective::GetRank(); if (rank == 0) { for (auto tree : trees) { tree->Save(&fs); } } fs.Seek(0); - rabit::Broadcast(&s_model, 0); + collective::Broadcast(&s_model, 0); for (auto tree : trees) { tree->Load(&fs); } diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index a179ab7f9c1d..06b832319d66 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,6 +1,7 @@ #include #include "test_quantile.h" #include "../helpers.h" +#include "../../../src/collective/device_communicator.cuh" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" @@ -467,12 +468,10 @@ TEST(GPUQuantile, SameOnAllWorkers) { thrust::copy(thrust::device, local_data.data(), local_data.data() + local_data.size(), all_workers.begin() + local_data.size() * rank); - dh::AllReducer reducer; - reducer.Init(0); + collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(0); - reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(), - all_workers.size()); - reducer.Synchronize(); + communicator->AllReduceSum(all_workers.data().get(), all_workers.size()); + communicator->Synchronize(); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); std::vector h_base_line(base_line.size()); diff --git a/tests/distributed/distributed_gpu.py b/tests/distributed/distributed_gpu.py index d10d2aed4884..8ec210ec65fa 100644 --- a/tests/distributed/distributed_gpu.py +++ b/tests/distributed/distributed_gpu.py @@ -8,9 +8,9 @@ def run_test(name, params_fun): """Runs a distributed GPU test.""" # Always call this before using distributed module - with xgb.rabit.RabitContext(): - rank = xgb.rabit.get_rank() - world = xgb.rabit.get_world_size() + with xgb.collective.CommunicatorContext(): + rank = xgb.collective.get_rank() + world = xgb.collective.get_world_size() # Load file, file will be automatically sharded in distributed mode. dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') @@ -28,8 +28,8 @@ def run_test(name, params_fun): # Have each worker save its model model_name = "test.model.%s.%d" % (name, rank) bst.dump_model(model_name, with_stats=True) - xgb.rabit.allreduce(np.ones((1, 1)), xgb.rabit.Op.MAX) # sync - xgb.rabit.tracker_print("Finished training\n") + xgb.collective.allreduce(np.ones((1, 1)), xgb.collective.Op.MAX) # sync + xgb.collective.communicator_print("Finished training\n") if (rank == 0): for i in range(0, world): diff --git a/tests/distributed/test_basic.py b/tests/distributed/test_basic.py index db2916b39a3c..4e64497b1bfd 100644 --- a/tests/distributed/test_basic.py +++ b/tests/distributed/test_basic.py @@ -2,7 +2,7 @@ import xgboost as xgb # Always call this before using distributed module -with xgb.rabit.RabitContext(): +with xgb.collective.CommunicatorContext(): # Load file, file will be automatically sharded in distributed mode. dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') @@ -19,6 +19,6 @@ bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) # Save the model, only ask process 0 to save the model. - if xgb.rabit.get_rank() == 0: + if xgb.collective.get_rank() == 0: bst.save_model("test.model") - xgb.rabit.tracker_print("Finished training\n") + xgb.collective.tracker_print("Finished training\n") diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index ea5a3a0f35cf..d62e5d923702 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -35,7 +35,7 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: ] # Always call this before using distributed module - with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): + with xgb.collective.CommunicatorContext([e.encode() for e in rabit_env]): # Load file, file will not be sharded in federated mode. dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) diff --git a/tests/distributed/test_issue3402.py b/tests/distributed/test_issue3402.py index 7a40d3420ebb..6f86cc0d837d 100644 --- a/tests/distributed/test_issue3402.py +++ b/tests/distributed/test_issue3402.py @@ -2,7 +2,7 @@ import xgboost as xgb import numpy as np -with xgb.rabit.RabitContext(): +with xgb.rabit.CommunicatorContext(): X = [ [15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05, 4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00, diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index b9ae17531790..5b260cbc63d3 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -13,36 +13,32 @@ def test_rabit_tracker(): tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1) tracker.start(1) - worker_env = tracker.worker_envs() - rabit_env = [] - for k, v in worker_env.items(): - rabit_env.append(f"{k}={v}".encode()) - with xgb.rabit.RabitContext(rabit_env): - ret = xgb.rabit.broadcast('test1234', 0) + with xgb.collective.CommunicatorContext(**tracker.worker_envs()): + ret = xgb.collective.broadcast('test1234', 0) assert str(ret) == 'test1234' def run_rabit_ops(client, n_workers): from test_with_dask import _get_client_workers from xgboost.dask import RabitContext, _get_rabit_args - from xgboost import rabit + from xgboost import collective workers = _get_client_workers(client) rabit_args = client.sync(_get_rabit_args, len(workers), None, client) - assert not rabit.is_distributed() + assert not collective.is_distributed() n_workers_from_dask = len(workers) assert n_workers == n_workers_from_dask def local_test(worker_id): with RabitContext(rabit_args): a = 1 - assert rabit.is_distributed() + assert collective.is_distributed() a = np.array([a]) - reduced = rabit.allreduce(a, rabit.Op.SUM) + reduced = collective.allreduce(a, collective.Op.SUM) assert reduced[0] == n_workers worker_id = np.array([worker_id]) - reduced = rabit.allreduce(worker_id, rabit.Op.MAX) + reduced = collective.allreduce(worker_id, collective.Op.MAX) assert reduced == n_workers - 1 return 1 @@ -66,14 +62,14 @@ def test_rank_assignment() -> None: from test_with_dask import _get_client_workers def local_test(worker_id): - with xgb.dask.RabitContext(args): + with xgb.dask.CommunicatorContext(**args): for val in args: sval = val.decode("utf-8") if sval.startswith("DMLC_TASK_ID"): task_id = sval break matched = re.search(".*-([0-9]).*", task_id) - rank = xgb.rabit.get_rank() + rank = xgb.collective.get_rank() # As long as the number of workers is lesser than 10, rank and worker id # should be the same assert rank == int(matched.group(1)) From 07ce6f17faf8249df89422396f3eba92add28b47 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 21 Sep 2022 18:13:38 -0700 Subject: [PATCH 02/18] fix size_t specialization --- src/collective/communicator-inl.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index d848eaef2829..8a59a2ccfd69 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -179,8 +179,16 @@ inline void Allreduce(int64_t *send_receive_buffer, size_t count) { Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op); } -template -inline void Allreduce(uint64_t *send_receive_buffer, size_t count) { +template ::value, bool> = true> +inline void Allreduce(T *send_receive_buffer, size_t count) { + Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); +} + +// Specialize on size_t for platforms where size_t != uint64_t. +template ::value>::type> +inline void Allreduce(T *send_receive_buffer, size_t count) { + static_assert(sizeof(size_t) == sizeof(uint64_t), ""); Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); } From 9db6d96ca00139b1b02af85a85d6cf80a7f5cb0e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 22 Sep 2022 10:20:04 -0700 Subject: [PATCH 03/18] really fix size_t --- src/collective/communicator-inl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 8a59a2ccfd69..3b8ca276ab70 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -179,16 +179,16 @@ inline void Allreduce(int64_t *send_receive_buffer, size_t count) { Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op); } -template ::value, bool> = true> -inline void Allreduce(T *send_receive_buffer, size_t count) { +template +inline void Allreduce(uint64_t *send_receive_buffer, size_t count) { Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); } // Specialize on size_t for platforms where size_t != uint64_t. template ::value>::type> + std::enable_if_t::value && !std::is_same::value> > inline void Allreduce(T *send_receive_buffer, size_t count) { - static_assert(sizeof(size_t) == sizeof(uint64_t), ""); + static_assert(sizeof(T) == sizeof(uint64_t), ""); Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); } From 96a232edbac4b1458ada67f03d720f61d8ddc6cb Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 22 Sep 2022 11:37:26 -0700 Subject: [PATCH 04/18] try again --- src/collective/communicator-inl.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 3b8ca276ab70..4b143107919f 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -184,9 +184,10 @@ inline void Allreduce(uint64_t *send_receive_buffer, size_t count) { Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); } -// Specialize on size_t for platforms where size_t != uint64_t. -template ::value && !std::is_same::value> > +// Specialization for size_t, which is implementation defined, so it might or might not +// be one of uint64_t/uint32_t/unsigned long long/unsigned long. +template {} && !std::is_same{}> > inline void Allreduce(T *send_receive_buffer, size_t count) { static_assert(sizeof(T) == sizeof(uint64_t), ""); Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); From 5853f83f4f013fff090b852c20695d4c93f8335f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 22 Sep 2022 11:49:04 -0700 Subject: [PATCH 05/18] add include --- src/metric/elementwise_metric.cu | 1 + src/metric/multiclass_metric.cu | 1 + 2 files changed, 2 insertions(+) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index be8530f0882d..bb5343a60f37 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -12,6 +12,7 @@ #include +#include "../collective/communicator-inl.h" #include "../common/common.h" #include "../common/math.h" #include "../common/pseudo_huber.h" diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 3a8695df04aa..75930e1aaa7f 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -11,6 +11,7 @@ #include #include "metric_common.h" +#include "../collective/communicator-inl.h" #include "../common/math.h" #include "../common/common.h" #include "../common/threading_utils.h" From 3482b53589bb680e6f522927132573ea379f62d4 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 22 Sep 2022 11:56:27 -0700 Subject: [PATCH 06/18] more include --- src/metric/survival_metric.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index db5ceaa84f7a..1922ff1d990f 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -16,6 +16,7 @@ #include "xgboost/host_device_vector.h" #include "metric_common.h" +#include "../collective/communicator-inl.h" #include "../common/math.h" #include "../common/survival_util.h" #include "../common/threading_utils.h" From 2ae73c088a30c8176fc6aeee1d081d87b0cf4e6d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 22 Sep 2022 12:25:42 -0700 Subject: [PATCH 07/18] fix lint errors --- python-package/xgboost/dask.py | 2 +- src/cli_main.cc | 2 +- src/collective/communicator-inl.h | 2 ++ src/collective/noop_communicator.h | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 6e15b0aa9059..3f051330c1de 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -216,7 +216,7 @@ def _assert_dask_support() -> None: class CommunicatorContext(collective.CommunicatorContext): """A context controlling collective communicator initialization and finalization.""" - def __init__(self, **args) -> None: + def __init__(self, **args: Any) -> None: super().__init__(**args) worker = distributed.get_worker() with distributed.worker_client() as client: diff --git a/src/cli_main.cc b/src/cli_main.cc index 7fbe0a34b774..bb89756c8315 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -475,7 +475,7 @@ class CLI { // Initialize the collective communicator. Json json{JsonObject()}; - for (auto& kv: cfg) { + for (auto& kv : cfg) { json[kv.first] = String(kv.second); } std::string json_str; diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 4b143107919f..a4573f37193e 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -2,6 +2,8 @@ * Copyright 2022 XGBoost contributors */ #pragma once +#include + #include "communicator.h" namespace xgboost { diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h index 0bc4a4948c35..7e5aaa026735 100644 --- a/src/collective/noop_communicator.h +++ b/src/collective/noop_communicator.h @@ -2,6 +2,8 @@ * Copyright 2022 XGBoost contributors */ #pragma once +#include + #include "communicator.h" namespace xgboost { From 2c594ca0b968badf809b8ff5964453902bcf0b3b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 23 Sep 2022 10:56:04 -0700 Subject: [PATCH 08/18] remove rabit includes --- src/c_api/c_api.cc | 7 ++-- src/cli_main.cc | 4 +-- src/collective/communicator-inl.h | 3 +- src/common/hist_util.cc | 5 --- src/common/quantile.cc | 1 - src/common/random.h | 1 - src/common/timer.cc | 5 --- src/data/iterative_dmatrix.cc | 3 -- src/data/sparse_page_source.h | 1 - src/logging.cc | 4 --- src/metric/auc.cc | 4 --- src/metric/auc.cu | 2 -- src/metric/auc.h | 1 - src/metric/elementwise_metric.cu | 1 - src/metric/multiclass_metric.cu | 2 -- src/metric/rank_metric.cc | 1 - src/metric/rank_metric.cu | 3 -- src/metric/survival_metric.cu | 1 - src/objective/adaptive.h | 1 - src/objective/regression_obj.cu | 4 +-- src/tree/hist/histogram.h | 1 - src/tree/updater_colmaker.cc | 2 -- src/tree/updater_prune.cc | 3 -- src/tree/updater_quantile_hist.cc | 7 ---- src/tree/updater_quantile_hist.h | 1 - src/tree/updater_refresh.cc | 1 - tests/cpp/common/test_quantile.cc | 30 ++++++++--------- tests/cpp/common/test_quantile.cu | 16 ++++----- tests/cpp/common/test_quantile.h | 15 ++++----- tests/distributed/test_federated.py | 26 +++++++-------- tests/distributed/test_issue3402.py | 6 ++-- .../test_gpu_with_dask/test_gpu_with_dask.py | 17 +++------- tests/python/test_with_dask.py | 33 ++++++++----------- 33 files changed, 69 insertions(+), 143 deletions(-) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8fc905f61332..8e01c8487b60 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1,11 +1,8 @@ // Copyright (c) 2014-2022 by Contributors -#include #include -#include #include #include -#include #include #include #include @@ -27,7 +24,6 @@ #include "../common/charconv.h" #include "../data/adapter.h" #include "../data/simple_dmatrix.h" -#include "../data/proxy_dmatrix.h" #if defined(XGBOOST_USE_FEDERATED) #include "../../plugin/federated/federated_server.h" @@ -1525,7 +1521,8 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, XGB_DLL int XGCommunicatorInit(char const* json_config) { API_BEGIN(); xgboost_CHECK_C_ARG_PTR(json_config); - collective::Init(json_config); + Json config{Json::Load(StringView{json_config})}; + collective::Init(config); API_END(); } diff --git a/src/cli_main.cc b/src/cli_main.cc index bb89756c8315..de9ae6253a0c 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -478,9 +478,7 @@ class CLI { for (auto& kv : cfg) { json[kv.first] = String(kv.second); } - std::string json_str; - Json::Dump(json, &json_str); - collective::Init(json_str.c_str()); + collective::Init(json); param_.Configure(cfg); } diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index a4573f37193e..923c6d291ad3 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -56,8 +56,7 @@ namespace collective { * - federated_client_key: Client key file path. Only needed for the SSL mode. * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. */ -inline void Init(char const *json_config) { - Json config{Json::Load(StringView{json_config})}; +inline void Init(Json const& config) { Communicator::Init(config); } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 42e3ff9ce202..ab67bc92fee7 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -3,19 +3,14 @@ * \file hist_util.cc */ #include -#include -#include -#include #include #include "xgboost/base.h" #include "../common/common.h" #include "hist_util.h" -#include "random.h" #include "column_matrix.h" #include "quantile.h" -#include "../data/gradient_index.h" #if defined(XGBOOST_MM_PREFETCH_PRESENT) #include diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 3fa8c66b1a71..3f3bb53265d4 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -10,7 +10,6 @@ #include "../data/adapter.h" #include "categorical.h" #include "hist_util.h" -#include "rabit/rabit.h" namespace xgboost { namespace common { diff --git a/src/common/random.h b/src/common/random.h index 1467aa57c177..2d29bede376e 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -7,7 +7,6 @@ #ifndef XGBOOST_COMMON_RANDOM_H_ #define XGBOOST_COMMON_RANDOM_H_ -#include #include #include diff --git a/src/common/timer.cc b/src/common/timer.cc index 6e5585171f0e..99150aa2695e 100644 --- a/src/common/timer.cc +++ b/src/common/timer.cc @@ -3,13 +3,8 @@ */ #include "timer.h" -#include - -#include #include -#include #include -#include #include "../collective/communicator-inl.h" diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 2e86cd957380..251b9ed04053 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -3,11 +3,8 @@ */ #include "iterative_dmatrix.h" -#include - #include "../collective/communicator-inl.h" #include "../common/column_matrix.h" -#include "../common/hist_util.h" #include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. #include "gradient_index.h" #include "proxy_dmatrix.h" diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 0a3e32e75e1f..ad19847e0e44 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -14,7 +14,6 @@ #include #include -#include "rabit/rabit.h" #include "xgboost/base.h" #include "xgboost/data.h" diff --git a/src/logging.cc b/src/logging.cc index 2fe2d0f10e2f..d24c6633d987 100644 --- a/src/logging.cc +++ b/src/logging.cc @@ -4,14 +4,10 @@ * \brief Implementation of loggers. * \author Tianqi Chen */ -#include - #include -#include #include "xgboost/parameter.h" #include "xgboost/logging.h" -#include "xgboost/json.h" #include "collective/communicator-inl.h" diff --git a/src/metric/auc.cc b/src/metric/auc.cc index e564134c4347..2cda80f069f5 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -14,11 +14,7 @@ #include #include -#include "../collective/communicator-inl.h" -#include "../common/common.h" #include "../common/math.h" -#include "../common/threading_utils.h" -#include "rabit/rabit.h" #include "xgboost/host_device_vector.h" #include "xgboost/linalg.h" #include "xgboost/metric.h" diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 96356bbb2c55..aa78002d2ef5 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -11,12 +11,10 @@ #include #include -#include "rabit/rabit.h" #include "xgboost/span.h" #include "xgboost/data.h" #include "auc.h" #include "../collective/device_communicator.cuh" -#include "../common/device_helpers.cuh" #include "../common/ranking_utils.cuh" namespace xgboost { diff --git a/src/metric/auc.h b/src/metric/auc.h index 82fe46bfe065..ab205bf7e573 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -13,7 +13,6 @@ #include "../collective/communicator-inl.h" #include "../common/common.h" #include "../common/threading_utils.h" -#include "rabit/rabit.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/metric.h" diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index bb5343a60f37..17151e4b14c5 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -7,7 +7,6 @@ * The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset. */ #include -#include #include #include diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 75930e1aaa7f..c453f6686303 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -4,7 +4,6 @@ * \brief evaluation metrics for multiclass classification. * \author Kailong Chen, Tianqi Chen */ -#include #include #include @@ -13,7 +12,6 @@ #include "metric_common.h" #include "../collective/communicator-inl.h" #include "../common/math.h" -#include "../common/common.h" #include "../common/threading_utils.h" #if defined(XGBOOST_USE_CUDA) diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index a251e7c87c56..2956a3fa7a54 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -21,7 +21,6 @@ // This precludes the CPU and GPU logic to coexist inside a .cu file #include -#include #include #include diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index d2dc2f4ebb70..1cf9558ed3cd 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -4,15 +4,12 @@ * \brief prediction rank based metrics. * \author Kailong Chen, Tianqi Chen */ -#include #include #include #include #include -#include -#include #include #include "metric_common.h" diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index 1922ff1d990f..86ce9672a65c 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -5,7 +5,6 @@ * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking */ -#include #include #include diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index 66fd0e4f6fb7..ba37f83e4e64 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -9,7 +9,6 @@ #include "../collective/communicator-inl.h" #include "../common/common.h" -#include "rabit/rabit.h" #include "xgboost/generic_parameters.h" #include "xgboost/host_device_vector.h" #include "xgboost/tree_model.h" diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 6b2ce6371a6d..632a27778a89 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -724,8 +724,8 @@ class MeanAbsoluteError : public ObjFunction { } // Weighted average base score across all workers - rabit::Allreduce(out.Values().data(), out.Values().size()); - rabit::Allreduce(&w, 1); + collective::Allreduce(out.Values().data(), out.Values().size()); + collective::Allreduce(&w, 1); std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out), [w](float v) { return v / w; }); diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index a46f8db78113..acc13f6817a5 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -12,7 +12,6 @@ #include "../../common/hist_util.h" #include "../../data/gradient_index.h" #include "expand_entry.h" -#include "rabit/rabit.h" #include "xgboost/tree_model.h" namespace xgboost { diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index fd06aeb02282..89e928e4d8cd 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -4,8 +4,6 @@ * \brief use columnwise update to construct a tree * \author Tianqi Chen */ -#include -#include #include #include #include diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 1cedeb25cb0d..a2f1a31b17aa 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -4,16 +4,13 @@ * \brief prune a tree given the statistics * \author Tianqi Chen */ -#include #include -#include #include #include "xgboost/base.h" #include "xgboost/json.h" #include "./param.h" -#include "../common/io.h" #include "../common/timer.h" namespace xgboost { namespace tree { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 3a35781d43ba..cd6345619d5b 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -6,19 +6,12 @@ */ #include "./updater_quantile_hist.h" -#include - #include #include -#include #include #include #include -#include "../common/column_matrix.h" -#include "../common/hist_util.h" -#include "../common/random.h" -#include "../common/threading_utils.h" #include "constraints.h" #include "hist/evaluate_splits.h" #include "param.h" diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index f50e6ab8a0bb..4c939bf7d4d0 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -7,7 +7,6 @@ #ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ -#include #include #include diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index f3cc1a9fa105..a70074740a90 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -4,7 +4,6 @@ * \brief refresh the statistics and leaf value on the tree on the dataset * \author Tianqi Chen */ -#include #include #include diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index a6026b7c6ab5..73fa4d5e74a1 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -46,8 +46,8 @@ template void TestDistributedQuantile(size_t rows, size_t cols) { std::string msg {"Skipping AllReduce test"}; int32_t constexpr kWorkers = 4; - InitRabitContext(msg, kWorkers); - auto world = rabit::GetWorldSize(); + InitCommunicatorContext(msg, kWorkers); + auto world = collective::GetWorldSize(); if (world != 1) { ASSERT_EQ(world, kWorkers); } else { @@ -65,7 +65,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) { // Generate cuts for distributed environment. auto sparsity = 0.5f; - auto rank = rabit::GetRank(); + auto rank = collective::GetRank(); std::vector ft(cols); for (size_t i = 0; i < ft.size(); ++i) { ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical; @@ -99,8 +99,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) { sketch_distributed.MakeCuts(&distributed_cuts); // Generate cuts for single node environment - rabit::Finalize(); - CHECK_EQ(rabit::GetWorldSize(), 1); + collective::Finalize(); + CHECK_EQ(collective::GetWorldSize(), 1); std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); m->Info().num_row_ = world * rows; ContainerType sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(), @@ -184,8 +184,8 @@ TEST(Quantile, SameOnAllWorkers) { #if defined(__unix__) std::string msg{"Skipping Quantile AllreduceBasic test"}; int32_t constexpr kWorkers = 4; - InitRabitContext(msg, kWorkers); - auto world = rabit::GetWorldSize(); + InitCommunicatorContext(msg, kWorkers); + auto world = collective::GetWorldSize(); if (world != 1) { CHECK_EQ(world, kWorkers); } else { @@ -196,7 +196,7 @@ TEST(Quantile, SameOnAllWorkers) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins( kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) { - auto rank = rabit::GetRank(); + auto rank = collective::GetRank(); HostDeviceVector storage; std::vector ft(kCols); for (size_t i = 0; i < ft.size(); ++i) { @@ -217,12 +217,12 @@ TEST(Quantile, SameOnAllWorkers) { std::vector cut_min_values(cuts.MinValues().size() * world, 0); size_t value_size = cuts.Values().size(); - rabit::Allreduce(&value_size, 1); + collective::Allreduce(&value_size, 1); size_t ptr_size = cuts.Ptrs().size(); - rabit::Allreduce(&ptr_size, 1); + collective::Allreduce(&ptr_size, 1); CHECK_EQ(ptr_size, kCols + 1); size_t min_value_size = cuts.MinValues().size(); - rabit::Allreduce(&min_value_size, 1); + collective::Allreduce(&min_value_size, 1); CHECK_EQ(min_value_size, kCols); size_t value_offset = value_size * rank; @@ -235,9 +235,9 @@ TEST(Quantile, SameOnAllWorkers) { std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(), cut_min_values.begin() + min_values_offset); - rabit::Allreduce(cut_values.data(), cut_values.size()); - rabit::Allreduce(cut_ptrs.data(), cut_ptrs.size()); - rabit::Allreduce(cut_min_values.data(), cut_min_values.size()); + collective::Allreduce(cut_values.data(), cut_values.size()); + collective::Allreduce(cut_ptrs.data(), cut_ptrs.size()); + collective::Allreduce(cut_min_values.data(), cut_min_values.size()); for (int32_t i = 0; i < world; i++) { for (size_t j = 0; j < value_size; ++j) { @@ -256,7 +256,7 @@ TEST(Quantile, SameOnAllWorkers) { } } }); - rabit::Finalize(); + collective::Finalize(); #endif // defined(__unix__) } } // namespace common diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 06b832319d66..628c322b602e 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -342,8 +342,8 @@ TEST(GPUQuantile, AllReduceBasic) { // This test is supposed to run by a python test that setups the environment. std::string msg {"Skipping AllReduce test"}; auto n_gpus = AllVisibleGPUs(); - InitRabitContext(msg, n_gpus); - auto world = rabit::GetWorldSize(); + InitCommunicatorContext(msg, n_gpus); + auto world = collective::GetWorldSize(); if (world != 1) { ASSERT_EQ(world, n_gpus); } else { @@ -383,7 +383,7 @@ TEST(GPUQuantile, AllReduceBasic) { // Set up distributed version. We rely on using rank as seed to generate // the exact same copy of data. - auto rank = rabit::GetRank(); + auto rank = collective::GetRank(); SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} @@ -420,14 +420,14 @@ TEST(GPUQuantile, AllReduceBasic) { ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps); } }); - rabit::Finalize(); + collective::Finalize(); } TEST(GPUQuantile, SameOnAllWorkers) { std::string msg {"Skipping SameOnAllWorkers test"}; auto n_gpus = AllVisibleGPUs(); - InitRabitContext(msg, n_gpus); - auto world = rabit::GetWorldSize(); + InitCommunicatorContext(msg, n_gpus); + auto world = collective::GetWorldSize(); if (world != 1) { ASSERT_EQ(world, n_gpus); } else { @@ -437,7 +437,7 @@ TEST(GPUQuantile, SameOnAllWorkers) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { - auto rank = rabit::GetRank(); + auto rank = collective::GetRank(); HostDeviceVector ft; SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; @@ -455,7 +455,7 @@ TEST(GPUQuantile, SameOnAllWorkers) { // Test for all workers having the same sketch. size_t n_data = sketch_distributed.Data().size(); - rabit::Allreduce(&n_data, 1); + collective::Allreduce(&n_data, 1); ASSERT_EQ(n_data, sketch_distributed.Data().size()); size_t size_as_float = sketch_distributed.Data().size_bytes() / sizeof(float); diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index d92695f53276..f6efdfca8bc2 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -1,16 +1,16 @@ #ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ #define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ -#include #include #include #include #include "../helpers.h" +#include "../../src/collective/communicator-inl.h" namespace xgboost { namespace common { -inline void InitRabitContext(std::string msg, int32_t n_workers) { +inline void InitCommunicatorContext(std::string msg, int32_t n_workers) { auto port = std::getenv("DMLC_TRACKER_PORT"); std::string port_str; if (port) { @@ -28,12 +28,11 @@ inline void InitRabitContext(std::string msg, int32_t n_workers) { return; } - std::vector envs{ - "DMLC_TRACKER_PORT=" + port_str, - "DMLC_TRACKER_URI=" + uri_str, - "DMLC_NUM_WORKER=" + std::to_string(n_workers)}; - char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])}; - rabit::Init(3, c_envs); + Json config{JsonObject()}; + config["DMLC_TRACKER_PORT"] = port_str; + config["DMLC_TRACKER_URI"] = uri_str; + config["DMLC_NUM_WORKER"] = n_workers; + collective::Init(config); } template void RunWithSeedsAndBins(size_t rows, Fn fn) { diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index d62e5d923702..4e75cff14d9f 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -21,21 +21,19 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None: def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None: - rabit_env = [ - 'xgboost_communicator=federated', - f'federated_server_address=localhost:{port}', - f'federated_world_size={world_size}', - f'federated_rank={rank}' - ] + communicator_env = { + 'xgboost_communicator': 'federated', + 'federated_server_address': f'localhost:{port}', + 'federated_world_size': f'{world_size}', + 'federated_rank': f'{rank}' + } if with_ssl: - rabit_env = rabit_env + [ - f'federated_server_cert={SERVER_CERT}', - f'federated_client_key={CLIENT_KEY}', - f'federated_client_cert={CLIENT_CERT}' - ] + communicator_env['federated_server_cert'] = SERVER_CERT + communicator_env['federated_client_key'] = CLIENT_KEY + communicator_env['federated_client_cert'] = CLIENT_CERT # Always call this before using distributed module - with xgb.collective.CommunicatorContext([e.encode() for e in rabit_env]): + with xgb.collective.CommunicatorContext(**communicator_env): # Load file, file will not be sharded in federated mode. dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) @@ -55,9 +53,9 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: early_stopping_rounds=2) # Save the model, only ask process 0 to save the model. - if xgb.rabit.get_rank() == 0: + if xgb.collective.get_rank() == 0: bst.save_model("test.model.json") - xgb.rabit.tracker_print("Finished training\n") + xgb.collective.communicator_print("Finished training\n") def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None: diff --git a/tests/distributed/test_issue3402.py b/tests/distributed/test_issue3402.py index 6f86cc0d837d..cd9beb8d1013 100644 --- a/tests/distributed/test_issue3402.py +++ b/tests/distributed/test_issue3402.py @@ -2,7 +2,7 @@ import xgboost as xgb import numpy as np -with xgb.rabit.CommunicatorContext(): +with xgb.collective.CommunicatorContext(): X = [ [15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05, 4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00, @@ -69,6 +69,6 @@ num_round = 2 bst = xgb.train(param, dtrain, num_round, watchlist) - if xgb.rabit.get_rank() == 0: + if xgb.collective.get_rank() == 0: bst.save_model("test_issue3402.model") - xgb.rabit.tracker_print("Finished training\n") + xgb.collective.communicator_print("Finished training\n") diff --git a/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py index 3cb110bd6c6e..05ce13b57c0d 100644 --- a/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py @@ -1,7 +1,7 @@ """Copyright 2019-2022 XGBoost contributors""" import sys import os -from typing import Type, TypeVar, Any, Dict, List +from typing import Type, TypeVar, Any, Dict, List, Union import pytest import numpy as np import asyncio @@ -407,7 +407,7 @@ def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client) def worker_fn(worker_addr: str, data_ref: Dict) -> None: - with dxgb.RabitContext(rabit_args): + with dxgb.CommunicatorContext(**rabit_args): local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7) fw_rows = local_dtrain.get_float_info("feature_weights").shape[0] assert fw_rows == local_dtrain.num_col() @@ -494,20 +494,13 @@ def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None: test = "--gtest_filter=GPUQuantile." + name def runit( - worker_addr: str, rabit_args: List[bytes] + worker_addr: str, rabit_args: Dict[str, Union[int, str]] ) -> subprocess.CompletedProcess: port_env = '' # setup environment for running the c++ part. - for arg in rabit_args: - if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): - port_env = arg.decode('utf-8') - if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"): - uri_env = arg.decode("utf-8") - port = port_env.split('=') env = os.environ.copy() - env[port[0]] = port[1] - uri = uri_env.split("=") - env[uri[0]] = uri[1] + env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT']) + env["DMLC_TRACKER_URI"] = rabit_args["DMLC_TRACKER_URI"] return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE) with Client(local_cuda_cluster) as client: diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index d6eb4f32b9f7..4d98f79a3246 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"): class TestWithDask: def test_dmatrix_binary(self, client: "Client") -> None: - def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None: - with xgb.dask.RabitContext(rabit_args): - rank = xgb.rabit.get_rank() + def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None: + with xgb.dask.CommunicatorContext(**rabit_args): + rank = xgb.collective.get_rank() X, y = tm.make_categorical(100, 4, 4, False) Xy = xgb.DMatrix(X, y, enable_categorical=True) path = os.path.join(tmpdir, f"{rank}.bin") Xy.save_binary(path) - def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None: - with xgb.dask.RabitContext(rabit_args): - rank = xgb.rabit.get_rank() + def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None: + with xgb.dask.CommunicatorContext(rabit_args): + rank = xgb.collective.get_rank() path = os.path.join(tmpdir, f"{rank}.bin") Xy = xgb.DMatrix(path) assert Xy.num_row() == 100 @@ -1488,20 +1488,13 @@ def run_quantile(self, name: str) -> None: test = "--gtest_filter=Quantile." + name def runit( - worker_addr: str, rabit_args: List[bytes] + worker_addr: str, rabit_args: Dict[str, Union[int, str]] ) -> subprocess.CompletedProcess: port_env = '' # setup environment for running the c++ part. - for arg in rabit_args: - if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): - port_env = arg.decode('utf-8') - if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"): - uri_env = arg.decode("utf-8") - port = port_env.split('=') env = os.environ.copy() - env[port[0]] = port[1] - uri = uri_env.split("=") - env["DMLC_TRACKER_URI"] = uri[1] + env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT']) + env["DMLC_TRACKER_URI"] = rabit_args["DMLC_TRACKER_URI"] return subprocess.run([str(exe), test], env=env, capture_output=True) with LocalCluster(n_workers=4, dashboard_address=":0") as cluster: @@ -1543,8 +1536,8 @@ def test_adaptive(self) -> None: def get_score(config: Dict) -> float: return float(config["learner"]["learner_model_param"]["base_score"]) - def local_test(rabit_args: List[bytes], worker_id: int) -> bool: - with xgb.dask.RabitContext(rabit_args): + def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool: + with xgb.dask.CommunicatorContext(**rabit_args): if worker_id == 0: y = np.array([0.0, 0.0, 0.0]) x = np.array([[0.0]] * 3) @@ -1686,12 +1679,12 @@ def test_no_duplicated_partition(self) -> None: n_workers = len(workers) def worker_fn(worker_addr: str, data_ref: Dict) -> None: - with xgb.dask.RabitContext(rabit_args): + with xgb.dask.CommunicatorContext(**rabit_args): local_dtrain = xgb.dask._dmatrix_from_list_of_parts( **data_ref, nthread=7 ) total = np.array([local_dtrain.num_row()]) - total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM) + total = xgb.collective.allreduce(total, xgb.collective.Op.SUM) assert total[0] == kRows futures = [] From 0f2cdd1972b6fb4c719ea3a659d41fd1e1de70e6 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 23 Sep 2022 11:19:27 -0700 Subject: [PATCH 09/18] fix pylint error --- tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py | 2 +- tests/python/test_with_dask.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py index 05ce13b57c0d..ea02405aab74 100644 --- a/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py @@ -500,7 +500,7 @@ def runit( # setup environment for running the c++ part. env = os.environ.copy() env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT']) - env["DMLC_TRACKER_URI"] = rabit_args["DMLC_TRACKER_URI"] + env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"]) return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE) with Client(local_cuda_cluster) as client: diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 4d98f79a3246..20f5ac20ea72 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1494,7 +1494,7 @@ def runit( # setup environment for running the c++ part. env = os.environ.copy() env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT']) - env["DMLC_TRACKER_URI"] = rabit_args["DMLC_TRACKER_URI"] + env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"]) return subprocess.run([str(exe), test], env=env, capture_output=True) with LocalCluster(n_workers=4, dashboard_address=":0") as cluster: From 217e3a5472de9fde5f7a1685d0f5476453bb8366 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 23 Sep 2022 12:31:45 -0700 Subject: [PATCH 10/18] return dict from communicator context --- python-package/xgboost/collective.py | 5 +++-- python-package/xgboost/spark/core.py | 2 +- tests/python/test_tracker.py | 8 ++------ 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index e4662d744e50..72890c78c0f2 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -4,7 +4,7 @@ import logging import pickle from enum import IntEnum, unique -from typing import Any, List +from typing import Any, List, Dict import numpy as np @@ -233,10 +233,11 @@ class CommunicatorContext: def __init__(self, **args: Any) -> None: self.args = args - def __enter__(self) -> None: + def __enter__(self) -> Dict[str, Any]: init(**self.args) assert is_distributed() LOGGER.debug("-------------- communicator say hello ------------------") + return self.args def __exit__(self, *args: List) -> None: finalize() diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 172c15af8d52..6d3589384679 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -2,8 +2,8 @@ """Xgboost pyspark integration submodule for core code.""" # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=too-few-public-methods, too-many-lines -from typing import Iterator, Optional, Tuple import json +from typing import Iterator, Optional, Tuple import numpy as np import pandas as pd diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index e545dcb10762..67543a96829d 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -79,12 +79,8 @@ def test_rank_assignment() -> None: from test_with_dask import _get_client_workers def local_test(worker_id): - with xgb.dask.CommunicatorContext(**args): - for val in args: - sval = val.decode("utf-8") - if sval.startswith("DMLC_TASK_ID"): - task_id = sval - break + with xgb.dask.CommunicatorContext(**args) as ctx: + task_id = ctx["DMLC_TASK_ID"] matched = re.search(".*-([0-9]).*", task_id) rank = xgb.collective.get_rank() # As long as the number of workers is lesser than 10, rank and worker id From aa731a03c2c9f99af5dcce51a263b2b675fc090a Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 23 Sep 2022 12:42:23 -0700 Subject: [PATCH 11/18] fix communicator shutdown --- src/collective/communicator.cc | 2 +- src/collective/communicator.cu | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index ad5de231c6e8..4b45f1e3182a 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -48,7 +48,7 @@ void Communicator::Init(Json const& config) { #ifndef XGBOOST_USE_CUDA void Communicator::Finalize() { communicator_->Shutdown(); - communicator_.reset(nullptr); + communicator_.reset(new NoOpCommunicator()); } #endif diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu index 2485000d9ad4..0880741f9470 100644 --- a/src/collective/communicator.cu +++ b/src/collective/communicator.cu @@ -4,6 +4,7 @@ #include "communicator.h" #include "device_communicator.cuh" #include "device_communicator_adapter.cuh" +#include "noop_communicator.h" #ifdef XGBOOST_USE_NCCL #include "nccl_device_communicator.cuh" #endif @@ -16,7 +17,7 @@ thread_local std::unique_ptr Communicator::device_communicat void Communicator::Finalize() { communicator_->Shutdown(); - communicator_.reset(nullptr); + communicator_.reset(new NoOpCommunicator()); device_ordinal_ = -1; device_communicator_.reset(nullptr); } From 8561853b9595034181e0b47004ecc371daf73011 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 23 Sep 2022 14:10:13 -0700 Subject: [PATCH 12/18] fix dask test --- tests/python/test_with_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 20f5ac20ea72..bdd432a75005 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1276,7 +1276,7 @@ def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None: Xy.save_binary(path) def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None: - with xgb.dask.CommunicatorContext(rabit_args): + with xgb.dask.CommunicatorContext(**rabit_args): rank = xgb.collective.get_rank() path = os.path.join(tmpdir, f"{rank}.bin") Xy = xgb.DMatrix(path) From c89df4725d8466cc0bf6830b8500345eec076ef3 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 26 Sep 2022 11:13:11 -0700 Subject: [PATCH 13/18] reset communicator mocklist --- .../scala/spark/XGBoostCommunicatorRegressionSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala index 1094a89f111a..a7310f1ab5d3 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala @@ -98,6 +98,8 @@ class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest { "rabit_timeout" -> 0)) .fit(training) } + + Communicator.mockList = Array.empty.toList.asJava } } From a9479d9f0b651e95ed05601737d6363f4b1a41d2 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 29 Sep 2022 11:33:02 -0700 Subject: [PATCH 14/18] fix distributed tests --- tests/distributed/test_basic.py | 2 +- tests/distributed/test_federated.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_basic.py b/tests/distributed/test_basic.py index 4e64497b1bfd..f54acafcc4b2 100644 --- a/tests/distributed/test_basic.py +++ b/tests/distributed/test_basic.py @@ -21,4 +21,4 @@ # Save the model, only ask process 0 to save the model. if xgb.collective.get_rank() == 0: bst.save_model("test.model") - xgb.collective.tracker_print("Finished training\n") + xgb.collective.communicator_print("Finished training\n") diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index 4e75cff14d9f..afd968d5d42e 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -24,8 +24,8 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: communicator_env = { 'xgboost_communicator': 'federated', 'federated_server_address': f'localhost:{port}', - 'federated_world_size': f'{world_size}', - 'federated_rank': f'{rank}' + 'federated_world_size': world_size, + 'federated_rank': rank } if with_ssl: communicator_env['federated_server_cert'] = SERVER_CERT From af56b1a26cbf3278c979fc5b63b9f8c95e19d270 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 29 Sep 2022 16:03:17 -0700 Subject: [PATCH 15/18] do not save device communicator --- src/common/quantile.cu | 12 ++++++------ src/common/quantile.cuh | 2 -- src/metric/auc.cu | 14 ++++++-------- src/tree/updater_gpu_hist.cu | 6 ++---- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 39589bf69121..55d5da6df721 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -11,6 +11,8 @@ #include #include +#include "../collective/communicator.h" +#include "../collective/device_communicator.cuh" #include "categorical.h" #include "common.h" #include "device_helpers.cuh" @@ -506,9 +508,7 @@ void SketchContainer::AllReduce() { } timer_.Start(__func__); - if (!communicator_) { - communicator_ = collective::Communicator::GetDevice(device_); - } + auto* communicator = collective::Communicator::GetDevice(device_); // Reduce the overhead on syncing. size_t global_sum_rows = num_rows_; collective::Allreduce(&global_sum_rows, 1); @@ -529,14 +529,14 @@ void SketchContainer::AllReduce() { auto offset = rank * d_columns_ptr.size(); thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(), gathered_ptrs.begin() + offset); - communicator_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size()); + communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size()); // Get the data from all workers. std::vector recv_lengths; dh::caching_device_vector recvbuf; - communicator_->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(), + communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf); - communicator_->Synchronize(); + communicator->Synchronize(); // Segment the received data. auto s_recvbuf = dh::ToSpan(recvbuf); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 0bd45e09e563..373b4e085ef7 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -5,7 +5,6 @@ #include "xgboost/span.h" #include "xgboost/data.h" -#include "../collective/device_communicator.cuh" #include "device_helpers.cuh" #include "quantile.h" #include "timer.h" @@ -38,7 +37,6 @@ class SketchContainer { private: Monitor timer_; - collective::DeviceCommunicator* communicator_; HostDeviceVector feature_types_; bst_row_t num_rows_; bst_feature_t num_columns_; diff --git a/src/metric/auc.cu b/src/metric/auc.cu index aa78002d2ef5..a5b0116f8c00 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -45,9 +45,8 @@ struct DeviceAUCCache { dh::device_vector unique_idx; // p^T: transposed prediction matrix, used by MultiClassAUC dh::device_vector predts_t; - collective::DeviceCommunicator* communicator; - void Init(common::Span predts, bool is_multi, int32_t device) { + void Init(common::Span predts, bool is_multi) { if (sorted_idx.size() != predts.size()) { sorted_idx.resize(predts.size()); fptp.resize(sorted_idx.size()); @@ -57,9 +56,6 @@ struct DeviceAUCCache { predts_t.resize(sorted_idx.size()); } } - if (is_multi && !communicator) { - communicator = collective::Communicator::GetDevice(device); - } } }; @@ -70,7 +66,7 @@ void InitCacheOnce(common::Span predts, int32_t device, if (!cache) { cache.reset(new DeviceAUCCache); } - cache->Init(predts, is_multi, device); + cache->Init(predts, is_multi); } /** @@ -204,8 +200,10 @@ double ScaleClasses(common::Span results, common::Span local_are std::shared_ptr cache, size_t n_classes) { dh::XGBDeviceAllocator alloc; if (collective::IsDistributed()) { - CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice()); - cache->communicator->AllReduceSum(results.data(), results.size()); + int32_t device = dh::CurrentDevice(); + CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device); + auto* communicator = collective::Communicator::GetDevice(device); + communicator->AllReduceSum(results.data(), results.size()); } auto reduce_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 5dfbe042e722..17cb2de6999e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -789,7 +789,6 @@ class GPUHistMaker : public TreeUpdater { void InitDataOnce(DMatrix* dmat) { CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device"; info_ = &dmat->Info(); - communicator_ = collective::Communicator::GetDevice(ctx_->gpu_id); // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); @@ -841,7 +840,8 @@ class GPUHistMaker : public TreeUpdater { monitor_.Stop("InitData"); gpair->SetDevice(ctx_->gpu_id); - maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator_, p_out_position); + auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id); + maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position); } bool UpdatePredictionCache(const DMatrix* data, @@ -867,8 +867,6 @@ class GPUHistMaker : public TreeUpdater { GPUHistMakerTrainParam hist_maker_param_; - collective::DeviceCommunicator* communicator_; - DMatrix* p_last_fmat_{nullptr}; RegTree const* p_last_tree_{nullptr}; ObjInfo task_; From 9be13f8cb0b0f908afaaca8b6d6ae238c311db85 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 30 Sep 2022 09:02:41 -0700 Subject: [PATCH 16/18] fix jvm gpu tests --- src/collective/nccl_device_communicator.cuh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index 4b74f65c940b..e14a2e446ed4 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -59,8 +59,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator { if (communicator_->GetWorldSize() == 1) { return; } - dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); - ncclCommDestroy(nccl_comm_); + if (cuda_stream_) { + dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); + } + if (nccl_comm_) { + dh::safe_nccl(ncclCommDestroy(nccl_comm_)); + } if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { LOG(CONSOLE) << "======== NCCL Statistics========"; LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; From 44a49fdfb19f38cc57204c1c6504ba5372967c2a Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 30 Sep 2022 10:37:27 -0700 Subject: [PATCH 17/18] add python test for federated communicator --- tests/python/test_collective.py | 42 +++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 1b9727ebf05b..f7de0400d21f 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -1,13 +1,13 @@ import multiprocessing import socket import sys +import time import numpy as np import pytest import xgboost as xgb -from xgboost import RabitTracker -from xgboost import collective +from xgboost import RabitTracker, build_info, federated if sys.platform.startswith("win"): pytest.skip("Skipping collective tests on Windows", allow_module_level=True) @@ -37,3 +37,41 @@ def test_rabit_communicator(): for worker in workers: worker.join() assert worker.exitcode == 0 + + +def run_federated_worker(port, world_size, rank): + with xgb.collective.CommunicatorContext(xgboost_communicator='federated', + federated_server_address=f'localhost:{port}', + federated_world_size=world_size, + federated_rank=rank): + assert xgb.collective.get_world_size() == world_size + assert xgb.collective.is_distributed() + assert xgb.collective.get_processor_name() == f'rank{rank}' + ret = xgb.collective.broadcast('test1234', 0) + assert str(ret) == 'test1234' + ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) + assert np.array_equal(ret, np.asarray([2, 4, 6])) + + +def test_federated_communicator(): + if not build_info()["USE_FEDERATED"]: + pytest.skip("XGBoost not built with federated learning enabled") + + port = 9091 + world_size = 2 + server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size)) + server.start() + time.sleep(1) + if not server.is_alive(): + raise Exception("Error starting Federated Learning server") + + workers = [] + for rank in range(world_size): + worker = multiprocessing.Process(target=run_federated_worker, + args=(port, world_size, rank)) + workers.append(worker) + worker.start() + for worker in workers: + worker.join() + assert worker.exitcode == 0 + server.terminate() From 43be5ac8c9a580c59a6a659e1f1c65e6127f1262 Mon Sep 17 00:00:00 2001 From: Hyunsu Philip Cho Date: Wed, 5 Oct 2022 12:16:00 -0700 Subject: [PATCH 18/18] Update gputreeshap submodule --- gputreeshap | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gputreeshap b/gputreeshap index c78fe621e429..acb5be3c17e9 160000 --- a/gputreeshap +++ b/gputreeshap @@ -1 +1 @@ -Subproject commit c78fe621e429117cbca45e7b23eb5c3b6280fa3a +Subproject commit acb5be3c17e9adae34ac0b176da6ea8e197cb17e