diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 4121d7610bd7..6ec8dfb57aaa 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -164,6 +164,12 @@ class _SparkXGBParams( + "Note: The auto repartitioning judgement is not fully accurate, so it is recommended" + "to have force_repartition be True.", ) + repartition_random_shuffle = Param( + Params._dummy(), + "repartition_random_shuffle", + "A boolean variable. Set repartition_random_shuffle=true if you want to random shuffle " + "dataset when repartitioning is required. By default is True." + ) feature_names = Param( Params._dummy(), "feature_names", "A list of str to specify feature names." ) @@ -461,6 +467,7 @@ def __init__(self): num_workers=1, use_gpu=False, force_repartition=False, + repartition_random_shuffle=True, feature_names=None, feature_types=None, arbitrary_params_dict={}, @@ -689,10 +696,12 @@ def _fit(self, dataset): # to balance data, because user might unionise train and validation dataset, # without shuffling data then some partitions might contain only train or validation # dataset. - # Repartition on `rand` column to avoid repartition - # result unbalance. Directly using `.repartition(N)` might result in some - # empty partitions. - dataset = dataset.repartition(num_workers, rand(1)) + if self.getOrDefault(self.repartition_random_shuffle): + # In some cases, spark round-robin repartition might cause data skew + # use random shuffle can address it. + dataset = dataset.repartition(num_workers, rand(1)) + else: + dataset = dataset.repartition(num_workers) train_params = self._get_distributed_train_params(dataset) booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(