diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 69e606d9c3b6..c16e45858415 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,8 +46,14 @@ import org.apache.spark.sql.SparkSession * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented * in Scala without Python components, and with full support of timeouts. * The Scala implementation is currently experimental, use at your own risk. + * + * @param hostIp The Rabit Tracker host IP address which is only used for python implementation. + * This is only needed if the host IP cannot be automatically guessed. + * @param pythonExec The python executed path for Rabit Tracker, + * which is only used for python implementation. */ -case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String ) +case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String, + hostIp: String = "", pythonExec: String = "") object TrackerConf { def apply(): TrackerConf = TrackerConf(0L, "python") @@ -336,13 +342,18 @@ object XGBoost extends Serializable { } } - private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { + /** visiable for testing */ + private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { val tracker: IRabitTracker = trackerConf.trackerImpl match { case "scala" => new RabitTracker(nWorkers) - case "python" => new PyRabitTracker(nWorkers) + case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec) case _ => new PyRabitTracker(nWorkers) } + tracker + } + private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { + val tracker = getTracker(nWorkers, trackerConf) require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker") tracker } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala index e106883ca254..26ea2ef71595 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,11 +24,61 @@ import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus import ml.dmlc.xgboost4j.scala.DMatrix - -import org.scalatest.{FunSuite, Ignore} +import org.scalatest.{FunSuite} class RabitRobustnessSuite extends FunSuite with PerTest { + private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = { + val classifier = new XGBoostClassifier(paramMap) + val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc) + xgbParamsFactory.buildXGBRuntimeParams + } + + + test("Customize host ip and python exec for Rabit tracker") { + val hostIp = "192.168.22.111" + val pythonExec = "/usr/bin/python3" + + val paramMap = Map( + "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(0L, "python", hostIp)) + val xgbExecParams = getXGBoostExecutionParams(paramMap) + val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) + tracker match { + case pyTracker: PyRabitTracker => + val cmd = pyTracker.getRabitTrackerCommand + assert(cmd.contains(hostIp)) + assert(cmd.startsWith("python")) + case _ => assert(false, "expected python tracker implementation") + } + + val paramMap1 = Map( + "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(0L, "python", "", pythonExec)) + val xgbExecParams1 = getXGBoostExecutionParams(paramMap1) + val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf) + tracker1 match { + case pyTracker: PyRabitTracker => + val cmd = pyTracker.getRabitTrackerCommand + assert(cmd.startsWith(pythonExec)) + assert(!cmd.contains(hostIp)) + case _ => assert(false, "expected python tracker implementation") + } + + val paramMap2 = Map( + "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec)) + val xgbExecParams2 = getXGBoostExecutionParams(paramMap2) + val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf) + tracker2 match { + case pyTracker: PyRabitTracker => + val cmd = pyTracker.getRabitTrackerCommand + assert(cmd.startsWith(pythonExec)) + assert(cmd.contains(s" --host-ip=${hostIp}")) + case _ => assert(false, "expected python tracker implementation") + } + } + test("training with Scala-implemented Rabit tracker") { val eval = new EvalError() val training = buildDataFrame(Classification.train) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 23866d5ba1bd..0e94ce69fab1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -30,6 +30,8 @@ public class RabitTracker implements IRabitTracker { private Map envs = new HashMap(); // number of workers to be submitted. private int numWorkers; + private String hostIp = ""; + private String pythonExec = ""; private AtomicReference trackerProcess = new AtomicReference(); static { @@ -85,6 +87,13 @@ public RabitTracker(int numWorkers) this.numWorkers = numWorkers; } + public RabitTracker(int numWorkers, String hostIp, String pythonExec) + throws XGBoostError { + this(numWorkers); + this.hostIp = hostIp; + this.pythonExec = pythonExec; + } + public void uncaughtException(Thread t, Throwable e) { logger.error("Uncaught exception thrown by worker:", e); try { @@ -126,12 +135,34 @@ private void loadEnvs(InputStream ins) throws IOException { } } + /** visible for testing */ + public String getRabitTrackerCommand() { + StringBuilder sb = new StringBuilder(); + if (pythonExec == null || pythonExec.isEmpty()) { + sb.append("python "); + } else { + sb.append(pythonExec + " "); + } + sb.append(" " + tracker_py + " "); + sb.append(" --log-level=DEBUG" + " "); + sb.append(" --num-workers=" + numWorkers + " "); + + // we first check the property then check the parameter + String hostIpFromProperties = trackerProperties.getHostIp(); + if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) { + logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties"); + sb.append(" --host-ip=" + hostIpFromProperties + " "); + } else if (hostIp != null & !hostIp.isEmpty()) { + logger.debug("Using the parametr host-ip: " + hostIp); + sb.append(" --host-ip=" + hostIp + " "); + } + return sb.toString(); + } + private boolean startTrackerProcess() { try { - String trackerExecString = this.addTrackerProperties("python " + tracker_py + - " --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)); - - trackerProcess.set(Runtime.getRuntime().exec(trackerExecString)); + String cmd = getRabitTrackerCommand(); + trackerProcess.set(Runtime.getRuntime().exec(cmd)); loadEnvs(trackerProcess.get().getInputStream()); return true; } catch (IOException ioe) { @@ -140,18 +171,6 @@ private boolean startTrackerProcess() { } } - private String addTrackerProperties(String trackerExecString) { - StringBuilder sb = new StringBuilder(trackerExecString); - String hostIp = trackerProperties.getHostIp(); - - if(hostIp != null && !hostIp.isEmpty()){ - logger.debug("Using provided host-ip: " + hostIp); - sb.append(" --host-ip=").append(hostIp); - } - - return sb.toString(); - } - public void stop() { if (trackerProcess.get() != null) { trackerProcess.get().destroy();