diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8cfdf7dba792..8773969c1f1d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -626,9 +626,8 @@ object XGBoost extends Serializable { (booster, metrics) } catch { case t: Throwable => - // if the job was aborted due to an exception + // if the job was aborted due to an exception, just throw the exception logger.error("the job was aborted due to ", t) - trainingData.sparkContext.stop() throw t } finally { uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index bd0ab4d6dd01..db0f66381c04 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -79,7 +79,7 @@ class SparkParallelismTracker( def execute[T](body: => T): T = { if (timeout <= 0) { logger.info("starting training without setting timeout for waiting for resources") - body + safeExecute(body) } else { logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") if (!waitForCondition(numAliveCores >= requestedCores, timeout)) { @@ -98,8 +98,8 @@ private[spark] class TaskFailedListener extends SparkListener { taskEnd.reason match { case taskEndReason: TaskFailedReason => logger.error(s"Training Task Failed during XGBoost Training: " + - s"$taskEndReason, stopping SparkContext") - TaskFailedListener.startedSparkContextKiller() + s"$taskEndReason, cancelling all jobs") + TaskFailedListener.cancelAllJobs() case _ => } } @@ -107,22 +107,21 @@ private[spark] class TaskFailedListener extends SparkListener { object TaskFailedListener { - var killerStarted = false + var cancelJobStarted = false - private def startedSparkContextKiller(): Unit = this.synchronized { - if (!killerStarted) { - // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it - // in a separate thread - val sparkContextKiller = new Thread() { + private def cancelAllJobs(): Unit = this.synchronized { + if (!cancelJobStarted) { + val cancelJob = new Thread() { override def run(): Unit = { LiveListenerBus.withinListenerThread.withValue(false) { - SparkContext.getOrCreate().stop() + SparkContext.getOrCreate().cancelAllJobs() } } } - sparkContextKiller.setDaemon(true) - sparkContextKiller.start() - killerStarted = true + cancelJob.setDaemon(true) + cancelJob.start() + cancelJobStarted = true } } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index c802fc0b0448..31f7a6b6f0f2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") } - private def waitForSparkContextShutdown(): Unit = { + private def sparkContextShouldNotShutDown(): Unit = { var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) { - Thread.sleep(10000) - totalWaitedTime += 10000 + while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) { + Thread.sleep(1000) + totalWaitedTime += 1000 } - assert(ss.sparkContext.isStopped === true) + assert(ss.sparkContext.isStopped === false) } test("fail training elegantly with unsupported objective function") { @@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - waitForSparkContextShutdown() + sparkContextShouldNotShutDown() } } @@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { } catch { case e: Throwable => // swallow anything } finally { - waitForSparkContextShutdown() + sparkContextShouldNotShutDown() } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 341db97bc447..8b23a9faa26c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -51,7 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => cleanExternalCache(currentSession.sparkContext.appName) currentSession = null } - TaskFailedListener.killerStarted = false + TaskFailedListener.cancelJobStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index d1b6ec0f9acc..8f1dd54bd337 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { } test("test rabit timeout fail handle") { - // disable spark kill listener to verify if rabit_timeout take effect and kill tasks - TaskFailedListener.killerStarted = true + // disable job cancel listener to verify if rabit_timeout take effect and kill tasks + TaskFailedListener.cancelJobStarted = true val training = buildDataFrame(Classification.train) // mock rank 0 failure during 8th allreduce synchronization @@ -109,11 +109,10 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { "rabit_timeout" -> 0)) .fit(training) } catch { - case e: Throwable => // swallow anything + case e: Throwable => println("----- " + e)// swallow anything } finally { - // assume all tasks throw exception almost same time - // 100ms should be enough to exhaust all retries - assert(waitAndCheckSparkShutdown(100) == true) + // wait 2s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(2000) == false) } } }