/
PerTest.scala
111 lines (90 loc) · 3.5 KB
/
PerTest.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.math.min
import scala.util.Random
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
@transient private var currentSession: SparkSession = _
def ss: SparkSession = getOrCreateSession
implicit def sc: SparkContext = ss.sparkContext
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
.master(s"local[${numWorkers}]")
.appName("XGBoostSuite")
.config("spark.ui.enabled", false)
.config("spark.driver.memory", "512m")
.config("spark.task.cpus", 1)
override def beforeEach(): Unit = getOrCreateSession
override def afterEach() {
synchronized {
if (currentSession != null) {
currentSession.stop()
cleanExternalCache(currentSession.sparkContext.appName)
currentSession = null
}
TaskFailedListener.killerStarted = false
TaskFailedListener.cancelJobStarted = false
}
}
private def getOrCreateSession = synchronized {
if (currentSession == null) {
currentSession = sparkSessionBuilder.getOrCreate()
currentSession.sparkContext.setLogLevel("ERROR")
}
currentSession
}
private def cleanExternalCache(prefix: String): Unit = {
val dir = new File(".")
for (file <- dir.listFiles() if file.getName.startsWith(prefix)) {
file.delete()
}
}
protected def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
}
protected def buildDataFrameWithRandSort(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
val df = buildDataFrame(labeledPoints, numPartitions)
val rndSortedRDD = df.rdd.mapPartitions { iter =>
iter.map(_ -> Random.nextDouble()).toList
.sortBy(_._2)
.map(_._1).iterator
}
ss.createDataFrame(rndSortedRDD, df.schema)
}
protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
}
}