Skip to content

Commit

Permalink
add a parameter to control if killing SparkContext
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Aug 18, 2020
1 parent 9ab36f2 commit c8dacc6
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 23 deletions.
Expand Up @@ -76,7 +76,8 @@ private[this] case class XGBoostExecutionParams(
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean,
treeMethod: Option[String],
isLocal: Boolean) {
isLocal: Boolean,
killSparkContext: Boolean) {

private var rawParamMap: Map[String, Any] = _

Expand Down Expand Up @@ -220,6 +221,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
.asInstanceOf[Boolean]

val killSparkContext = overridedParams.getOrElse("kill_spark_context", true)
.asInstanceOf[Boolean]

val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
missing, allowNonZeroForMissing, trackerConf,
timeoutRequestWorkers,
Expand All @@ -228,7 +232,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecEarlyStoppingParams,
cacheTrainingSet,
treeMethod,
isLocal)
isLocal,
killSparkContext)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
Expand Down Expand Up @@ -588,7 +593,8 @@ object XGBoost extends Serializable {
val (booster, metrics) = try {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
xgbExecParams.numWorkers,
xgbExecParams.killSparkContext)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
Expand Down Expand Up @@ -628,6 +634,9 @@ object XGBoost extends Serializable {
case t: Throwable =>
// if the job was aborted due to an exception, just throw the exception
logger.error("the job was aborted due to ", t)
if (xgbExecParams.killSparkContext) {
trainingData.sparkContext.stop()
}
throw t
} finally {
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
Expand Down
Expand Up @@ -105,8 +105,14 @@ private[spark] trait LearningTaskParams extends Params {

final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)

setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
/**
* whether killing SparkContext when training task fails
*/
final val killSparkContext = new BooleanParam(this, "killSparkContext",
"whether killing SparkContext when training task fails")

setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0,
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContext -> true)
}

private[spark] object LearningTaskParams {
Expand Down
Expand Up @@ -26,11 +26,13 @@ import org.apache.spark.scheduler._
* @param sc The SparkContext object
* @param timeout The maximum time to wait for enough number of workers.
* @param numWorkers nWorkers used in an XGBoost Job
* @param killSparkContext kill SparkContext or not when task fails
*/
class SparkParallelismTracker(
val sc: SparkContext,
timeout: Long,
numWorkers: Int) {
numWorkers: Int,
killSparkContext: Boolean = true) {

private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1)
private[this] val logger = LogFactory.getLog("XGBoostSpark")
Expand Down Expand Up @@ -58,7 +60,7 @@ class SparkParallelismTracker(
}

private[this] def safeExecute[T](body: => T): T = {
val listener = new TaskFailedListener
val listener = new TaskFailedListener(killSparkContext)
sc.addSparkListener(listener)
try {
body
Expand Down Expand Up @@ -90,7 +92,7 @@ class SparkParallelismTracker(
}
}

private[spark] class TaskFailedListener extends SparkListener {
class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener {

private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")

Expand All @@ -99,28 +101,57 @@ private[spark] class TaskFailedListener extends SparkListener {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, cancelling all jobs")
TaskFailedListener.cancelAllJobs()
taskFaultHandle
case _ =>
}
}

private[this] def taskFaultHandle = {
if (killSparkContext == true) {
TaskFailedListener.startedSparkContextKiller()
} else {
TaskFailedListener.cancelAllJobs()
}
}
}

object TaskFailedListener {

private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener")
var cancelJobStarted = false

private def cancelAllJobs(): Unit = this.synchronized {
if (!cancelJobStarted) {
cancelJobStarted = true
val cancelJob = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
logger.info("will call spark cancel all jobs")
SparkContext.getOrCreate().cancelAllJobs()
logger.info("after call spark cancel all jobs")
}
}
}
cancelJob.setDaemon(true)
cancelJob.start()
cancelJobStarted = true
}
}

var killerStarted = false

private def startedSparkContextKiller(): Unit = this.synchronized {
if (!killerStarted) {
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
// in a separate thread
val sparkContextKiller = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
SparkContext.getOrCreate().stop()
}
}
}
sparkContextKiller.setDaemon(true)
sparkContextKiller.start()
killerStarted = true
}
}

Expand Down
Expand Up @@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}

private def sparkContextShouldNotShutDown(): Unit = {
private def waitForSparkContextShutdown(): Unit = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) {
Thread.sleep(1000)
totalWaitedTime += 1000
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) {
Thread.sleep(10000)
totalWaitedTime += 10000
}
assert(ss.sparkContext.isStopped === false)
assert(ss.sparkContext.isStopped === true)
}

test("fail training elegantly with unsupported objective function") {
Expand All @@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
} catch {
case e: Throwable => // swallow anything
} finally {
sparkContextShouldNotShutDown()
waitForSparkContextShutdown()
}
}

Expand All @@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
} catch {
case e: Throwable => // swallow anything
} finally {
sparkContextShouldNotShutDown()
waitForSparkContextShutdown()
}
}
}
Expand Up @@ -51,6 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
cleanExternalCache(currentSession.sparkContext.appName)
currentSession = null
}
TaskFailedListener.killerStarted = false
TaskFailedListener.cancelJobStarted = false
}
}
Expand Down
Expand Up @@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
}

test("test rabit timeout fail handle") {
// disable job cancel listener to verify if rabit_timeout take effect and kill tasks
TaskFailedListener.cancelJobStarted = true
// disable spark kill listener to verify if rabit_timeout take effect and kill tasks
TaskFailedListener.killerStarted = true

val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Expand All @@ -109,10 +109,35 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
"rabit_timeout" -> 0))
.fit(training)
} catch {
case e: Throwable => println("----- " + e)// swallow anything
case e: Throwable => // swallow anything
} finally {
// wait 2s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(2000) == false)
// assume all tasks throw exception almost same time
// 100ms should be enough to exhaust all retries
assert(waitAndCheckSparkShutdown(100) == true)
}
}

test("test SparkContext should not be killed ") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Rabit.mockList = Array("0,8,0,0").toList.asJava

try {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"kill_spark_context" -> false,
"rabit_timeout" -> 0))
.fit(training)
} catch {
case e: Throwable => // swallow anything
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}
}
Expand Up @@ -34,6 +34,15 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)

private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) {
Thread.sleep(100)
totalWaitedTime += 100
}
ss.sparkContext.isStopped
}

test("tracker should not affect execution result when timeout is not larger than 0") {
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers)
Expand Down Expand Up @@ -74,4 +83,27 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest {
}
}
}

test("tracker should not kill SparkContext when killSparkContext=false") {
val nWorkers = numParallelism
val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers)
val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false)
try {
tracker.execute {
rdd.map { i =>
val partitionId = TaskContext.get().partitionId()
if (partitionId == 0) {
throw new RuntimeException("mocking task failing")
}
i
}.sum()
}
} catch {
case e: Exception => // catch the exception
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}

}

0 comments on commit c8dacc6

Please sign in to comment.