diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 94cfeaefe293..17ec6a9102e8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -135,6 +135,7 @@ class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener if (stageIds.contains(stageId)) { logger.error("Cancelling jobId:" + jobId) + jobIdToStageIds.remove(jobId) SparkContext.getOrCreate().cancelJob(jobId) } }) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala index dadadeb219f4..eab81f02a8d4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala @@ -106,7 +106,7 @@ class SparkParallelismTrackerSuite extends FunSuite with PerTest { } } - test("tracker should cancel correct job when killSparkContext=false") { + test("tracker should cancel the correct job when killSparkContext=false") { val nWorkers = 2 val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers)