diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index df19858749cd..e6ccb6349b57 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -21,6 +21,7 @@ import java.io.File import scala.collection.mutable import scala.util.Random import scala.collection.JavaConverters._ + import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams @@ -30,8 +31,9 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.FileSystem + import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.sql.SparkSession /** @@ -79,8 +81,7 @@ private[scala] case class XGBoostExecutionParams( earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, cacheTrainingSet: Boolean, treeMethod: Option[String], - isLocal: Boolean, - killSparkContextOnWorkerFailure: Boolean) { + isLocal: Boolean) { private var rawParamMap: Map[String, Any] = _ @@ -224,9 +225,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false) .asInstanceOf[Boolean] - val killSparkContext = overridedParams.getOrElse("kill_spark_context_on_worker_failure", true) - .asInstanceOf[Boolean] - val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval, missing, allowNonZeroForMissing, trackerConf, timeoutRequestWorkers, @@ -235,8 +233,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s xgbExecEarlyStoppingParams, cacheTrainingSet, treeMethod, - isLocal, - killSparkContext) + isLocal) xgbExecParam.setRawParamMap(overridedParams) xgbExecParam } @@ -351,7 +348,11 @@ object XGBoost extends Serializable { watches.toMap, metrics, obj, eval, earlyStoppingRound = numEarlyStoppingRounds, prevBooster) } - Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) + if (TaskContext.get().partitionId() == 0) { + Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) + } else { + Iterator.empty + } } catch { case xgbException: XGBoostError => logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException) @@ -409,15 +410,10 @@ object XGBoost extends Serializable { // Train for every ${savingRound} rounds and save the partially completed booster val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) val (booster, metrics) = try { - val parallelismTracker = new SparkParallelismTracker(sc, - xgbExecParams.timeoutRequestWorkers, - xgbExecParams.numWorkers, - xgbExecParams.killSparkContextOnWorkerFailure) - tracker.getWorkerEnvs().putAll(xgbRabitParams) val rabitEnv = tracker.getWorkerEnvs - val boostersAndMetrics = trainingRDD.mapPartitions { iter => { + val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => { var optionWatches: Option[() => Watches] = None // take the first Watches to train @@ -430,24 +426,14 @@ object XGBoost extends Serializable { xgbExecParams.eval, prevBooster)} .getOrElse(throw new RuntimeException("No Watches to train")) - }}.cache() - - val sparkJobThread = new Thread() { - override def run() { - // force the job - boostersAndMetrics.foreachPartition(() => _) - } - } - sparkJobThread.setUncaughtExceptionHandler(tracker) - - val trackerReturnVal = parallelismTracker.execute { - sparkJobThread.start() - tracker.waitFor(0L) - } + }} + val (booster, metrics) = boostersAndMetrics.collect()(0) + val trackerReturnVal = tracker.waitFor(0L) logger.info(s"Rabit returns with exit code $trackerReturnVal") - val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, - boostersAndMetrics, sparkJobThread) + if (trackerReturnVal != 0) { + throw new XGBoostError("XGBoostModel training failed.") + } (booster, metrics) } finally { tracker.stop() @@ -467,42 +453,12 @@ object XGBoost extends Serializable { case t: Throwable => // if the job was aborted due to an exception logger.error("the job was aborted due to ", t) - if (xgbExecParams.killSparkContextOnWorkerFailure) { - sc.stop() - } throw t } finally { optionalCachedRDD.foreach(_.unpersist()) } } - private def postTrackerReturnProcessing( - trackerReturnVal: Int, - distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])], - sparkJobThread: Thread): (Booster, Map[String, Array[Float]]) = { - if (trackerReturnVal == 0) { - // Copies of the final booster and the corresponding metrics - // reside in each partition of the `distributedBoostersAndMetrics`. - // Any of them can be used to create the model. - // it's safe to block here forever, as the tracker has returned successfully, and the Spark - // job should have finished, there is no reason for the thread cannot return - sparkJobThread.join() - val (booster, metrics) = distributedBoostersAndMetrics.first() - distributedBoostersAndMetrics.unpersist(false) - (booster, metrics) - } else { - try { - if (sparkJobThread.isAlive) { - sparkJobThread.interrupt() - } - } catch { - case _: InterruptedException => - logger.info("spark job thread is interrupted") - } - throw new XGBoostError("XGBoostModel training failed") - } - } - } class Watches private[scala] ( diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 988535547441..852864d9cb1c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -105,14 +105,8 @@ private[spark] trait LearningTaskParams extends Params { final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics) - /** - * whether killing SparkContext when training task fails - */ - final val killSparkContextOnWorkerFailure = new BooleanParam(this, - "killSparkContextOnWorkerFailure", "whether killing SparkContext when training task fails") - setDefault(objective -> "reg:squarederror", baseScore -> 0.5, trainTestRatio -> 1.0, - numEarlyStoppingRounds -> 0, cacheTrainingSet -> false, killSparkContextOnWorkerFailure -> true) + numEarlyStoppingRounds -> 0, cacheTrainingSet -> false) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala deleted file mode 100644 index 99c1cccf2761..000000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package org.apache.spark - -import org.apache.commons.logging.LogFactory -import org.apache.spark.scheduler._ - -import scala.collection.mutable.{HashMap, HashSet} - -/** - * A tracker that ensures enough number of executor cores are alive. - * Throws an exception when the number of alive cores is less than nWorkers. - * - * @param sc The SparkContext object - * @param timeout The maximum time to wait for enough number of workers. - * @param numWorkers nWorkers used in an XGBoost Job - * @param killSparkContextOnWorkerFailure kill SparkContext or not when task fails - */ -class SparkParallelismTracker( - val sc: SparkContext, - timeout: Long, - numWorkers: Int, - killSparkContextOnWorkerFailure: Boolean = true) { - - private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) - private[this] val logger = LogFactory.getLog("XGBoostSpark") - - private[this] def numAliveCores: Int = { - sc.statusStore.executorList(true).map(_.totalCores).sum - } - - private[this] def waitForCondition( - condition: => Boolean, - timeout: Long, - checkInterval: Long = 100L) = { - val waitImpl = new ((Long, Boolean) => Boolean) { - override def apply(waitedTime: Long, status: Boolean): Boolean = status match { - case s if s => true - case _ => waitedTime match { - case t if t < timeout => - Thread.sleep(checkInterval) - apply(t + checkInterval, status = condition) - case _ => false - } - } - } - waitImpl(0L, condition) - } - - private[this] def safeExecute[T](body: => T): T = { - val listener = new TaskFailedListener(killSparkContextOnWorkerFailure) - sc.addSparkListener(listener) - try { - body - } finally { - sc.removeSparkListener(listener) - } - } - - /** - * Execute a blocking function call with two checks on enough nWorkers: - * - Before the function starts, wait until there are enough executor cores. - * - During the execution, throws an exception if there is any executor lost. - * - * @param body A blocking function call - * @tparam T Return type - * @return The return of body - */ - def execute[T](body: => T): T = { - if (timeout <= 0) { - logger.info("starting training without setting timeout for waiting for resources") - safeExecute(body) - } else { - logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") - if (!waitForCondition(numAliveCores >= requestedCores, timeout)) { - throw new IllegalStateException(s"Unable to get $requestedCores cores for XGBoost training") - } - safeExecute(body) - } - } -} - -class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener { - - private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") - - // {jobId, [stageId0, stageId1, ...] } - // keep track of the mapping of job id and stage ids - // when a task fails, find the job id and stage id the task belongs to, finally - // cancel the jobs - private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - if (!killSparkContext) { - jobStart.stageIds.foreach(stageId => { - jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId - }) - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - if (!killSparkContext) { - jobIdToStageIds.remove(jobEnd.jobId) - } - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - taskEnd.reason match { - case taskEndReason: TaskFailedReason => - logger.error(s"Training Task Failed during XGBoost Training: " + - s"$taskEndReason") - if (killSparkContext) { - logger.error("killing SparkContext") - TaskFailedListener.startedSparkContextKiller() - } else { - val stageId = taskEnd.stageId - // find job ids according to stage id and then cancel the job - - jobIdToStageIds.foreach { - case (jobId, stageIds) => - if (stageIds.contains(stageId)) { - logger.error("Cancelling jobId:" + jobId) - jobIdToStageIds.remove(jobId) - SparkContext.getOrCreate().cancelJob(jobId) - } - } - } - case _ => - } - } -} - -object TaskFailedListener { - - var killerStarted: Boolean = false - - var sparkContextKiller: Thread = _ - - val sparkContextShutdownLock = new AnyRef - - private def startedSparkContextKiller(): Unit = this.synchronized { - if (!killerStarted) { - killerStarted = true - // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it - // in a separate thread - sparkContextKiller = new Thread() { - override def run(): Unit = { - LiveListenerBus.withinListenerThread.withValue(false) { - sparkContextShutdownLock.synchronized { - SparkContext.getActive.foreach(_.stop()) - killerStarted = false - sparkContextShutdownLock.notify() - } - } - } - } - sparkContextKiller.setDaemon(true) - sparkContextKiller.start() - } - } -} diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/log4j.properties b/jvm-packages/xgboost4j-spark/src/test/resources/log4j.properties index dcd02d2c878d..900a698ae76c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/resources/log4j.properties +++ b/jvm-packages/xgboost4j-spark/src/test/resources/log4j.properties @@ -1 +1 @@ -log4j.logger.org.apache.spark=ERROR \ No newline at end of file +log4j.logger.org.apache.spark=ERROR diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala index 5ef49431468f..cdcfd76f5bf7 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost} -import org.scalatest.{FunSuite, Ignore} +import org.scalatest.FunSuite import org.apache.hadoop.fs.{FileSystem, Path} class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala index 7e560827b5b6..79562d1f428b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.XGBoostError import org.apache.spark.Partitioner import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.sql.SparkSession import org.scalatest.FunSuite import org.apache.spark.sql.functions._ @@ -53,7 +51,7 @@ class FeatureSizeValidatingSuite extends FunSuite with PerTest { "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0) import DataUtils._ - val sparkSession = SparkSession.builder().getOrCreate() + val sparkSession = ss import sparkSession.implicits._ val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2) .map(lp => (lp.label, lp)).partitionBy( diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala index 9e23d81b51d1..5863e2ace566 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,14 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.XGBoostError import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.DataFrame import org.scalatest.FunSuite - import scala.util.Random +import org.apache.spark.SparkException + class MissingValueHandlingSuite extends FunSuite with PerTest { test("dense vectors containing missing value") { def buildDenseDataFrame(): DataFrame = { @@ -113,7 +113,7 @@ class MissingValueHandlingSuite extends FunSuite with PerTest { val inputDF = vectorAssembler.transform(testDF).select("features", "label") val paramMap = List("eta" -> "1", "max_depth" -> "2", "objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap - intercept[XGBoostError] { + intercept[SparkException] { new XGBoostClassifier(paramMap).fit(inputDF) } } @@ -140,7 +140,7 @@ class MissingValueHandlingSuite extends FunSuite with PerTest { inputDF.show() val paramMap = List("eta" -> "1", "max_depth" -> "2", "objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap - intercept[XGBoostError] { + intercept[SparkException] { new XGBoostClassifier(paramMap).fit(inputDF) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index 50596c69f7ae..ab1226d2bf2f 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.XGBoostError -import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.param.ParamMap class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { @@ -40,28 +40,16 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") } - private def waitForSparkContextShutdown(): Unit = { - var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) { - Thread.sleep(10000) - totalWaitedTime += 10000 - } - assert(ss.sparkContext.isStopped === true) - } - test("fail training elegantly with unsupported objective function") { val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", "objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers) val trainingDF = buildDataFrame(MultiClassification.train) val xgb = new XGBoostClassifier(paramMap) - try { - val model = xgb.fit(trainingDF) - } catch { - case e: Throwable => // swallow anything - } finally { - waitForSparkContextShutdown() + intercept[SparkException] { + xgb.fit(trainingDF) } + } test("fail training elegantly with unsupported eval metrics") { @@ -70,12 +58,8 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { "num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics") val trainingDF = buildDataFrame(MultiClassification.train) val xgb = new XGBoostClassifier(paramMap) - try { - val model = xgb.fit(trainingDF) - } catch { - case e: Throwable => // swallow anything - } finally { - waitForSparkContextShutdown() + intercept[SparkException] { + xgb.fit(trainingDF) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 6148e6dbe8e7..f5775bc4d7bb 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} -import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener} +import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.scalatest.{BeforeAndAfterEach, FunSuite} @@ -40,32 +40,16 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => .appName("XGBoostSuite") .config("spark.ui.enabled", false) .config("spark.driver.memory", "512m") + .config("spark.barrier.sync.timeout", 10) .config("spark.task.cpus", 1) override def beforeEach(): Unit = getOrCreateSession override def afterEach() { - TaskFailedListener.sparkContextShutdownLock.synchronized { - if (currentSession != null) { - // this synchronization is mostly for the tests involving SparkContext shutdown - // for unit test involving the sparkContext shutdown there are two different events sequence - // 1. SparkContext killer is executed before afterEach, in this case, before SparkContext - // is fully stopped, afterEach() will block at the following code block - // 2. SparkContext killer is executed afterEach, in this case, currentSession.stop() in will - // block to wait for all msgs in ListenerBus get processed. Because currentSession.stop() - // has been called, SparkContext killer will not take effect - while (TaskFailedListener.killerStarted) { - TaskFailedListener.sparkContextShutdownLock.wait() - } - currentSession.stop() - cleanExternalCache(currentSession.sparkContext.appName) - currentSession = null - } - if (TaskFailedListener.sparkContextKiller != null) { - TaskFailedListener.sparkContextKiller.interrupt() - TaskFailedListener.sparkContextKiller = null - } - TaskFailedListener.killerStarted = false + if (currentSession != null) { + currentSession.stop() + cleanExternalCache(currentSession.sparkContext.appName) + currentSession = null } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index a1732c7f7e1b..93b7554017a0 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala index 4b3d8d7c936a..7d588d97ce0a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} -import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.scalatest.FunSuite diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 875960ed667c..cd13e4b6cafd 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,12 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.XGBoostError import scala.util.Random import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.DMatrix -import org.apache.spark.TaskContext +import org.apache.spark.{SparkException, TaskContext} import org.scalatest.FunSuite import org.apache.spark.ml.feature.VectorAssembler @@ -375,13 +374,14 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { test("throw exception for empty partition in trainingset") { val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "multi:softmax", "num_class" -> "2", "num_round" -> 5, - "num_workers" -> numWorkers, "tree_method" -> "auto") + "objective" -> "binary:logistic", "num_class" -> "2", "num_round" -> 5, + "num_workers" -> numWorkers, "tree_method" -> "auto", "allow_non_zero_for_missing" -> true) // The Dmatrix will be empty - val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 1, Array(), Array()))) + val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 4, + Array(0, 1, 2, 3), Array(0, 1, 2, 3)))) val xgb = new XGBoostClassifier(paramMap) - intercept[XGBoostError] { - val model = xgb.fit(trainingDF) + intercept[SparkException] { + xgb.fit(trainingDF) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index 7e2cbb6d537f..00a29681ca73 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,15 @@ package ml.dmlc.xgboost4j.scala.spark -import ml.dmlc.xgboost4j.java.{Rabit, XGBoostError} -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} -import org.apache.spark.TaskFailedListener -import org.apache.spark.SparkException +import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.scala.Booster import scala.collection.JavaConverters._ + import org.apache.spark.sql._ import org.scalatest.FunSuite +import org.apache.spark.SparkException + class XGBoostRabitRegressionSuite extends FunSuite with PerTest { val predictionErrorMin = 0.00001f val maxFailure = 2; @@ -33,15 +34,6 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { .config("spark.kryo.classesToRegister", classOf[Booster].getName) .master(s"local[${numWorkers},${maxFailure}]") - private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = { - var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) { - Thread.sleep(10) - totalWaitedTime += 10 - } - return ss.sparkContext.isStopped - } - test("test classification prediction parity w/o ring reduce") { val training = buildDataFrame(Classification.train) val testDF = buildDataFrame(Classification.test) @@ -91,14 +83,11 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { } test("test rabit timeout fail handle") { - // disable spark kill listener to verify if rabit_timeout take effect and kill tasks - TaskFailedListener.killerStarted = true - val training = buildDataFrame(Classification.train) // mock rank 0 failure during 8th allreduce synchronization Rabit.mockList = Array("0,8,0,0").toList.asJava - try { + intercept[SparkException] { new XGBoostClassifier(Map( "eta" -> "0.1", "max_depth" -> "10", @@ -108,39 +97,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { "num_workers" -> numWorkers, "rabit_timeout" -> 0)) .fit(training) - } catch { - case e: Throwable => // swallow anything - } finally { - // assume all tasks throw exception almost same time - // 100ms should be enough to exhaust all retries - assert(waitAndCheckSparkShutdown(100) == true) - TaskFailedListener.killerStarted = false } } - test("test SparkContext should not be killed ") { - cancel("For some reason, sparkContext can't cancel the job locally in the CI env," + - "which will be resolved when introducing barrier mode") - val training = buildDataFrame(Classification.train) - // mock rank 0 failure during 8th allreduce synchronization - Rabit.mockList = Array("0,8,0,0").toList.asJava - - try { - new XGBoostClassifier(Map( - "eta" -> "0.1", - "max_depth" -> "10", - "verbosity" -> "1", - "objective" -> "binary:logistic", - "num_round" -> 5, - "num_workers" -> numWorkers, - "kill_spark_context_on_worker_failure" -> false, - "rabit_timeout" -> 0)) - .fit(training) - } catch { - case e: Throwable => // swallow anything - } finally { - // wait 3s to check if SparkContext is killed - assert(waitAndCheckSparkShutdown(3000) == false) - } - } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index e427c17e31a5..bd104f6c7987 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -21,7 +21,6 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.sql.functions._ import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types._ import org.scalatest.FunSuite import org.apache.spark.ml.feature.VectorAssembler diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala deleted file mode 100644 index cb8fa579476a..000000000000 --- a/jvm-packages/xgboost4j-spark/src/test/scala/org/apache/spark/SparkParallelismTrackerSuite.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package org.apache.spark - -import org.scalatest.FunSuite -import _root_.ml.dmlc.xgboost4j.scala.spark.PerTest -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession - -import scala.math.min - -class SparkParallelismTrackerSuite extends FunSuite with PerTest { - - val numParallelism: Int = min(Runtime.getRuntime.availableProcessors(), 4) - - override protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder() - .master(s"local[${numParallelism}]") - .appName("XGBoostSuite") - .config("spark.ui.enabled", true) - .config("spark.driver.memory", "512m") - .config("spark.task.cpus", 1) - - private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = { - var totalWaitedTime = 0L - while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) { - Thread.sleep(100) - totalWaitedTime += 100 - } - ss.sparkContext.isStopped - } - - test("tracker should not affect execution result when timeout is not larger than 0") { - val nWorkers = numParallelism - val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) - val tracker = new SparkParallelismTracker(sc, 10000, nWorkers) - val disabledTracker = new SparkParallelismTracker(sc, 0, nWorkers) - assert(tracker.execute(rdd.sum()) == rdd.sum()) - assert(disabledTracker.execute(rdd.sum()) == rdd.sum()) - } - - test("tracker should throw exception if parallelism is not sufficient") { - val nWorkers = numParallelism * 3 - val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) - val tracker = new SparkParallelismTracker(sc, 1000, nWorkers) - intercept[IllegalStateException] { - tracker.execute { - rdd.map { i => - // Test interruption - Thread.sleep(Long.MaxValue) - i - }.sum() - } - } - } - - test("tracker should throw exception if parallelism is not sufficient with" + - " spark.task.cpus larger than 1") { - sc.conf.set("spark.task.cpus", "2") - val nWorkers = numParallelism - val rdd: RDD[Int] = sc.parallelize(1 to nWorkers) - val tracker = new SparkParallelismTracker(sc, 1000, nWorkers) - intercept[IllegalStateException] { - tracker.execute { - rdd.map { i => - // Test interruption - Thread.sleep(Long.MaxValue) - i - }.sum() - } - } - } - - test("tracker should not kill SparkContext when killSparkContextOnWorkerFailure=false") { - val nWorkers = numParallelism - val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) - val rdd: RDD[Int] = sc.parallelize(1 to nWorkers, nWorkers) - try { - tracker.execute { - rdd.map { i => - val partitionId = TaskContext.get().partitionId() - if (partitionId == 0) { - throw new RuntimeException("mocking task failing") - } - i - }.sum() - } - } catch { - case e: Exception => // catch the exception - } finally { - // wait 3s to check if SparkContext is killed - assert(waitAndCheckSparkShutdown(3000) == false) - } - } - - test("tracker should cancel the correct job when killSparkContextOnWorkerFailure=false") { - val nWorkers = 2 - val tracker = new SparkParallelismTracker(sc, 0, nWorkers, false) - val rdd: RDD[Int] = sc.parallelize(1 to 10, nWorkers) - val thread = new TestThread(sc) - thread.start() - try { - tracker.execute { - rdd.map { i => - Thread.sleep(100) - val partitionId = TaskContext.get().partitionId() - if (partitionId == 0) { - throw new RuntimeException("mocking task failing") - } - i - }.sum() - } - } catch { - case e: Exception => // catch the exception - } finally { - thread.join(8000) - // wait 3s to check if SparkContext is killed - assert(waitAndCheckSparkShutdown(3000) == false) - } - } - - private[this] class TestThread(sc: SparkContext) extends Thread { - override def run(): Unit = { - var sum: Double = 0.0f - try { - val rdd = sc.parallelize(1 to 4, 2) - sum = rdd.mapPartitions(iter => { - // sleep 2s to ensure task is alive when cancelling other jobs - Thread.sleep(2000) - iter - }).sum() - } finally { - // get the correct result - assert(sum.toInt == 10) - } - } - } -}