Skip to content

Commit

Permalink
cancel the jobs the failed task belongs to
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Aug 26, 2020
1 parent f834537 commit 8852f59
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 33 deletions.
Expand Up @@ -632,7 +632,7 @@ object XGBoost extends Serializable {
(booster, metrics)
} catch {
case t: Throwable =>
// if the job was aborted due to an exception, just throw the exception
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
if (xgbExecParams.killSparkContext) {
trainingData.sparkContext.stop()
Expand Down
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark
import org.apache.commons.logging.LogFactory
import org.apache.spark.scheduler._

import scala.collection.mutable.{HashMap, HashSet}

/**
* A tracker that ensures enough number of executor cores are alive.
* Throws an exception when the number of alive cores is less than nWorkers.
Expand Down Expand Up @@ -96,46 +98,53 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener

private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskEnd.reason match {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, cancelling all jobs")
taskFaultHandle
case _ =>
// {jobId, [stageId0, stageId1, ...] }
// keep track of the mapping of job id and stage ids
// when a task failed, find the job id and stage Id the task belongs to, finally
// cancel the jobs
private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty

override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
if (!killSparkContext) {
jobStart.stageIds.foreach(stageId => {
jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId
})
}
}

private[this] def taskFaultHandle = {
if (killSparkContext == true) {
TaskFailedListener.startedSparkContextKiller()
} else {
TaskFailedListener.cancelAllJobs()
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
if (!killSparkContext) {
jobIdToStageIds.remove(jobEnd.jobId)
}
}
}

object TaskFailedListener {
private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")
var cancelJobStarted = false

private def cancelAllJobs(): Unit = this.synchronized {
if (!cancelJobStarted) {
cancelJobStarted = true
val cancelJob = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
logger.info("will call spark cancel all jobs")
SparkContext.getOrCreate().cancelAllJobs()
logger.info("after call spark cancel all jobs")
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskEnd.reason match {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason")
if (killSparkContext) {
logger.error("killing SparkContext")
TaskFailedListener.startedSparkContextKiller()
} else {
val stageId = taskEnd.stageId
// find job ids according to stage id and then cancel the job
jobIdToStageIds.foreach(t => {
val jobId = t._1
val stageIds = t._2

if (stageIds.contains(stageId)) {
logger.error("Cancelling jobId:" + jobId)
SparkContext.getOrCreate().cancelJob(jobId)
}
})
}
}
cancelJob.setDaemon(true)
cancelJob.start()
case _ =>
}
}
}

object TaskFailedListener {
var killerStarted = false

private def startedSparkContextKiller(): Unit = this.synchronized {
Expand All @@ -154,5 +163,4 @@ object TaskFailedListener {
killerStarted = true
}
}

}
Expand Up @@ -52,7 +52,6 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
currentSession = null
}
TaskFailedListener.killerStarted = false
TaskFailedListener.cancelJobStarted = false
}
}

Expand Down
Expand Up @@ -86,11 +86,36 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {

test("tracker should not kill SparkContext when killSparkContext=false") {
val nWorkers = numParallelism
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers)
try {
tracker.execute {
rdd.map { i =>
val partitionId = TaskContext.get().partitionId()
if (partitionId == 0) {
throw new RuntimeException("mocking task failing")
}
i
}.sum()
}
} catch {
case e: Exception => // catch the exception
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}

test("tracker should cancel correct job when killSparkContext=false") {
val nWorkers = 2
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers)
val thread = new MyThread(sc)
thread.start()
try {
tracker.execute {
rdd.map { i =>
Thread.sleep(100)
val partitionId = TaskContext.get().partitionId()
if (partitionId == 0) {
throw new RuntimeException("mocking task failing")
Expand All @@ -101,9 +126,26 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
} catch {
case e: Exception => // catch the exception
} finally {
thread.join(8000)
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}

private[this] class MyThread(sc: SparkContext) extends Thread {
override def run(): Unit = {
var sum: Double = 0.0f
try {
val rdd = sc.parallelize(1 to 4, 2)
sum = rdd.mapPartitions(iter => {
// sleep 2s to ensure task is alive when cancelling other jobs
Thread.sleep(2000)
iter
}).sum()
} finally {
// get the correct result
assert(sum.toInt == 10)
}
}
}
}

0 comments on commit 8852f59

Please sign in to comment.