Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] add hostIp and python exec for rabit tracker #7808

Merged
merged 1 commit into from Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,8 +46,14 @@ import org.apache.spark.sql.SparkSession
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String )
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
hostIp: String = "", pythonExec: String = "")

object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
Expand Down Expand Up @@ -336,13 +342,18 @@ object XGBoost extends Serializable {
}
}

private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
case _ => new PyRabitTracker(nWorkers)
}
tracker
}

private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker = getTracker(nWorkers, trackerConf)
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
tracker
}
Expand Down
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -24,11 +24,61 @@ import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.DMatrix

import org.scalatest.{FunSuite, Ignore}
import org.scalatest.{FunSuite}

class RabitRobustnessSuite extends FunSuite with PerTest {

private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
val classifier = new XGBoostClassifier(paramMap)
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
xgbParamsFactory.buildXGBRuntimeParams
}


test("Customize host ip and python exec for Rabit tracker") {
val hostIp = "192.168.22.111"
val pythonExec = "/usr/bin/python3"

val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.contains(hostIp))
assert(cmd.startsWith("python"))
case _ => assert(false, "expected python tracker implementation")
}

val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(!cmd.contains(hostIp))
case _ => assert(false, "expected python tracker implementation")
}

val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(cmd.contains(s" --host-ip=${hostIp}"))
case _ => assert(false, "expected python tracker implementation")
}
}

test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
Expand Down
Expand Up @@ -30,6 +30,8 @@ public class RabitTracker implements IRabitTracker {
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int numWorkers;
private String hostIp = "";
private String pythonExec = "";
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();

static {
Expand Down Expand Up @@ -85,6 +87,13 @@ public RabitTracker(int numWorkers)
this.numWorkers = numWorkers;
}

public RabitTracker(int numWorkers, String hostIp, String pythonExec)
throws XGBoostError {
this(numWorkers);
this.hostIp = hostIp;
this.pythonExec = pythonExec;
}

public void uncaughtException(Thread t, Throwable e) {
logger.error("Uncaught exception thrown by worker:", e);
try {
Expand Down Expand Up @@ -126,12 +135,34 @@ private void loadEnvs(InputStream ins) throws IOException {
}
}

/** visible for testing */
public String getRabitTrackerCommand() {
StringBuilder sb = new StringBuilder();
if (pythonExec == null || pythonExec.isEmpty()) {
sb.append("python ");
} else {
sb.append(pythonExec + " ");
}
sb.append(" " + tracker_py + " ");
sb.append(" --log-level=DEBUG" + " ");
sb.append(" --num-workers=" + numWorkers + " ");

// we first check the property then check the parameter
String hostIpFromProperties = trackerProperties.getHostIp();
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
sb.append(" --host-ip=" + hostIpFromProperties + " ");
} else if (hostIp != null & !hostIp.isEmpty()) {
logger.debug("Using the parametr host-ip: " + hostIp);
sb.append(" --host-ip=" + hostIp + " ");
}
return sb.toString();
}

private boolean startTrackerProcess() {
try {
String trackerExecString = this.addTrackerProperties("python " + tracker_py +
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers));

trackerProcess.set(Runtime.getRuntime().exec(trackerExecString));
String cmd = getRabitTrackerCommand();
trackerProcess.set(Runtime.getRuntime().exec(cmd));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
Expand All @@ -140,18 +171,6 @@ private boolean startTrackerProcess() {
}
}

private String addTrackerProperties(String trackerExecString) {
StringBuilder sb = new StringBuilder(trackerExecString);
String hostIp = trackerProperties.getHostIp();

if(hostIp != null && !hostIp.isEmpty()){
logger.debug("Using provided host-ip: " + hostIp);
sb.append(" --host-ip=").append(hostIp);
}

return sb.toString();
}

public void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
Expand Down