Skip to content

Commit

Permalink
cancel job instead of killing SparkContext
Browse files Browse the repository at this point in the history
This PR changes the default behavior that kills SparkContext. Instead, This PR
cancels jobs when coming across task failed. That means the SparkContext is
still alive even some exceptions happen.
  • Loading branch information
wbo4958 committed Aug 18, 2020
1 parent a418278 commit 9ab36f2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 29 deletions.
Expand Up @@ -626,9 +626,8 @@ object XGBoost extends Serializable {
(booster, metrics)
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
// if the job was aborted due to an exception, just throw the exception
logger.error("the job was aborted due to ", t)
trainingData.sparkContext.stop()
throw t
} finally {
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
Expand Down
Expand Up @@ -79,7 +79,7 @@ class SparkParallelismTracker(
def execute[T](body: => T): T = {
if (timeout <= 0) {
logger.info("starting training without setting timeout for waiting for resources")
body
safeExecute(body)
} else {
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
Expand All @@ -98,31 +98,30 @@ private[spark] class TaskFailedListener extends SparkListener {
taskEnd.reason match {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, stopping SparkContext")
TaskFailedListener.startedSparkContextKiller()
s"$taskEndReason, cancelling all jobs")
TaskFailedListener.cancelAllJobs()
case _ =>
}
}
}

object TaskFailedListener {

var killerStarted = false
var cancelJobStarted = false

private def startedSparkContextKiller(): Unit = this.synchronized {
if (!killerStarted) {
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
// in a separate thread
val sparkContextKiller = new Thread() {
private def cancelAllJobs(): Unit = this.synchronized {
if (!cancelJobStarted) {
val cancelJob = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
SparkContext.getOrCreate().stop()
SparkContext.getOrCreate().cancelAllJobs()
}
}
}
sparkContextKiller.setDaemon(true)
sparkContextKiller.start()
killerStarted = true
cancelJob.setDaemon(true)
cancelJob.start()
cancelJobStarted = true
}
}

}
Expand Up @@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}

private def waitForSparkContextShutdown(): Unit = {
private def sparkContextShouldNotShutDown(): Unit = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) {
Thread.sleep(10000)
totalWaitedTime += 10000
while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) {
Thread.sleep(1000)
totalWaitedTime += 1000
}
assert(ss.sparkContext.isStopped === true)
assert(ss.sparkContext.isStopped === false)
}

test("fail training elegantly with unsupported objective function") {
Expand All @@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
} catch {
case e: Throwable => // swallow anything
} finally {
waitForSparkContextShutdown()
sparkContextShouldNotShutDown()
}
}

Expand All @@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
} catch {
case e: Throwable => // swallow anything
} finally {
waitForSparkContextShutdown()
sparkContextShouldNotShutDown()
}
}
}
Expand Up @@ -51,7 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
cleanExternalCache(currentSession.sparkContext.appName)
currentSession = null
}
TaskFailedListener.killerStarted = false
TaskFailedListener.cancelJobStarted = false
}
}

Expand Down
Expand Up @@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
}

test("test rabit timeout fail handle") {
// disable spark kill listener to verify if rabit_timeout take effect and kill tasks
TaskFailedListener.killerStarted = true
// disable job cancel listener to verify if rabit_timeout take effect and kill tasks
TaskFailedListener.cancelJobStarted = true

val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Expand All @@ -109,11 +109,10 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
"rabit_timeout" -> 0))
.fit(training)
} catch {
case e: Throwable => // swallow anything
case e: Throwable => println("----- " + e)// swallow anything
} finally {
// assume all tasks throw exception almost same time
// 100ms should be enough to exhaust all retries
assert(waitAndCheckSparkShutdown(100) == true)
// wait 2s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(2000) == false)
}
}
}

0 comments on commit 9ab36f2

Please sign in to comment.