/
XGBoostRabitRegressionSuite.scala
143 lines (121 loc) · 5.22 KB
/
XGBoostRabitRegressionSuite.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.{Rabit, XGBoostError}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.TaskFailedListener
import org.apache.spark.SparkException
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.FunSuite
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
val predictionErrorMin = 0.00001f
val maxFailure = 2;
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
.master(s"local[${numWorkers},${maxFailure}]")
private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) {
Thread.sleep(10)
totalWaitedTime += 10
}
return ss.sparkContext.isStopped
}
test("test classification prediction parity w/o ring reduce") {
val training = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val model1 = new XGBoostClassifier(xgbSettings).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
.fit(training)
assert(Rabit.rabitEnvs.asScala.size > 3)
Rabit.rabitEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 == p2)
}
}
test("test regression prediction parity w/o ring reduce") {
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
val model1 = new XGBoostRegressor(xgbSettings).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
).fit(training)
assert(Rabit.rabitEnvs.asScala.size > 3)
Rabit.rabitEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
// check the equality of single instance prediction
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(math.abs(p1 - p2) < predictionErrorMin)
}
}
test("test rabit timeout fail handle") {
// disable spark kill listener to verify if rabit_timeout take effect and kill tasks
TaskFailedListener.killerStarted = true
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Rabit.mockList = Array("0,8,0,0").toList.asJava
try {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"rabit_timeout" -> 0))
.fit(training)
} catch {
case e: Throwable => // swallow anything
} finally {
// assume all tasks throw exception almost same time
// 100ms should be enough to exhaust all retries
assert(waitAndCheckSparkShutdown(100) == true)
}
}
test("test SparkContext should not be killed ") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Rabit.mockList = Array("0,8,0,0").toList.asJava
try {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"kill_spark_context_on_worker_failure" -> false,
"rabit_timeout" -> 0))
.fit(training)
} catch {
case e: Throwable => // swallow anything
} finally {
// wait 3s to check if SparkContext is killed
assert(waitAndCheckSparkShutdown(3000) == false)
}
}
}