diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py index 4d4cd41627d0..c0b347eefb30 100644 --- a/python-package/xgboost/spark/__init__.py +++ b/python-package/xgboost/spark/__init__.py @@ -10,6 +10,7 @@ from .estimator import ( SparkXGBClassifier, SparkXGBClassifierModel, + SparkXGBRanker, SparkXGBRegressor, SparkXGBRegressorModel, ) @@ -19,4 +20,5 @@ "SparkXGBClassifierModel", "SparkXGBRegressor", "SparkXGBRegressorModel", + "SparkXGBRanker", ] diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index d86c36d02c57..0e9509099e13 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -35,7 +35,7 @@ from xgboost.training import train as worker_train import xgboost -from xgboost import XGBClassifier, XGBRegressor +from xgboost import XGBClassifier, XGBRanker, XGBRegressor from .data import ( _read_csr_matrix_from_unwrapped_spark_vec, @@ -54,6 +54,7 @@ HasBaseMarginCol, HasEnableSparseDataOptim, HasFeaturesCols, + HasQueryIdCol, ) from .utils import ( RabitContext, @@ -86,6 +87,7 @@ "feature_names", "features_cols", "enable_sparse_data_optim", + "qid_col", ] _non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"] @@ -116,6 +118,10 @@ "eval_set", "sample_weight_eval_set", "base_margin", # Supported by spark param base_margin_col + "group", # Use spark param `qid_col` instead + "qid", # Use spark param `qid_col` instead + "eval_group", # Use spark param `qid_col` instead + "eval_qid", # Use spark param `qid_col` instead } _unsupported_predict_params = { @@ -136,6 +142,7 @@ class _SparkXGBParams( HasBaseMarginCol, HasFeaturesCols, HasEnableSparseDataOptim, + HasQueryIdCol, ): num_workers = Param( Params._dummy(), @@ -572,13 +579,19 @@ def _get_distributed_train_params(self, dataset): params["verbose_eval"] = verbose_eval classification = self._xgb_cls() == XGBClassifier num_classes = int(dataset.select(countDistinct(alias.label)).collect()[0][0]) - if classification and num_classes == 2: - params["objective"] = "binary:logistic" - elif classification and num_classes > 2: - params["objective"] = "multi:softprob" - params["num_class"] = num_classes + if classification: + num_classes = int( + dataset.select(countDistinct(alias.label)).collect()[0][0] + ) + if num_classes <= 2: + params["objective"] = "binary:logistic" + else: + params["objective"] = "multi:softprob" + params["num_class"] = num_classes else: - params["objective"] = "reg:squarederror" + # use user specified objective or default objective. + # e.g., the default objective for Regressor is 'reg:squarederror' + params["objective"] = self.getOrDefault(self.objective) # TODO: support "num_parallel_tree" for random forest params["num_boost_round"] = self.getOrDefault(self.n_estimators) @@ -648,6 +661,9 @@ def _fit(self, dataset): col(self.getOrDefault(self.base_margin_col)).alias(alias.margin) ) + if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col): + select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid)) + dataset = dataset.select(*select_cols) num_workers = self.getOrDefault(self.num_workers) @@ -782,6 +798,10 @@ def __init__(self, xgb_sklearn_model=None): super().__init__() self._xgb_sklearn_model = xgb_sklearn_model + @classmethod + def _xgb_cls(cls): + raise NotImplementedError() + def get_booster(self): """ Return the `xgboost.core.Booster` instance. @@ -818,9 +838,6 @@ def read(cls): """ return SparkXGBModelReader(cls) - def _transform(self, dataset): - raise NotImplementedError() - def _get_feature_col(self, dataset) -> (list, Optional[list]): """XGBoost model trained with features_cols parameter can also predict vector or array feature type. But first we need to check features_cols @@ -855,18 +872,6 @@ def _get_feature_col(self, dataset) -> (list, Optional[list]): ) return features_col, feature_col_names - -class SparkXGBRegressorModel(_SparkXGBModel): - """ - The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit` - - .. Note:: This API is experimental. - """ - - @classmethod - def _xgb_cls(cls): - return XGBRegressor - def _transform(self, dataset): # Save xgb_sklearn_model and predict_params to be local variable # to avoid the `self` object to be pickled to remote. @@ -920,6 +925,30 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: return dataset.withColumn(predictionColName, pred_col) +class SparkXGBRegressorModel(_SparkXGBModel): + """ + The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit` + + .. Note:: This API is experimental. + """ + + @classmethod + def _xgb_cls(cls): + return XGBRegressor + + +class SparkXGBRankerModel(_SparkXGBModel): + """ + The model returned by :func:`xgboost.spark.SparkXGBRanker.fit` + + .. Note:: This API is experimental. + """ + + @classmethod + def _xgb_cls(cls): + return XGBRanker + + class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol): """ The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit` diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 77e51c412e7f..c01468c12492 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -19,8 +19,8 @@ def stack_series(series: pd.Series) -> np.ndarray: # Global constant for defining column alias shared between estimator and data # processing procedures. -Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid")) -alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator") +Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid", "qid")) +alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator", "qid") def concat_or_none(seq: Optional[Sequence[np.ndarray]]) -> Optional[np.ndarray]: @@ -41,6 +41,7 @@ def make_blob(part: pd.DataFrame, is_valid: bool) -> None: append(part, alias.label, is_valid) append(part, alias.weight, is_valid) append(part, alias.margin, is_valid) + append(part, alias.qid, is_valid) has_validation: Optional[bool] = None @@ -94,6 +95,7 @@ def next(self, input_data: Callable) -> int: label=self._fetch(self._data.get(alias.label, None)), weight=self._fetch(self._data.get(alias.weight, None)), base_margin=self._fetch(self._data.get(alias.margin, None)), + qid=self._fetch(self._data.get(alias.qid, None)), ) self._iter += 1 return 1 @@ -226,9 +228,9 @@ def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix label = concat_or_none(values.get(alias.label, None)) weight = concat_or_none(values.get(alias.weight, None)) margin = concat_or_none(values.get(alias.margin, None)) - + qid = concat_or_none(values.get(alias.qid, None)) return DMatrix( - data=data, label=label, weight=weight, base_margin=margin, **kwargs + data=data, label=label, weight=weight, base_margin=margin, qid=qid, **kwargs ) is_dmatrix = feature_cols is None diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 3e7c5fdf65a3..dbf080944cce 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -3,10 +3,11 @@ # pylint: disable=too-many-ancestors from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol -from xgboost import XGBClassifier, XGBRegressor +from xgboost import XGBClassifier, XGBRanker, XGBRegressor from .core import ( SparkXGBClassifierModel, + SparkXGBRankerModel, SparkXGBRegressorModel, _set_pyspark_xgb_cls_param_attrs, _SparkXGBEstimator, @@ -106,6 +107,13 @@ def _xgb_cls(cls): def _pyspark_model_cls(cls): return SparkXGBRegressorModel + def _validate_params(self): + super()._validate_params() + if self.isDefined(self.qid_col): + raise ValueError( + "Spark Xgboost regressor estimator does not support `qid_col` param." + ) + _set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel) @@ -213,5 +221,126 @@ def _xgb_cls(cls): def _pyspark_model_cls(cls): return SparkXGBClassifierModel + def _validate_params(self): + super()._validate_params() + if self.isDefined(self.qid_col): + raise ValueError( + "Spark Xgboost classifier estimator does not support `qid_col` param." + ) + _set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel) + + +class SparkXGBRanker(_SparkXGBEstimator): + """SparkXGBRanker is a PySpark ML estimator. It implements the XGBoost + classification algorithm based on XGBoost python library, and it can be used in + PySpark Pipeline and PySpark ML meta algorithms like + :py:class:`~pyspark.ml.tuning.CrossValidator`/ + :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/ + :py:class:`~pyspark.ml.classification.OneVsRest` + + SparkXGBRanker automatically supports most of the parameters in + `xgboost.XGBClassifier` constructor and most of the parameters used in + :py:class:`xgboost.XGBClassifier` fit and predict method. + + SparkXGBRanker doesn't support setting `gpu_id` but support another param `use_gpu`, + see doc below for more details. + + SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support + another param called `base_margin_col`. see doc below for more details. + + SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin + from the raw prediction column. See `raw_prediction_col` param doc below for more details. + + SparkXGBRanker doesn't support `validate_features` and `output_margin` param. + + SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the `nthread` + param for each xgboost worker will be set equal to `spark.task.cpus` config value. + + + Parameters + ---------- + + callbacks: + The export and import of the callback functions are at best effort. For + details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc. + validation_indicator_col: + For params related to `xgboost.XGBClassifier` training with + evaluation dataset's supervision, + set :py:attr:`xgboost.spark.SparkXGBClassifier.validation_indicator_col` + parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier` + fit method. + weight_col: + To specify the weight of the training and validation dataset, set + :py:attr:`xgboost.spark.SparkXGBClassifier.weight_col` parameter instead of setting + `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier` + fit method. + xgb_model: + Set the value to be the instance returned by + :func:`xgboost.spark.SparkXGBClassifierModel.get_booster`. + num_workers: + Integer that specifies the number of XGBoost workers to use. + Each XGBoost worker corresponds to one spark task. + use_gpu: + Boolean that specifies whether the executors are running on GPU + instances. + base_margin_col: + To specify the base margins of the training and validation + dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.base_margin_col` parameter + instead of setting `base_margin` and `base_margin_eval_set` in the + `xgboost.XGBRanker` fit method. + qid_col: + To specify the qid of the training and validation + dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.qid_col` parameter + instead of setting `qid` / `group`, `eval_qid` / `eval_group` in the + `xgboost.XGBRanker` fit method. + + .. Note:: The Parameters chart above contains parameters that need special handling. + For a full list of parameters, see entries with `Param(parent=...` below. + + .. Note:: This API is experimental. + + Examples + -------- + + >>> from xgboost.spark import SparkXGBClassifier + >>> from pyspark.ml.linalg import Vectors + >>> df_train = spark.createDataFrame([ + ... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), + ... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), + ... (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), + ... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), + ... ], ["features", "label", "isVal", "weight"]) + >>> df_test = spark.createDataFrame([ + ... (Vectors.dense(1.0, 2.0, 3.0), ), + ... ], ["features"]) + >>> xgb_classifier = SparkXGBClassifier(max_depth=5, missing=0.0, + ... validation_indicator_col='isVal', weight_col='weight', + ... early_stopping_rounds=1, eval_metric='logloss') + >>> xgb_clf_model = xgb_classifier.fit(df_train) + >>> xgb_clf_model.transform(df_test).show() + + """ + + def __init__(self, **kwargs): + super().__init__() + self.setParams(**kwargs) + + @classmethod + def _xgb_cls(cls): + return XGBRanker + + @classmethod + def _pyspark_model_cls(cls): + return SparkXGBRankerModel + + def _validate_params(self): + super()._validate_params() + if not self.isDefined(self.qid_col): + raise ValueError( + "Spark Xgboost ranker estimator requires setting `qid_col` param." + ) + + +_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel) diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 01417c016af4..ed46ba20ec40 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -72,3 +72,17 @@ class HasEnableSparseDataOptim(Params): def __init__(self): super().__init__() self._setDefault(enable_sparse_data_optim=False) + + +class HasQueryIdCol(Params): + """ + Mixin for param featuresCols: a list of feature column names. + This parameter is taken effect only when use_gpu is enabled. + """ + + qid_col = Param( + Params._dummy(), + "qid_col", + "query id column name", + typeConverter=TypeConverters.toString, + ) diff --git a/tests/python/test_spark/test_spark_local.py b/tests/python/test_spark/test_spark_local.py index 22adc899e7e1..14377663c388 100644 --- a/tests/python/test_spark/test_spark_local.py +++ b/tests/python/test_spark/test_spark_local.py @@ -24,6 +24,7 @@ from xgboost.spark import ( SparkXGBClassifier, SparkXGBClassifierModel, + SparkXGBRanker, SparkXGBRegressor, SparkXGBRegressorModel, ) @@ -380,6 +381,28 @@ def setUp(self): "expected_prediction_with_base_margin", ], ) + self.ranker_df_train = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, 0), + (Vectors.dense(4.0, 5.0, 6.0), 1, 0), + (Vectors.dense(9.0, 4.0, 8.0), 2, 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), + ], + ["features", "label", "qid"], + ) + self.ranker_df_test = self.session.createDataFrame( + [ + (Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988), + (Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556), + (Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570), + (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988), + (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612), + (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826), + ], + ["features", "qid", "expected_prediction"], + ) self.reg_df_sparse_train = self.session.createDataFrame( [ @@ -1024,3 +1047,12 @@ def test_classifier_with_sparse_optim(self): for row1, row2 in zip(pred_result, pred_result2): self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3)) + + def test_ranker(self): + ranker = SparkXGBRanker(qid_col="qid") + assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" + model = ranker.fit(self.ranker_df_train) + pred_result = model.transform(self.ranker_df_test).collect() + + for row in pred_result: + assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)