Skip to content

Commit

Permalink
[pyspark] Add param validation for "objective" and "eval_metric" para…
Browse files Browse the repository at this point in the history
…m, and remove invalid booster params (#8173)


Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Aug 24, 2022
1 parent 9b32e6e commit d03794c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
19 changes: 15 additions & 4 deletions python-package/xgboost/spark/core.py
Expand Up @@ -114,10 +114,10 @@

_unsupported_fit_params = {
"sample_weight", # Supported by spark param weightCol
# Supported by spark param weightCol # and validationIndicatorCol
"eval_set",
"sample_weight_eval_set",
"eval_set", # Supported by spark param validation_indicator_col
"sample_weight_eval_set", # Supported by spark param weight_col + validation_indicator_col
"base_margin", # Supported by spark param base_margin_col
"base_margin_eval_set", # Supported by spark param base_margin_col + validation_indicator_col
"group", # Use spark param `qid_col` instead
"qid", # Use spark param `qid_col` instead
"eval_group", # Use spark param `qid_col` instead
Expand Down Expand Up @@ -287,6 +287,14 @@ def _validate_params(self):
"If features_cols param set, then features_col param is ignored."
)

if self.getOrDefault(self.objective) is not None:
if not isinstance(self.getOrDefault(self.objective), str):
raise ValueError("Only string type 'objective' param is allowed.")

if self.getOrDefault(self.eval_metric) is not None:
if not isinstance(self.getOrDefault(self.eval_metric), str):
raise ValueError("Only string type 'eval_metric' param is allowed.")

if self.getOrDefault(self.enable_sparse_data_optim):
if self.getOrDefault(self.missing) != 0.0:
# If DMatrix is constructed from csr / csc matrix, then inactive elements
Expand Down Expand Up @@ -578,7 +586,6 @@ def _get_distributed_train_params(self, dataset):
params.update(fit_params)
params["verbose_eval"] = verbose_eval
classification = self._xgb_cls() == XGBClassifier
num_classes = int(dataset.select(countDistinct(alias.label)).collect()[0][0])
if classification:
num_classes = int(
dataset.select(countDistinct(alias.label)).collect()[0][0]
Expand Down Expand Up @@ -610,6 +617,10 @@ def _get_xgb_train_call_args(cls, train_params):
kwargs_params[key] = value
else:
booster_params[key] = value

booster_params = {
k: v for k, v in booster_params.items() if k not in _non_booster_params
}
return booster_params, kwargs_params

def _fit(self, dataset):
Expand Down
9 changes: 9 additions & 0 deletions python-package/xgboost/spark/estimator.py
Expand Up @@ -211,6 +211,11 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction

def __init__(self, **kwargs):
super().__init__()
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
# but in pyspark we will automatically set objective param depending on
# binary or multinomial input dataset, and we need to remove the fixed default
# param value as well to avoid causing ambiguity.
self._setDefault(objective=None)
self.setParams(**kwargs)

@classmethod
Expand All @@ -227,6 +232,10 @@ def _validate_params(self):
raise ValueError(
"Spark Xgboost classifier estimator does not support `qid_col` param."
)
if self.getOrDefault(self.objective): # pylint: disable=no-member
raise ValueError(
"Setting custom 'objective' param is not allowed in 'SparkXGBClassifier'."
)


_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
Expand Down
2 changes: 2 additions & 0 deletions tests/python/test_spark/test_spark_local.py
Expand Up @@ -433,6 +433,7 @@ def test_regressor_params_basic(self):
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
self.assertFalse(hasattr(py_reg, "gpu_id"))
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror")
py_reg2 = SparkXGBRegressor(n_estimators=200)
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
Expand All @@ -445,6 +446,7 @@ def test_classifier_params_basic(self):
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
self.assertFalse(hasattr(py_cls, "gpu_id"))
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
self.assertEqual(py_cls.getOrDefault(py_cls.objective), None)
py_cls2 = SparkXGBClassifier(n_estimators=200)
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
Expand Down

0 comments on commit d03794c

Please sign in to comment.