Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rabit update. #5978

Merged
merged 3 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ object XGBoost extends Serializable {
val attempt = TaskContext.get().attemptNumber.toString
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, as I commented in dmlc/rabit#143, this is something we do need to implement current error handling strategy

regarding #4826 , what's your current plan? I think we could replace "kill SparkContext" with "kill the corresponding job" in the case of a task failure in JVM layer

Copy link
Member Author

@trivialfis trivialfis Aug 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CodingCat It's still on early planing stage and I'm obviously not the right person to talk about spark, so my thought will be largely based on experience with dask. For spark specific things I usually just forward to @wbo4958 ;-). I have a few plans that I haven't decided on which one to go yet. They can be roughly listed as follow:

  • We drop the support on single node recovery, and rely on Spark to recover from training failures. My goal is to have XGBoost fail gracefully without interrupting others. On dask a Python exception should do it. If we do this I want to replace rabit with more matured MPI solution.

"kill the corresponding job" in the case of a task failure in JVM layer

Sounds good to me. Whatever that doesn't interrupt the others.

  • We continue investing on support for single node recovery, but instead we use call sequence number as cache key. Previously in the doc from @chenqin the bootstrapping calls can be out of ordered. But so far I believe it's strictly ordered in a single threaded application (not accounting for omp parallalization happens inside XGBoost). My major concern on this is multi-threaded application. It's not common for users to create DMatrix on their own with the dask interface, but I have seen people trying to access DMatrix by themselves to bypass dask data loading procedure to avoid extra memory overhead and data balancing. If they do it asynchronously then we need serious tests on rabit's thread safety guarantee (which doesn't exist).

Copy link
Member Author

@trivialfis trivialfis Aug 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way, the exit(-1) call has to go, as we need a proper exception.

val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try {
Expand Down Expand Up @@ -997,4 +996,3 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
group
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ class XGBoostClassificationModel private[ml](
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ class XGBoostRegressionModel private[ml] (
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
}

Expand Down