Skip to content

Commit

Permalink
[pyspark] Add validation for param 'early_stopping_rounds' and 'valid…
Browse files Browse the repository at this point in the history
…ation_indicator_col' (#8250)



Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Sep 26, 2022
1 parent 0cd11b8 commit ff71c69
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python-package/xgboost/spark/core.py
Expand Up @@ -292,6 +292,16 @@ def _validate_params(self):
if not isinstance(self.getOrDefault(self.eval_metric), str):
raise ValueError("Only string type 'eval_metric' param is allowed.")

if self.getOrDefault(self.early_stopping_rounds) is not None:
if not (
self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol)
):
raise ValueError(
"If 'early_stopping_rounds' param is set, you need to set "
"'validation_indicator_col' param as well."
)

if self.getOrDefault(self.enable_sparse_data_optim):
if self.getOrDefault(self.missing) != 0.0:
# If DMatrix is constructed from csr / csc matrix, then inactive elements
Expand Down
5 changes: 5 additions & 0 deletions tests/python/test_spark/test_spark_local.py
Expand Up @@ -1145,3 +1145,8 @@ def test_empty_partition(self):
num_workers=4,
)
classifier.fit(data_trans)

def test_early_stop_param_validation(self):
classifier = SparkXGBClassifier(early_stopping_rounds=1)
with pytest.raises(ValueError, match="early_stopping_rounds"):
classifier.fit(self.cls_df_train)

0 comments on commit ff71c69

Please sign in to comment.