diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8773969c1f1d..8efbb56fcc60 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -76,7 +76,8 @@ private[this] case class XGBoostExecutionParams( earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, cacheTrainingSet: Boolean, treeMethod: Option[String], - isLocal: Boolean) { + isLocal: Boolean, + killSparkContext: Boolean) { private var rawParamMap: Map[String, Any] = _ @@ -220,6 +221,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false) .asInstanceOf[Boolean] + val killSparkContext = overridedParams.getOrElse("kill_spark_context", true) + .asInstanceOf[Boolean] + val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval, missing, allowNonZeroForMissing, trackerConf, timeoutRequestWorkers, @@ -228,7 +232,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s xgbExecEarlyStoppingParams, cacheTrainingSet, treeMethod, - isLocal) + isLocal, + killSparkContext) xgbExecParam.setRawParamMap(overridedParams) xgbExecParam } @@ -588,7 +593,8 @@ object XGBoost extends Serializable { val (booster, metrics) = try { val parallelismTracker = new SparkParallelismTracker(sc, xgbExecParams.timeoutRequestWorkers, - xgbExecParams.numWorkers) + xgbExecParams.numWorkers, + xgbExecParams.killSparkContext) val rabitEnv = tracker.getWorkerEnvs val boostersAndMetrics = if (hasGroup) { trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster, @@ -628,6 +634,9 @@ object XGBoost extends Serializable { case t: Throwable => // if the job was aborted due to an exception, just throw the exception logger.error("the job was aborted due to ", t) + if (xgbExecParams.killSparkContext) { + trainingData.sparkContext.stop() + } throw t } finally { uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 1512c85d07a9..761c962b4654 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -105,8 +105,14 @@ private[spark] trait LearningTaskParams extends Params { final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics) - setDefault(objective -> "reg:squarederror", baseScore -> 0.5, - trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false) + /** + * whether killing SparkContext when training task fails + */ + final val killSparkContext = new BooleanParam(this, "killSparkContext", + "whether killing SparkContext when training task fails") + + setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0, + numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContext -> true) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index db0f66381c04..5baef44c34ac 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -26,11 +26,13 @@ import org.apache.spark.scheduler._ * @param sc The SparkContext object * @param timeout The maximum time to wait for enough number of workers. * @param numWorkers nWorkers used in an XGBoost Job + * @param killSparkContext kill SparkContext or not when task fails */ class SparkParallelismTracker( val sc: SparkContext, timeout: Long, - numWorkers: Int) { + numWorkers: Int, + killSparkContext: Boolean = true) { private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) private[this] val logger = LogFactory.getLog("XGBoostSpark") @@ -58,7 +60,7 @@ class SparkParallelismTracker( } private[this] def safeExecute[T](body: => T): T = { - val listener = new TaskFailedListener + val listener = new TaskFailedListener(killSparkContext) sc.addSparkListener(listener) try { body @@ -90,7 +92,7 @@ class SparkParallelismTracker( } } -private[spark] class TaskFailedListener extends SparkListener { +class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener { private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") @@ -99,28 +101,57 @@ private[spark] class TaskFailedListener extends SparkListener { case taskEndReason: TaskFailedReason => logger.error(s"Training Task Failed during XGBoost Training: " + s"$taskEndReason, cancelling all jobs") - TaskFailedListener.cancelAllJobs() + taskFaultHandle case _ => } } + + private[this] def taskFaultHandle = { + if (killSparkContext == true) { + TaskFailedListener.startedSparkContextKiller() + } else { + TaskFailedListener.cancelAllJobs() + } + } } object TaskFailedListener { - + private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") var cancelJobStarted = false private def cancelAllJobs(): Unit = this.synchronized { if (!cancelJobStarted) { + cancelJobStarted = true val cancelJob = new Thread() { override def run(): Unit = { LiveListenerBus.withinListenerThread.withValue(false) { + logger.info("will call spark cancel all jobs") SparkContext.getOrCreate().cancelAllJobs() + logger.info("after call spark cancel all jobs") } } } cancelJob.setDaemon(true) cancelJob.start() - cancelJobStarted = true + } + } + + var killerStarted = false + + private def startedSparkContextKiller(): Unit = this.synchronized { + if (!killerStarted) { + // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it + // in a separate thread + val sparkContextKiller = new Thread() { + override def run(): Unit = { + LiveListenerBus.withinListenerThread.withValue(false) { + SparkContext.getOrCreate().stop() + } + } + } + sparkContextKiller.setDaemon(true) + sparkContextKiller.start() + killerStarted = true } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index 31f7a6b6f0f2..c802fc0b0448 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") } - private def sparkContextShouldNotShutDown(): Unit = { + private def waitForSparkContextShutdown(): Unit = { var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) { - Thread.sleep(1000) - totalWaitedTime += 1000 + while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) { + Thread.sleep(10000) + totalWaitedTime += 10000 } - assert(ss.sparkContext.isStopped === false) + assert(ss.sparkContext.isStopped === true) } test("fail training elegantly with unsupported objective function") { @@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - sparkContextShouldNotShutDown() + waitForSparkContextShutdown() } } @@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - sparkContextShouldNotShutDown() + waitForSparkContextShutdown() } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 8b23a9faa26c..9154bc7c7055 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -51,6 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => cleanExternalCache(currentSession.sparkContext.appName) currentSession = null } + TaskFailedListener.killerStarted = false TaskFailedListener.cancelJobStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index 8f1dd54bd337..6609b04daab0 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { } test("test rabit timeout fail handle") { - // disable job cancel listener to verify if rabit_timeout take effect and kill tasks - TaskFailedListener.cancelJobStarted = true + // disable spark kill listener to verify if rabit_timeout take effect and kill tasks + TaskFailedListener.killerStarted = true val training = buildDataFrame(Classification.train) // mock rank 0 failure during 8th allreduce synchronization @@ -109,10 +109,35 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { "rabit_timeout" -> 0)) .fit(training) } catch { - case e: Throwable => println("----- " + e)// swallow anything + case e: Throwable => // swallow anything } finally { - // wait 2s to check if SparkContext is killed - assert(waitAndCheckSparkShutdown(2000) == false) + // assume all tasks throw exception almost same time + // 100ms should be enough to exhaust all retries + assert(waitAndCheckSparkShutdown(100) == true) + } + } + + test("test SparkContext should not be killed ") { + val training = buildDataFrame(Classification.train) + // mock rank 0 failure during 8th allreduce synchronization + Rabit.mockList = Array("0,8,0,0").toList.asJava + + try { + new XGBoostClassifier(Map( + "eta" -> "0.1", + "max_depth" -> "10", + "verbosity" -> "1", + "objective" -> "binary:logistic", + "num_round" -> 5, + "num_workers" -> numWorkers, + "kill_spark_context" -> false, + "rabit_timeout" -> 0)) + .fit(training) + } catch { + case e: Throwable => // swallow anything + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index 7f344674f12b..804740fcda39 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -34,6 +34,15 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { .config("spark.driver.memory", "512m") .config("spark.task.cpus", 1) + private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = { + var totalWaitedTime = 0L + while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) { + Thread.sleep(100) + totalWaitedTime += 100 + } + ss.sparkContext.isStopped + } + test("tracker should not affect execution result when timeout is not larger than 0") { val nWorkers = numParallelism val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) @@ -74,4 +83,27 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } } + + test("tracker should not kill SparkContext when killSparkContext=false") { + val nWorkers = numParallelism + val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers) + val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) + try { + tracker.execute { + rdd.map { i => + val partitionId = TaskContext.get().partitionId() + if (partitionId == 0) { + throw new RuntimeException("mocking task failing") + } + i + }.sum() + } + } catch { + case e: Exception => // catch the exception + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) + } + } + }