Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Sep 8, 2022
1 parent b397d64 commit 9854131
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions python-package/xgboost/spark/core.py
Expand Up @@ -20,7 +20,9 @@
HasWeightCol,
)
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.sql.functions import col, countDistinct, pandas_udf, struct
from pyspark.sql.functions import (
col, countDistinct, pandas_udf, struct, monotonically_increasing_id
)
from pyspark.sql.types import (
ArrayType,
DoubleType,
Expand Down Expand Up @@ -270,15 +272,6 @@ def _validate_params(self):
f"It cannot be less than 1 [Default is 1]"
)

if (
self.getOrDefault(self.force_repartition)
and self.getOrDefault(self.num_workers) == 1
):
get_logger(self.__class__.__name__).warning(
"You set force_repartition to true when there is no need for a repartition."
"Therefore, that parameter will be ignored."
)

if self.getOrDefault(self.features_cols):
if not self.getOrDefault(self.use_gpu):
raise ValueError("features_cols param requires enabling use_gpu.")
Expand Down Expand Up @@ -691,7 +684,10 @@ def _fit(self, dataset):
)

if self._repartition_needed(dataset):
dataset = dataset.repartition(num_workers)
# Repartition on `monotonically_increasing_id` column to avoid repartition
# result unbalance. Directly using `.repartition(N)` might result in some
# empty partitions.
dataset = dataset.repartition(num_workers, monotonically_increasing_id())
train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params
Expand Down

0 comments on commit 9854131

Please sign in to comment.