Skip to content

Commit

Permalink
[pyspark] Filter out the unsupported train parameters (#8355)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Oct 18, 2022
1 parent 3901f5d commit 76f95a6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python-package/xgboost/spark/core.py
Expand Up @@ -126,6 +126,11 @@
"eval_qid", # Use spark param `qid_col` instead
}

_unsupported_train_params = {
"evals", # Supported by spark param validation_indicator_col
"evals_result", # Won't support yet+
}

_unsupported_predict_params = {
# for classification, we can use rawPrediction as margin
"output_margin",
Expand Down Expand Up @@ -515,6 +520,7 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name
k in _unsupported_xgb_params
or k in _unsupported_fit_params
or k in _unsupported_predict_params
or k in _unsupported_train_params
):
raise ValueError(f"Unsupported param '{k}'.")
_extra_params[k] = v
Expand Down Expand Up @@ -620,7 +626,9 @@ def _get_distributed_train_params(self, dataset):

@classmethod
def _get_xgb_train_call_args(cls, train_params):
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
xgb_train_default_args = _get_default_params_from_func(
xgboost.train, _unsupported_train_params
)
booster_params, kwargs_params = {}, {}
for key, value in train_params.items():
if key in xgb_train_default_args:
Expand Down
4 changes: 4 additions & 0 deletions tests/python/test_spark/test_spark_local.py
Expand Up @@ -1126,3 +1126,7 @@ def test_early_stop_param_validation(self):
classifier = SparkXGBClassifier(early_stopping_rounds=1)
with pytest.raises(ValueError, match="early_stopping_rounds"):
classifier.fit(self.cls_df_train)

def test_unsupported_params(self):
with pytest.raises(ValueError, match="evals_result"):
SparkXGBClassifier(evals_result={})

0 comments on commit 76f95a6

Please sign in to comment.