Skip to content

Commit

Permalink
[jvm-package] remove the coalesce in barrier mode (dmlc#7846)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Apr 28, 2022
1 parent 11271e0 commit c24acff
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 17 deletions.
Expand Up @@ -69,7 +69,7 @@ public void testBooster() throws XGBoostError {
.hasHeader().build();

int maxBin = 16;
int round = 100;
int round = 10;
//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
Expand Down
Expand Up @@ -407,14 +407,9 @@ object GpuPreXGBoost extends PreXGBoostProvider {
}

private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = {
// We can't check dataFrame.rdd.getNumPartitions == nWorkers here, since dataFrame.rdd is
// a lazy variable. If we call it here, we will not directly extract RDD[Table] again,
// instead, we will involve Columnar -> Row -> Columnar and decrease the performance
if (nWorkers == 1) {
dataFrame.coalesce(1)
} else {
dataFrame.repartition(nWorkers)
}
// we can't involve any coalesce operation here, since Barrier mode will check
// the RDD patterns which does not allow coalesce.
dataFrame.repartition(nWorkers)
}

private def repartitionForGroup(
Expand Down
Expand Up @@ -39,13 +39,8 @@ trait GpuTestSuite extends FunSuite with TmpFolderSuite {

def enableCsvConf(): SparkConf = {
new SparkConf()
.set(RapidsConf.ENABLE_READ_CSV_DATES.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_BYTES.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_SHORTS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_INTEGERS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_LONGS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_FLOATS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_DOUBLES.key, "true")
.set("spark.rapids.sql.csv.read.float.enabled", "true")
.set("spark.rapids.sql.csv.read.double.enabled", "true")
}

def withGpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
Expand Down Expand Up @@ -246,12 +241,13 @@ object SparkSessionHolder extends Logging {
Locale.setDefault(Locale.US)

val builder = SparkSession.builder()
.master("local[1]")
.master("local[2]")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.rapids.sql.enabled", "false")
.config("spark.rapids.sql.test.enabled", "false")
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
.config("spark.sql.files.maxPartitionBytes", "1000")
.appName("XGBoost4j-Spark-Gpu unit test")

builder.getOrCreate()
Expand Down

0 comments on commit c24acff

Please sign in to comment.