Skip to content

Commit

Permalink
[pyspark] Fix xgboost spark estimator dataset repartition issues (#8231)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed Sep 22, 2022
1 parent 3fd331f commit ab342af
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions python-package/xgboost/spark/core.py
Expand Up @@ -20,7 +20,7 @@
HasWeightCol,
)
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.sql.functions import col, countDistinct, pandas_udf, struct
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
ArrayType,
DoubleType,
Expand Down Expand Up @@ -164,6 +164,12 @@ class _SparkXGBParams(
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
+ "to have force_repartition be True.",
)
repartition_random_shuffle = Param(
Params._dummy(),
"repartition_random_shuffle",
"A boolean variable. Set repartition_random_shuffle=true if you want to random shuffle "
"dataset when repartitioning is required. By default is True.",
)
feature_names = Param(
Params._dummy(), "feature_names", "A list of str to specify feature names."
)
Expand Down Expand Up @@ -270,15 +276,6 @@ def _validate_params(self):
f"It cannot be less than 1 [Default is 1]"
)

if (
self.getOrDefault(self.force_repartition)
and self.getOrDefault(self.num_workers) == 1
):
get_logger(self.__class__.__name__).warning(
"You set force_repartition to true when there is no need for a repartition."
"Therefore, that parameter will be ignored."
)

if self.getOrDefault(self.features_cols):
if not self.getOrDefault(self.use_gpu):
raise ValueError("features_cols param requires enabling use_gpu.")
Expand Down Expand Up @@ -470,6 +467,7 @@ def __init__(self):
num_workers=1,
use_gpu=False,
force_repartition=False,
repartition_random_shuffle=True,
feature_names=None,
feature_types=None,
arbitrary_params_dict={},
Expand Down Expand Up @@ -695,8 +693,21 @@ def _fit(self, dataset):
num_workers,
)

if self._repartition_needed(dataset):
dataset = dataset.repartition(num_workers)
if self._repartition_needed(dataset) or (
self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol)
):
# If validationIndicatorCol defined, we always repartition dataset
# to balance data, because user might unionise train and validation dataset,
# without shuffling data then some partitions might contain only train or validation
# dataset.
if self.getOrDefault(self.repartition_random_shuffle):
# In some cases, spark round-robin repartition might cause data skew
# use random shuffle can address it.
dataset = dataset.repartition(num_workers, rand(1))
else:
dataset = dataset.repartition(num_workers)

train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params
Expand Down

0 comments on commit ab342af

Please sign in to comment.