Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Sep 15, 2022
1 parent 5f89620 commit 9410a89
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python-package/xgboost/spark/core.py
Expand Up @@ -164,6 +164,12 @@ class _SparkXGBParams(
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
+ "to have force_repartition be True.",
)
repartition_random_shuffle = Param(
Params._dummy(),
"repartition_random_shuffle",
"A boolean variable. Set repartition_random_shuffle=true if you want to random shuffle "
"dataset when repartitioning is required. By default is True."
)
feature_names = Param(
Params._dummy(), "feature_names", "A list of str to specify feature names."
)
Expand Down Expand Up @@ -461,6 +467,7 @@ def __init__(self):
num_workers=1,
use_gpu=False,
force_repartition=False,
repartition_random_shuffle=True,
feature_names=None,
feature_types=None,
arbitrary_params_dict={},
Expand Down Expand Up @@ -689,10 +696,12 @@ def _fit(self, dataset):
# to balance data, because user might unionise train and validation dataset,
# without shuffling data then some partitions might contain only train or validation
# dataset.
# Repartition on `rand` column to avoid repartition
# result unbalance. Directly using `.repartition(N)` might result in some
# empty partitions.
dataset = dataset.repartition(num_workers, rand(1))
if self.getOrDefault(self.repartition_random_shuffle):
# In some cases, spark round-robin repartition might cause data skew
# use random shuffle can address it.
dataset = dataset.repartition(num_workers, rand(1))
else:
dataset = dataset.repartition(num_workers)

train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
Expand Down

0 comments on commit 9410a89

Please sign in to comment.