diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8efbb56fcc60..f7bf437d01d9 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -632,7 +632,7 @@ object XGBoost extends Serializable { (booster, metrics) } catch { case t: Throwable => - // if the job was aborted due to an exception, just throw the exception + // if the job was aborted due to an exception logger.error("the job was aborted due to ", t) if (xgbExecParams.killSparkContext) { trainingData.sparkContext.stop() diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 5baef44c34ac..94cfeaefe293 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -19,6 +19,8 @@ package org.apache.spark import org.apache.commons.logging.LogFactory import org.apache.spark.scheduler._ +import scala.collection.mutable.{HashMap, HashSet} + /** * A tracker that ensures enough number of executor cores are alive. * Throws an exception when the number of alive cores is less than nWorkers. @@ -96,46 +98,53 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - taskEnd.reason match { - case taskEndReason: TaskFailedReason => - logger.error(s"Training Task Failed during XGBoost Training: " + - s"$taskEndReason, cancelling all jobs") - taskFaultHandle - case _ => + // {jobId, [stageId0, stageId1, ...] } + // keep track of the mapping of job id and stage ids + // when a task failed, find the job id and stage Id the task belongs to, finally + // cancel the jobs + private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + if (!killSparkContext) { + jobStart.stageIds.foreach(stageId => { + jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId + }) } } - private[this] def taskFaultHandle = { - if (killSparkContext == true) { - TaskFailedListener.startedSparkContextKiller() - } else { - TaskFailedListener.cancelAllJobs() + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (!killSparkContext) { + jobIdToStageIds.remove(jobEnd.jobId) } } -} - -object TaskFailedListener { - private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") - var cancelJobStarted = false - private def cancelAllJobs(): Unit = this.synchronized { - if (!cancelJobStarted) { - cancelJobStarted = true - val cancelJob = new Thread() { - override def run(): Unit = { - LiveListenerBus.withinListenerThread.withValue(false) { - logger.info("will call spark cancel all jobs") - SparkContext.getOrCreate().cancelAllJobs() - logger.info("after call spark cancel all jobs") - } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEnd.reason match { + case taskEndReason: TaskFailedReason => + logger.error(s"Training Task Failed during XGBoost Training: " + + s"$taskEndReason") + if (killSparkContext) { + logger.error("killing SparkContext") + TaskFailedListener.startedSparkContextKiller() + } else { + val stageId = taskEnd.stageId + // find job ids according to stage id and then cancel the job + jobIdToStageIds.foreach(t => { + val jobId = t._1 + val stageIds = t._2 + + if (stageIds.contains(stageId)) { + logger.error("Cancelling jobId:" + jobId) + SparkContext.getOrCreate().cancelJob(jobId) + } + }) } - } - cancelJob.setDaemon(true) - cancelJob.start() + case _ => } } +} +object TaskFailedListener { var killerStarted = false private def startedSparkContextKiller(): Unit = this.synchronized { @@ -154,5 +163,4 @@ object TaskFailedListener { killerStarted = true } } - } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 9154bc7c7055..341db97bc447 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -52,7 +52,6 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => currentSession = null } TaskFailedListener.killerStarted = false - TaskFailedListener.cancelJobStarted = false } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index 804740fcda39..dadadeb219f4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -86,11 +86,36 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { test("tracker should not kill SparkContext when killSparkContext=false") { val nWorkers = numParallelism + val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers) + try { + tracker.execute { + rdd.map { i => + val partitionId = TaskContext.get().partitionId() + if (partitionId == 0) { + throw new RuntimeException("mocking task failing") + } + i + }.sum() + } + } catch { + case e: Exception => // catch the exception + } finally { + // wait 3s to check if SparkContext is killed + assert(waitAndCheckSparkShutdown(3000) == false) + } + } + + test("tracker should cancel correct job when killSparkContext=false") { + val nWorkers = 2 val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) + val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers) + val thread = new MyThread(sc) + thread.start() try { tracker.execute { rdd.map { i => + Thread.sleep(100) val partitionId = TaskContext.get().partitionId() if (partitionId == 0) { throw new RuntimeException("mocking task failing") @@ -101,9 +126,26 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } catch { case e: Exception => // catch the exception } finally { + thread.join(8000) // wait 3s to check if SparkContext is killed assert(waitAndCheckSparkShutdown(3000) == false) } } + private[this] class MyThread(sc: SparkContext) extends Thread { + override def run(): Unit = { + var sum: Double = 0.0f + try { + val rdd = sc.parallelize(1 to 4, 2) + sum = rdd.mapPartitions(iter => { + // sleep 2s to ensure task is alive when cancelling other jobs + Thread.sleep(2000) + iter + }).sum() + } finally { + // get the correct result + assert(sum.toInt == 10) + } + } + } }