From 2a2b47af7568b1c8a8d5e82b173a0cd624d9a28e Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Thu, 13 Aug 2020 11:36:54 +0800 Subject: [PATCH 1/5] cancel job instead of killing SparkContext This PR changes the default behavior that kills SparkContext. Instead, This PR cancels jobs when coming across task failed. That means the SparkContext is still alive even some exceptions happen. --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 3 +-- .../spark/SparkParallelismTracker.scala | 25 +++++++++---------- .../scala/spark/ParameterSuite.scala | 14 +++++------ .../dmlc/xgboost4j/scala/spark/PerTest.scala | 2 +- .../spark/XGBoostRabitRegressionSuite.scala | 11 ++++---- 5 files changed, 26 insertions(+), 29 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8cfdf7dba792..8773969c1f1d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -626,9 +626,8 @@ object XGBoost extends Serializable { (booster, metrics) } catch { case t: Throwable => - // if the job was aborted due to an exception + // if the job was aborted due to an exception, just throw the exception logger.error("the job was aborted due to ", t) - trainingData.sparkContext.stop() throw t } finally { uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index bd0ab4d6dd01..db0f66381c04 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -79,7 +79,7 @@ class SparkParallelismTracker( def execute[T](body: => T): T = { if (timeout <= 0) { logger.info("starting training without setting timeout for waiting for resources") - body + safeExecute(body) } else { logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") if (!waitForCondition(numAliveCores >= requestedCores, timeout)) { @@ -98,8 +98,8 @@ private[spark] class TaskFailedListener extends SparkListener { taskEnd.reason match { case taskEndReason: TaskFailedReason => logger.error(s"Training Task Failed during XGBoost Training: " + - s"$taskEndReason, stopping SparkContext") - TaskFailedListener.startedSparkContextKiller() + s"$taskEndReason, cancelling all jobs") + TaskFailedListener.cancelAllJobs() case _ => } } @@ -107,22 +107,21 @@ private[spark] class TaskFailedListener extends SparkListener { object TaskFailedListener { - var killerStarted = false + var cancelJobStarted = false - private def startedSparkContextKiller(): Unit = this.synchronized { - if (!killerStarted) { - // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it - // in a separate thread - val sparkContextKiller = new Thread() { + private def cancelAllJobs(): Unit = this.synchronized { + if (!cancelJobStarted) { + val cancelJob = new Thread() { override def run(): Unit = { LiveListenerBus.withinListenerThread.withValue(false) { - SparkContext.getOrCreate().stop() + SparkContext.getOrCreate().cancelAllJobs() } } } - sparkContextKiller.setDaemon(true) - sparkContextKiller.start() - killerStarted = true + cancelJob.setDaemon(true) + cancelJob.start() + cancelJobStarted = true } } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index c802fc0b0448..31f7a6b6f0f2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") } - private def waitForSparkContextShutdown(): Unit = { + private def sparkContextShouldNotShutDown(): Unit = { var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) { - Thread.sleep(10000) - totalWaitedTime += 10000 + while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) { + Thread.sleep(1000) + totalWaitedTime += 1000 } - assert(ss.sparkContext.isStopped === true) + assert(ss.sparkContext.isStopped === false) } test("fail training elegantly with unsupported objective function") { @@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - waitForSparkContextShutdown() + sparkContextShouldNotShutDown() } } @@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - waitForSparkContextShutdown() + sparkContextShouldNotShutDown() } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 341db97bc447..8b23a9faa26c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -51,7 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => cleanExternalCache(currentSession.sparkContext.appName) currentSession = null } - TaskFailedListener.killerStarted = false + TaskFailedListener.cancelJobStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index d1b6ec0f9acc..8f1dd54bd337 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { } test("test rabit timeout fail handle") { - // disable spark kill listener to verify if rabit_timeout take effect and kill tasks - TaskFailedListener.killerStarted = true + // disable job cancel listener to verify if rabit_timeout take effect and kill tasks + TaskFailedListener.cancelJobStarted = true val training = buildDataFrame(Classification.train) // mock rank 0 failure during 8th allreduce synchronization @@ -109,11 +109,10 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { "rabit_timeout" -> 0)) .fit(training) } catch { - case e: Throwable => // swallow anything + case e: Throwable => println("----- " + e)// swallow anything } finally { - // assume all tasks throw exception almost same time - // 100ms should be enough to exhaust all retries - assert(waitAndCheckSparkShutdown(100) == true) + // wait 2s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(2000) == false) } } } From f8345378be5e6cacb6ff0d2607f79f9140091f93 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Tue, 18 Aug 2020 15:23:57 +0800 Subject: [PATCH 2/5] add a parameter to control if killing SparkContext --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 15 +++++-- .../spark/params/LearningTaskParams.scala | 10 ++++- .../spark/SparkParallelismTracker.scala | 43 ++++++++++++++++--- .../scala/spark/ParameterSuite.scala | 14 +++--- .../dmlc/xgboost4j/scala/spark/PerTest.scala | 1 + .../spark/XGBoostRabitRegressionSuite.scala | 35 ++++++++++++--- .../spark/SparkParallelismTrackerSuite.scala | 32 ++++++++++++++ 7 files changed, 127 insertions(+), 23 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8773969c1f1d..8efbb56fcc60 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -76,7 +76,8 @@ private[this] case class XGBoostExecutionParams( earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, cacheTrainingSet: Boolean, treeMethod: Option[String], - isLocal: Boolean) { + isLocal: Boolean, + killSparkContext: Boolean) { private var rawParamMap: Map[String, Any] = _ @@ -220,6 +221,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false) .asInstanceOf[Boolean] + val killSparkContext = overridedParams.getOrElse("kill_spark_context", true) + .asInstanceOf[Boolean] + val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval, missing, allowNonZeroForMissing, trackerConf, timeoutRequestWorkers, @@ -228,7 +232,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s xgbExecEarlyStoppingParams, cacheTrainingSet, treeMethod, - isLocal) + isLocal, + killSparkContext) xgbExecParam.setRawParamMap(overridedParams) xgbExecParam } @@ -588,7 +593,8 @@ object XGBoost extends Serializable { val (booster, metrics) = try { val parallelismTracker = new SparkParallelismTracker(sc, xgbExecParams.timeoutRequestWorkers, - xgbExecParams.numWorkers) + xgbExecParams.numWorkers, + xgbExecParams.killSparkContext) val rabitEnv = tracker.getWorkerEnvs val boostersAndMetrics = if (hasGroup) { trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster, @@ -628,6 +634,9 @@ object XGBoost extends Serializable { case t: Throwable => // if the job was aborted due to an exception, just throw the exception logger.error("the job was aborted due to ", t) + if (xgbExecParams.killSparkContext) { + trainingData.sparkContext.stop() + } throw t } finally { uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 1512c85d07a9..761c962b4654 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -105,8 +105,14 @@ private[spark] trait LearningTaskParams extends Params { final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics) - setDefault(objective -> "reg:squarederror", baseScore -> 0.5, - trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false) + /** + * whether killing SparkContext when training task fails + */ + final val killSparkContext = new BooleanParam(this, "killSparkContext", + "whether killing SparkContext when training task fails") + + setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0, + numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContext -> true) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index db0f66381c04..5baef44c34ac 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -26,11 +26,13 @@ import org.apache.spark.scheduler._ * @param sc The SparkContext object * @param timeout The maximum time to wait for enough number of workers. * @param numWorkers nWorkers used in an XGBoost Job + * @param killSparkContext kill SparkContext or not when task fails */ class SparkParallelismTracker( val sc: SparkContext, timeout: Long, - numWorkers: Int) { + numWorkers: Int, + killSparkContext: Boolean = true) { private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) private[this] val logger = LogFactory.getLog("XGBoostSpark") @@ -58,7 +60,7 @@ class SparkParallelismTracker( } private[this] def safeExecute[T](body: => T): T = { - val listener = new TaskFailedListener + val listener = new TaskFailedListener(killSparkContext) sc.addSparkListener(listener) try { body @@ -90,7 +92,7 @@ class SparkParallelismTracker( } } -private[spark] class TaskFailedListener extends SparkListener { +class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener { private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") @@ -99,28 +101,57 @@ private[spark] class TaskFailedListener extends SparkListener { case taskEndReason: TaskFailedReason => logger.error(s"Training Task Failed during XGBoost Training: " + s"$taskEndReason, cancelling all jobs") - TaskFailedListener.cancelAllJobs() + taskFaultHandle case _ => } } + + private[this] def taskFaultHandle = { + if (killSparkContext == true) { + TaskFailedListener.startedSparkContextKiller() + } else { + TaskFailedListener.cancelAllJobs() + } + } } object TaskFailedListener { - + private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") var cancelJobStarted = false private def cancelAllJobs(): Unit = this.synchronized { if (!cancelJobStarted) { + cancelJobStarted = true val cancelJob = new Thread() { override def run(): Unit = { LiveListenerBus.withinListenerThread.withValue(false) { + logger.info("will call spark cancel all jobs") SparkContext.getOrCreate().cancelAllJobs() + logger.info("after call spark cancel all jobs") } } } cancelJob.setDaemon(true) cancelJob.start() - cancelJobStarted = true + } + } + + var killerStarted = false + + private def startedSparkContextKiller(): Unit = this.synchronized { + if (!killerStarted) { + // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it + // in a separate thread + val sparkContextKiller = new Thread() { + override def run(): Unit = { + LiveListenerBus.withinListenerThread.withValue(false) { + SparkContext.getOrCreate().stop() + } + } + } + sparkContextKiller.setDaemon(true) + sparkContextKiller.start() + killerStarted = true } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index 31f7a6b6f0f2..c802fc0b0448 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") } - private def sparkContextShouldNotShutDown(): Unit = { + private def waitForSparkContextShutdown(): Unit = { var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) { - Thread.sleep(1000) - totalWaitedTime += 1000 + while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) { + Thread.sleep(10000) + totalWaitedTime += 10000 } - assert(ss.sparkContext.isStopped === false) + assert(ss.sparkContext.isStopped === true) } test("fail training elegantly with unsupported objective function") { @@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - sparkContextShouldNotShutDown() + waitForSparkContextShutdown() } } @@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - sparkContextShouldNotShutDown() + waitForSparkContextShutdown() } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 8b23a9faa26c..9154bc7c7055 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -51,6 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => cleanExternalCache(currentSession.sparkContext.appName) currentSession = null } + TaskFailedListener.killerStarted = false TaskFailedListener.cancelJobStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index 8f1dd54bd337..6609b04daab0 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { } test("test rabit timeout fail handle") { - // disable job cancel listener to verify if rabit_timeout take effect and kill tasks - TaskFailedListener.cancelJobStarted = true + // disable spark kill listener to verify if rabit_timeout take effect and kill tasks + TaskFailedListener.killerStarted = true val training = buildDataFrame(Classification.train) // mock rank 0 failure during 8th allreduce synchronization @@ -109,10 +109,35 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { "rabit_timeout" -> 0)) .fit(training) } catch { - case e: Throwable => println("----- " + e)// swallow anything + case e: Throwable => // swallow anything } finally { - // wait 2s to check if SparkContext is killed - assert(waitAndCheckSparkShutdown(2000) == false) + // assume all tasks throw exception almost same time + // 100ms should be enough to exhaust all retries + assert(waitAndCheckSparkShutdown(100) == true) + } + } + + test("test SparkContext should not be killed ") { + val training = buildDataFrame(Classification.train) + // mock rank 0 failure during 8th allreduce synchronization + Rabit.mockList = Array("0,8,0,0").toList.asJava + + try { + new XGBoostClassifier(Map( + "eta" -> "0.1", + "max_depth" -> "10", + "verbosity" -> "1", + "objective" -> "binary:logistic", + "num_round" -> 5, + "num_workers" -> numWorkers, + "kill_spark_context" -> false, + "rabit_timeout" -> 0)) + .fit(training) + } catch { + case e: Throwable => // swallow anything + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index 7f344674f12b..804740fcda39 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -34,6 +34,15 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { .config("spark.driver.memory", "512m") .config("spark.task.cpus", 1) + private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = { + var totalWaitedTime = 0L + while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) { + Thread.sleep(100) + totalWaitedTime += 100 + } + ss.sparkContext.isStopped + } + test("tracker should not affect execution result when timeout is not larger than 0") { val nWorkers = numParallelism val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) @@ -74,4 +83,27 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } } + + test("tracker should not kill SparkContext when killSparkContext=false") { + val nWorkers = numParallelism + val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers) + val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) + try { + tracker.execute { + rdd.map { i => + val partitionId = TaskContext.get().partitionId() + if (partitionId == 0) { + throw new RuntimeException("mocking task failing") + } + i + }.sum() + } + } catch { + case e: Exception => // catch the exception + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) + } + } + } From 8852f595f7d807143af0cc76c13d863f6cbdc18a Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 21 Aug 2020 17:52:00 +0800 Subject: [PATCH 3/5] cancel the jobs the failed task belongs to --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 2 +- .../spark/SparkParallelismTracker.scala | 70 +++++++++++-------- .../dmlc/xgboost4j/scala/spark/PerTest.scala | 1 - .../spark/SparkParallelismTrackerSuite.scala | 42 +++++++++++ 4 files changed, 82 insertions(+), 33 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8efbb56fcc60..f7bf437d01d9 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -632,7 +632,7 @@ object XGBoost extends Serializable { (booster, metrics) } catch { case t: Throwable => - // if the job was aborted due to an exception, just throw the exception + // if the job was aborted due to an exception logger.error("the job was aborted due to ", t) if (xgbExecParams.killSparkContext) { trainingData.sparkContext.stop() diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 5baef44c34ac..94cfeaefe293 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -19,6 +19,8 @@ package org.apache.spark import org.apache.commons.logging.LogFactory import org.apache.spark.scheduler._ +import scala.collection.mutable.{HashMap, HashSet} + /** * A tracker that ensures enough number of executor cores are alive. * Throws an exception when the number of alive cores is less than nWorkers. @@ -96,46 +98,53 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - taskEnd.reason match { - case taskEndReason: TaskFailedReason => - logger.error(s"Training Task Failed during XGBoost Training: " + - s"$taskEndReason, cancelling all jobs") - taskFaultHandle - case _ => + // {jobId, [stageId0, stageId1, ...] } + // keep track of the mapping of job id and stage ids + // when a task failed, find the job id and stage Id the task belongs to, finally + // cancel the jobs + private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + if (!killSparkContext) { + jobStart.stageIds.foreach(stageId => { + jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId + }) } } - private[this] def taskFaultHandle = { - if (killSparkContext == true) { - TaskFailedListener.startedSparkContextKiller() - } else { - TaskFailedListener.cancelAllJobs() + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (!killSparkContext) { + jobIdToStageIds.remove(jobEnd.jobId) } } -} - -object TaskFailedListener { - private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") - var cancelJobStarted = false - private def cancelAllJobs(): Unit = this.synchronized { - if (!cancelJobStarted) { - cancelJobStarted = true - val cancelJob = new Thread() { - override def run(): Unit = { - LiveListenerBus.withinListenerThread.withValue(false) { - logger.info("will call spark cancel all jobs") - SparkContext.getOrCreate().cancelAllJobs() - logger.info("after call spark cancel all jobs") - } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEnd.reason match { + case taskEndReason: TaskFailedReason => + logger.error(s"Training Task Failed during XGBoost Training: " + + s"$taskEndReason") + if (killSparkContext) { + logger.error("killing SparkContext") + TaskFailedListener.startedSparkContextKiller() + } else { + val stageId = taskEnd.stageId + // find job ids according to stage id and then cancel the job + jobIdToStageIds.foreach(t => { + val jobId = t._1 + val stageIds = t._2 + + if (stageIds.contains(stageId)) { + logger.error("Cancelling jobId:" + jobId) + SparkContext.getOrCreate().cancelJob(jobId) + } + }) } - } - cancelJob.setDaemon(true) - cancelJob.start() + case _ => } } +} +object TaskFailedListener { var killerStarted = false private def startedSparkContextKiller(): Unit = this.synchronized { @@ -154,5 +163,4 @@ object TaskFailedListener { killerStarted = true } } - } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 9154bc7c7055..341db97bc447 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -52,7 +52,6 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => currentSession = null } TaskFailedListener.killerStarted = false - TaskFailedListener.cancelJobStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index 804740fcda39..dadadeb219f4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -86,11 +86,36 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { test("tracker should not kill SparkContext when killSparkContext=false") { val nWorkers = numParallelism + val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers) + try { + tracker.execute { + rdd.map { i => + val partitionId = TaskContext.get().partitionId() + if (partitionId == 0) { + throw new RuntimeException("mocking task failing") + } + i + }.sum() + } + } catch { + case e: Exception => // catch the exception + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) + } + } + + test("tracker should cancel correct job when killSparkContext=false") { + val nWorkers = 2 val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) + val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers) + val thread = new MyThread(sc) + thread.start() try { tracker.execute { rdd.map { i => + Thread.sleep(100) val partitionId = TaskContext.get().partitionId() if (partitionId == 0) { throw new RuntimeException("mocking task failing") @@ -101,9 +126,26 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } catch { case e: Exception => // catch the exception } finally { + thread.join(8000) // wait 3s to check if SparkContext is killed assert(waitAndCheckSparkShutdown(3000) == false) } } + private[this] class MyThread(sc: SparkContext) extends Thread { + override def run(): Unit = { + var sum: Double = 0.0f + try { + val rdd = sc.parallelize(1 to 4, 2) + sum = rdd.mapPartitions(iter => { + // sleep 2s to ensure task is alive when cancelling other jobs + Thread.sleep(2000) + iter + }).sum() + } finally { + // get the correct result + assert(sum.toInt == 10) + } + } + } } From f1df786f401551b524da71ea26136f57dcf611c7 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 26 Aug 2020 08:39:44 +0800 Subject: [PATCH 4/5] remove the jobId from the map when one job failed. --- .../main/scala/org/apache/spark/SparkParallelismTracker.scala | 1 + .../scala/org/apache/spark/SparkParallelismTrackerSuite.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 94cfeaefe293..17ec6a9102e8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -135,6 +135,7 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener if (stageIds.contains(stageId)) { logger.error("Cancelling jobId:" + jobId) + jobIdToStageIds.remove(jobId) SparkContext.getOrCreate().cancelJob(jobId) } }) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index dadadeb219f4..eab81f02a8d4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -106,7 +106,7 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } - test("tracker should cancel correct job when killSparkContext=false") { + test("tracker should cancel the correct job when killSparkContext=false") { val nWorkers = 2 val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers) From 2824d66d93ca668481e7b6e1d7a131e153e9c86c Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 31 Aug 2020 11:48:25 +0800 Subject: [PATCH 5/5] resolve comments --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 8 +++--- .../spark/params/LearningTaskParams.scala | 6 ++-- .../spark/SparkParallelismTracker.scala | 28 +++++++++---------- .../spark/XGBoostRabitRegressionSuite.scala | 2 +- .../spark/SparkParallelismTrackerSuite.scala | 8 +++--- 5 files changed, 26 insertions(+), 26 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index f7bf437d01d9..66cce0c32bfe 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -77,7 +77,7 @@ private[this] case class XGBoostExecutionParams( cacheTrainingSet: Boolean, treeMethod: Option[String], isLocal: Boolean, - killSparkContext: Boolean) { + killSparkContextOnWorkerFailure: Boolean) { private var rawParamMap: Map[String, Any] = _ @@ -221,7 +221,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false) .asInstanceOf[Boolean] - val killSparkContext = overridedParams.getOrElse("kill_spark_context", true) + val killSparkContext = overridedParams.getOrElse("kill_spark_context_on_worker_failure", true) .asInstanceOf[Boolean] val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval, @@ -594,7 +594,7 @@ object XGBoost extends Serializable { val parallelismTracker = new SparkParallelismTracker(sc, xgbExecParams.timeoutRequestWorkers, xgbExecParams.numWorkers, - xgbExecParams.killSparkContext) + xgbExecParams.killSparkContextOnWorkerFailure) val rabitEnv = tracker.getWorkerEnvs val boostersAndMetrics = if (hasGroup) { trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster, @@ -634,7 +634,7 @@ object XGBoost extends Serializable { case t: Throwable => // if the job was aborted due to an exception logger.error("the job was aborted due to ", t) - if (xgbExecParams.killSparkContext) { + if (xgbExecParams.killSparkContextOnWorkerFailure) { trainingData.sparkContext.stop() } throw t diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 761c962b4654..aba8d45f3c69 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -108,11 +108,11 @@ private[spark] trait LearningTaskParams extends Params { /** * whether killing SparkContext when training task fails */ - final val killSparkContext = new BooleanParam(this, "killSparkContext", - "whether killing SparkContext when training task fails") + final val killSparkContextOnWorkerFailure = new BooleanParam(this, + "killSparkContextOnWorkerFailure", "whether killing SparkContext when training task fails") setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0, - numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContext -> true) + numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContextOnWorkerFailure -> true) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 17ec6a9102e8..3e514ebd885b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -28,13 +28,13 @@ import scala.collection.mutable.{HashMap, HashSet} * @param sc The SparkContext object * @param timeout The maximum time to wait for enough number of workers. * @param numWorkers nWorkers used in an XGBoost Job - * @param killSparkContext kill SparkContext or not when task fails + * @param killSparkContextOnWorkerFailure kill SparkContext or not when task fails */ class SparkParallelismTracker( val sc: SparkContext, timeout: Long, numWorkers: Int, - killSparkContext: Boolean = true) { + killSparkContextOnWorkerFailure: Boolean = true) { private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) private[this] val logger = LogFactory.getLog("XGBoostSpark") @@ -62,7 +62,7 @@ class SparkParallelismTracker( } private[this] def safeExecute[T](body: => T): T = { - val listener = new TaskFailedListener(killSparkContext) + val listener = new TaskFailedListener(killSparkContextOnWorkerFailure) sc.addSparkListener(listener) try { body @@ -100,7 +100,7 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener // {jobId, [stageId0, stageId1, ...] } // keep track of the mapping of job id and stage ids - // when a task failed, find the job id and stage Id the task belongs to, finally + // when a task fails, find the job id and stage id the task belongs to, finally // cancel the jobs private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty @@ -129,16 +129,15 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener } else { val stageId = taskEnd.stageId // find job ids according to stage id and then cancel the job - jobIdToStageIds.foreach(t => { - val jobId = t._1 - val stageIds = t._2 - - if (stageIds.contains(stageId)) { - logger.error("Cancelling jobId:" + jobId) - jobIdToStageIds.remove(jobId) - SparkContext.getOrCreate().cancelJob(jobId) - } - }) + + jobIdToStageIds.foreach { + case (jobId, stageIds) => + if (stageIds.contains(stageId)) { + logger.error("Cancelling jobId:" + jobId) + jobIdToStageIds.remove(jobId) + SparkContext.getOrCreate().cancelJob(jobId) + } + } } case _ => } @@ -146,6 +145,7 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener } object TaskFailedListener { + var killerStarted = false private def startedSparkContextKiller(): Unit = this.synchronized { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index 6609b04daab0..e1d58f26ed50 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -130,7 +130,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers, - "kill_spark_context" -> false, + "kill_spark_context_on_worker_failure" -> false, "rabit_timeout" -> 0)) .fit(training) } catch { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index eab81f02a8d4..cb8fa579476a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -84,7 +84,7 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } - test("tracker should not kill SparkContext when killSparkContext=false") { + test("tracker should not kill SparkContext when killSparkContextOnWorkerFailure=false") { val nWorkers = numParallelism val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers) @@ -106,11 +106,11 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } - test("tracker should cancel the correct job when killSparkContext=false") { + test("tracker should cancel the correct job when killSparkContextOnWorkerFailure=false") { val nWorkers = 2 val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers) - val thread = new MyThread(sc) + val thread = new TestThread(sc) thread.start() try { tracker.execute { @@ -132,7 +132,7 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } - private[this] class MyThread(sc: SparkContext) extends Thread { + private[this] class TestThread(sc: SparkContext) extends Thread { override def run(): Unit = { var sum: Double = 0.0f try {