diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8cfdf7dba792..66cce0c32bfe 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -76,7 +76,8 @@ private[this] case class XGBoostExecutionParams( earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, cacheTrainingSet: Boolean, treeMethod: Option[String], - isLocal: Boolean) { + isLocal: Boolean, + killSparkContextOnWorkerFailure: Boolean) { private var rawParamMap: Map[String, Any] = _ @@ -220,6 +221,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false) .asInstanceOf[Boolean] + val killSparkContext = overridedParams.getOrElse("kill_spark_context_on_worker_failure", true) + .asInstanceOf[Boolean] + val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval, missing, allowNonZeroForMissing, trackerConf, timeoutRequestWorkers, @@ -228,7 +232,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s xgbExecEarlyStoppingParams, cacheTrainingSet, treeMethod, - isLocal) + isLocal, + killSparkContext) xgbExecParam.setRawParamMap(overridedParams) xgbExecParam } @@ -588,7 +593,8 @@ object XGBoost extends Serializable { val (booster, metrics) = try { val parallelismTracker = new SparkParallelismTracker(sc, xgbExecParams.timeoutRequestWorkers, - xgbExecParams.numWorkers) + xgbExecParams.numWorkers, + xgbExecParams.killSparkContextOnWorkerFailure) val rabitEnv = tracker.getWorkerEnvs val boostersAndMetrics = if (hasGroup) { trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster, @@ -628,7 +634,9 @@ object XGBoost extends Serializable { case t: Throwable => // if the job was aborted due to an exception logger.error("the job was aborted due to ", t) - trainingData.sparkContext.stop() + if (xgbExecParams.killSparkContextOnWorkerFailure) { + trainingData.sparkContext.stop() + } throw t } finally { uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 1512c85d07a9..aba8d45f3c69 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -105,8 +105,14 @@ private[spark] trait LearningTaskParams extends Params { final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics) - setDefault(objective -> "reg:squarederror", baseScore -> 0.5, - trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false) + /** + * whether killing SparkContext when training task fails + */ + final val killSparkContextOnWorkerFailure = new BooleanParam(this, + "killSparkContextOnWorkerFailure", "whether killing SparkContext when training task fails") + + setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0, + numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContextOnWorkerFailure -> true) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index bd0ab4d6dd01..3e514ebd885b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -19,6 +19,8 @@ package org.apache.spark import org.apache.commons.logging.LogFactory import org.apache.spark.scheduler._ +import scala.collection.mutable.{HashMap, HashSet} + /** * A tracker that ensures enough number of executor cores are alive. * Throws an exception when the number of alive cores is less than nWorkers. @@ -26,11 +28,13 @@ import org.apache.spark.scheduler._ * @param sc The SparkContext object * @param timeout The maximum time to wait for enough number of workers. * @param numWorkers nWorkers used in an XGBoost Job + * @param killSparkContextOnWorkerFailure kill SparkContext or not when task fails */ class SparkParallelismTracker( val sc: SparkContext, timeout: Long, - numWorkers: Int) { + numWorkers: Int, + killSparkContextOnWorkerFailure: Boolean = true) { private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) private[this] val logger = LogFactory.getLog("XGBoostSpark") @@ -58,7 +62,7 @@ class SparkParallelismTracker( } private[this] def safeExecute[T](body: => T): T = { - val listener = new TaskFailedListener + val listener = new TaskFailedListener(killSparkContextOnWorkerFailure) sc.addSparkListener(listener) try { body @@ -79,7 +83,7 @@ class SparkParallelismTracker( def execute[T](body: => T): T = { if (timeout <= 0) { logger.info("starting training without setting timeout for waiting for resources") - body + safeExecute(body) } else { logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") if (!waitForCondition(numAliveCores >= requestedCores, timeout)) { @@ -90,16 +94,51 @@ class SparkParallelismTracker( } } -private[spark] class TaskFailedListener extends SparkListener { +class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener { private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") + // {jobId, [stageId0, stageId1, ...] } + // keep track of the mapping of job id and stage ids + // when a task fails, find the job id and stage id the task belongs to, finally + // cancel the jobs + private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + if (!killSparkContext) { + jobStart.stageIds.foreach(stageId => { + jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId + }) + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (!killSparkContext) { + jobIdToStageIds.remove(jobEnd.jobId) + } + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { taskEnd.reason match { case taskEndReason: TaskFailedReason => logger.error(s"Training Task Failed during XGBoost Training: " + - s"$taskEndReason, stopping SparkContext") - TaskFailedListener.startedSparkContextKiller() + s"$taskEndReason") + if (killSparkContext) { + logger.error("killing SparkContext") + TaskFailedListener.startedSparkContextKiller() + } else { + val stageId = taskEnd.stageId + // find job ids according to stage id and then cancel the job + + jobIdToStageIds.foreach { + case (jobId, stageIds) => + if (stageIds.contains(stageId)) { + logger.error("Cancelling jobId:" + jobId) + jobIdToStageIds.remove(jobId) + SparkContext.getOrCreate().cancelJob(jobId) + } + } + } case _ => } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index d1b6ec0f9acc..e1d58f26ed50 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -116,4 +116,28 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { assert(waitAndCheckSparkShutdown(100) == true) } } + + test("test SparkContext should not be killed ") { + val training = buildDataFrame(Classification.train) + // mock rank 0 failure during 8th allreduce synchronization + Rabit.mockList = Array("0,8,0,0").toList.asJava + + try { + new XGBoostClassifier(Map( + "eta" -> "0.1", + "max_depth" -> "10", + "verbosity" -> "1", + "objective" -> "binary:logistic", + "num_round" -> 5, + "num_workers" -> numWorkers, + "kill_spark_context_on_worker_failure" -> false, + "rabit_timeout" -> 0)) + .fit(training) + } catch { + case e: Throwable => // swallow anything + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) + } + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index 7f344674f12b..cb8fa579476a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -34,6 +34,15 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { .config("spark.driver.memory", "512m") .config("spark.task.cpus", 1) + private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = { + var totalWaitedTime = 0L + while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) { + Thread.sleep(100) + totalWaitedTime += 100 + } + ss.sparkContext.isStopped + } + test("tracker should not affect execution result when timeout is not larger than 0") { val nWorkers = numParallelism val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) @@ -74,4 +83,69 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } } + + test("tracker should not kill SparkContext when killSparkContextOnWorkerFailure=false") { + val nWorkers = numParallelism + val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) + val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers) + try { + tracker.execute { + rdd.map { i => + val partitionId = TaskContext.get().partitionId() + if (partitionId == 0) { + throw new RuntimeException("mocking task failing") + } + i + }.sum() + } + } catch { + case e: Exception => // catch the exception + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) + } + } + + test("tracker should cancel the correct job when killSparkContextOnWorkerFailure=false") { + val nWorkers = 2 + val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) + val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers) + val thread = new TestThread(sc) + thread.start() + try { + tracker.execute { + rdd.map { i => + Thread.sleep(100) + val partitionId = TaskContext.get().partitionId() + if (partitionId == 0) { + throw new RuntimeException("mocking task failing") + } + i + }.sum() + } + } catch { + case e: Exception => // catch the exception + } finally { + thread.join(8000) + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) + } + } + + private[this] class TestThread(sc: SparkContext) extends Thread { + override def run(): Unit = { + var sum: Double = 0.0f + try { + val rdd = sc.parallelize(1 to 4, 2) + sum = rdd.mapPartitions(iter => { + // sleep 2s to ensure task is alive when cancelling other jobs + Thread.sleep(2000) + iter + }).sum() + } finally { + // get the correct result + assert(sum.toInt == 10) + } + } + } }