Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] cancel job instead of killing SparkContext #6019

Merged
merged 5 commits into from Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -76,7 +76,8 @@ private[this] case class XGBoostExecutionParams(
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean,
treeMethod: Option[String],
isLocal: Boolean) {
isLocal: Boolean,
killSparkContextOnWorkerFailure: Boolean) {

private var rawParamMap: Map[String, Any] = _

Expand Down Expand Up @@ -220,6 +221,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
.asInstanceOf[Boolean]

val killSparkContext = overridedParams.getOrElse("kill_spark_context_on_worker_failure", true)
.asInstanceOf[Boolean]

val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
missing, allowNonZeroForMissing, trackerConf,
timeoutRequestWorkers,
Expand All @@ -228,7 +232,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecEarlyStoppingParams,
cacheTrainingSet,
treeMethod,
isLocal)
isLocal,
killSparkContext)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
Expand Down Expand Up @@ -588,7 +593,8 @@ object XGBoost extends Serializable {
val (booster, metrics) = try {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
xgbExecParams.numWorkers,
xgbExecParams.killSparkContextOnWorkerFailure)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
Expand Down Expand Up @@ -628,7 +634,9 @@ object XGBoost extends Serializable {
case t: Throwable =>
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
trainingData.sparkContext.stop()
if (xgbExecParams.killSparkContextOnWorkerFailure) {
trainingData.sparkContext.stop()
}
throw t
} finally {
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
Expand Down
Expand Up @@ -105,8 +105,14 @@ private[spark] trait LearningTaskParams extends Params {

final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)

setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
/**
* whether killing SparkContext when training task fails
*/
final val killSparkContextOnWorkerFailure = new BooleanParam(this,
"killSparkContextOnWorkerFailure", "whether killing SparkContext when training task fails")

setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0,
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContextOnWorkerFailure -> true)
}

private[spark] object LearningTaskParams {
Expand Down
Expand Up @@ -19,18 +19,22 @@ package org.apache.spark
import org.apache.commons.logging.LogFactory
import org.apache.spark.scheduler._

import scala.collection.mutable.{HashMap, HashSet}

/**
* A tracker that ensures enough number of executor cores are alive.
* Throws an exception when the number of alive cores is less than nWorkers.
*
* @param sc The SparkContext object
* @param timeout The maximum time to wait for enough number of workers.
* @param numWorkers nWorkers used in an XGBoost Job
* @param killSparkContextOnWorkerFailure kill SparkContext or not when task fails
*/
class SparkParallelismTracker(
val sc: SparkContext,
timeout: Long,
numWorkers: Int) {
numWorkers: Int,
killSparkContextOnWorkerFailure: Boolean = true) {

private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
private[this] val logger = LogFactory.getLog("XGBoostSpark")
Expand Down Expand Up @@ -58,7 +62,7 @@ class SparkParallelismTracker(
}

private[this] def safeExecute[T](body: => T): T = {
val listener = new TaskFailedListener
val listener = new TaskFailedListener(killSparkContextOnWorkerFailure)
sc.addSparkListener(listener)
try {
body
Expand All @@ -79,7 +83,7 @@ class SparkParallelismTracker(
def execute[T](body: => T): T = {
if (timeout <= 0) {
logger.info("starting training without setting timeout for waiting for resources")
body
safeExecute(body)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we have to change to safeExecute()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The safeExecute wraps TaskFailedListener inside. I don't know why the body was not executed in safeExecute in the previous version, Since it may hang forever if no TaskFailedListener.

} else {
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
Expand All @@ -90,16 +94,51 @@ class SparkParallelismTracker(
}
}

private[spark] class TaskFailedListener extends SparkListener {
class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener {

private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")

// {jobId, [stageId0, stageId1, ...] }
// keep track of the mapping of job id and stage ids
// when a task fails, find the job id and stage id the task belongs to, finally
// cancel the jobs
private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty

override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
if (!killSparkContext) {
jobStart.stageIds.foreach(stageId => {
jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId
})
}
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
if (!killSparkContext) {
jobIdToStageIds.remove(jobEnd.jobId)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskEnd.reason match {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, stopping SparkContext")
TaskFailedListener.startedSparkContextKiller()
s"$taskEndReason")
if (killSparkContext) {
logger.error("killing SparkContext")
TaskFailedListener.startedSparkContextKiller()
} else {
val stageId = taskEnd.stageId
// find job ids according to stage id and then cancel the job

jobIdToStageIds.foreach {
case (jobId, stageIds) =>
if (stageIds.contains(stageId)) {
logger.error("Cancelling jobId:" + jobId)
jobIdToStageIds.remove(jobId)
SparkContext.getOrCreate().cancelJob(jobId)
}
}
}
case _ =>
}
}
Expand Down
Expand Up @@ -116,4 +116,28 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
assert(waitAndCheckSparkShutdown(100) == true)
}
}

test("test SparkContext should not be killed ") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Rabit.mockList = Array("0,8,0,0").toList.asJava

try {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"kill_spark_context_on_worker_failure" -> false,
"rabit_timeout" -> 0))
.fit(training)
} catch {
case e: Throwable => // swallow anything
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}
}
Expand Up @@ -34,6 +34,15 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)

private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) {
Thread.sleep(100)
totalWaitedTime += 100
}
ss.sparkContext.isStopped
}

test("tracker should not affect execution result when timeout is not larger than 0") {
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
Expand Down Expand Up @@ -74,4 +83,69 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
}
}
}

test("tracker should not kill SparkContext when killSparkContextOnWorkerFailure=false") {
val nWorkers = numParallelism
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers)
try {
tracker.execute {
rdd.map { i =>
val partitionId = TaskContext.get().partitionId()
if (partitionId == 0) {
throw new RuntimeException("mocking task failing")
}
i
}.sum()
}
} catch {
case e: Exception => // catch the exception
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}

test("tracker should cancel the correct job when killSparkContextOnWorkerFailure=false") {
val nWorkers = 2
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers)
val thread = new TestThread(sc)
thread.start()
try {
tracker.execute {
rdd.map { i =>
Thread.sleep(100)
val partitionId = TaskContext.get().partitionId()
if (partitionId == 0) {
throw new RuntimeException("mocking task failing")
}
i
}.sum()
}
} catch {
case e: Exception => // catch the exception
} finally {
thread.join(8000)
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}

private[this] class TestThread(sc: SparkContext) extends Thread {
override def run(): Unit = {
var sum: Double = 0.0f
try {
val rdd = sc.parallelize(1 to 4, 2)
sum = rdd.mapPartitions(iter => {
// sleep 2s to ensure task is alive when cancelling other jobs
Thread.sleep(2000)
iter
}).sum()
} finally {
// get the correct result
assert(sum.toInt == 10)
}
}
}
}