From 9ab36f2e197a0bb89dac26ce258bd5cd5bb59a90 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Thu, 13 Aug 2020 11:36:54 +0800 Subject: [PATCH] cancel job instead of killing SparkContext This PR changes the default behavior that kills SparkContext. Instead, This PR cancels jobs when coming across task failed. That means the SparkContext is still alive even some exceptions happen. --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 3 +-- .../spark/SparkParallelismTracker.scala | 25 +++++++++---------- .../scala/spark/ParameterSuite.scala | 14 +++++------ .../dmlc/xgboost4j/scala/spark/PerTest.scala | 2 +- .../spark/XGBoostRabitRegressionSuite.scala | 11 ++++---- 5 files changed, 26 insertions(+), 29 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8cfdf7dba792..8773969c1f1d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -626,9 +626,8 @@ object XGBoost extends Serializable { (booster, metrics) } catch { case t: Throwable => - // if the job was aborted due to an exception + // if the job was aborted due to an exception, just throw the exception logger.error("the job was aborted due to ", t) - trainingData.sparkContext.stop() throw t } finally { uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index bd0ab4d6dd01..db0f66381c04 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -79,7 +79,7 @@ class SparkParallelismTracker( def execute[T](body: => T): T = { if (timeout <= 0) { logger.info("starting training without setting timeout for waiting for resources") - body + safeExecute(body) } else { logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") if (!waitForCondition(numAliveCores >= requestedCores, timeout)) { @@ -98,8 +98,8 @@ private[spark] class TaskFailedListener extends SparkListener { taskEnd.reason match { case taskEndReason: TaskFailedReason => logger.error(s"Training Task Failed during XGBoost Training: " + - s"$taskEndReason, stopping SparkContext") - TaskFailedListener.startedSparkContextKiller() + s"$taskEndReason, cancelling all jobs") + TaskFailedListener.cancelAllJobs() case _ => } } @@ -107,22 +107,21 @@ private[spark] class TaskFailedListener extends SparkListener { object TaskFailedListener { - var killerStarted = false + var cancelJobStarted = false - private def startedSparkContextKiller(): Unit = this.synchronized { - if (!killerStarted) { - // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it - // in a separate thread - val sparkContextKiller = new Thread() { + private def cancelAllJobs(): Unit = this.synchronized { + if (!cancelJobStarted) { + val cancelJob = new Thread() { override def run(): Unit = { LiveListenerBus.withinListenerThread.withValue(false) { - SparkContext.getOrCreate().stop() + SparkContext.getOrCreate().cancelAllJobs() } } } - sparkContextKiller.setDaemon(true) - sparkContextKiller.start() - killerStarted = true + cancelJob.setDaemon(true) + cancelJob.start() + cancelJobStarted = true } } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index c802fc0b0448..31f7a6b6f0f2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") } - private def waitForSparkContextShutdown(): Unit = { + private def sparkContextShouldNotShutDown(): Unit = { var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) { - Thread.sleep(10000) - totalWaitedTime += 10000 + while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) { + Thread.sleep(1000) + totalWaitedTime += 1000 } - assert(ss.sparkContext.isStopped === true) + assert(ss.sparkContext.isStopped === false) } test("fail training elegantly with unsupported objective function") { @@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - waitForSparkContextShutdown() + sparkContextShouldNotShutDown() } } @@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - waitForSparkContextShutdown() + sparkContextShouldNotShutDown() } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 341db97bc447..8b23a9faa26c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -51,7 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => cleanExternalCache(currentSession.sparkContext.appName) currentSession = null } - TaskFailedListener.killerStarted = false + TaskFailedListener.cancelJobStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index d1b6ec0f9acc..8f1dd54bd337 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { } test("test rabit timeout fail handle") { - // disable spark kill listener to verify if rabit_timeout take effect and kill tasks - TaskFailedListener.killerStarted = true + // disable job cancel listener to verify if rabit_timeout take effect and kill tasks + TaskFailedListener.cancelJobStarted = true val training = buildDataFrame(Classification.train) // mock rank 0 failure during 8th allreduce synchronization @@ -109,11 +109,10 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { "rabit_timeout" -> 0)) .fit(training) } catch { - case e: Throwable => // swallow anything + case e: Throwable => println("----- " + e)// swallow anything } finally { - // assume all tasks throw exception almost same time - // 100ms should be enough to exhaust all retries - assert(waitAndCheckSparkShutdown(100) == true) + // wait 2s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(2000) == false) } } }