diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 10ec30b6227a..cbec5c795f19 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -126,6 +126,11 @@ "eval_qid", # Use spark param `qid_col` instead } +_unsupported_train_params = { + "evals", # Supported by spark param validation_indicator_col + "evals_result", # Won't support yet+ +} + _unsupported_predict_params = { # for classification, we can use rawPrediction as margin "output_margin", @@ -515,6 +520,7 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name k in _unsupported_xgb_params or k in _unsupported_fit_params or k in _unsupported_predict_params + or k in _unsupported_train_params ): raise ValueError(f"Unsupported param '{k}'.") _extra_params[k] = v @@ -620,7 +626,9 @@ def _get_distributed_train_params(self, dataset): @classmethod def _get_xgb_train_call_args(cls, train_params): - xgb_train_default_args = _get_default_params_from_func(xgboost.train, {}) + xgb_train_default_args = _get_default_params_from_func( + xgboost.train, _unsupported_train_params + ) booster_params, kwargs_params = {}, {} for key, value in train_params.items(): if key in xgb_train_default_args: diff --git a/tests/python/test_spark/test_spark_local.py b/tests/python/test_spark/test_spark_local.py index b7505b4a89e3..03981d955040 100644 --- a/tests/python/test_spark/test_spark_local.py +++ b/tests/python/test_spark/test_spark_local.py @@ -1126,3 +1126,7 @@ def test_early_stop_param_validation(self): classifier = SparkXGBClassifier(early_stopping_rounds=1) with pytest.raises(ValueError, match="early_stopping_rounds"): classifier.fit(self.cls_df_train) + + def test_unsupported_params(self): + with pytest.raises(ValueError, match="evals_result"): + SparkXGBClassifier(evals_result={})