Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] Implement SparkXGBRanker estimator #8172

Merged
merged 18 commits into from Aug 22, 2022
2 changes: 2 additions & 0 deletions python-package/xgboost/spark/__init__.py
Expand Up @@ -10,6 +10,7 @@
from .estimator import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRanker,
SparkXGBRegressor,
SparkXGBRegressorModel,
)
Expand All @@ -19,4 +20,5 @@
"SparkXGBClassifierModel",
"SparkXGBRegressor",
"SparkXGBRegressorModel",
"SparkXGBRanker",
]
73 changes: 51 additions & 22 deletions python-package/xgboost/spark/core.py
Expand Up @@ -35,7 +35,7 @@
from xgboost.training import train as worker_train

import xgboost
from xgboost import XGBClassifier, XGBRegressor
from xgboost import XGBClassifier, XGBRanker, XGBRegressor

from .data import (
_read_csr_matrix_from_unwrapped_spark_vec,
Expand All @@ -54,6 +54,7 @@
HasBaseMarginCol,
HasEnableSparseDataOptim,
HasFeaturesCols,
HasQueryIdCol,
)
from .utils import (
RabitContext,
Expand Down Expand Up @@ -86,6 +87,7 @@
"feature_names",
"features_cols",
"enable_sparse_data_optim",
"qid_col",
]

_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
Expand Down Expand Up @@ -116,6 +118,10 @@
"eval_set",
"sample_weight_eval_set",
"base_margin", # Supported by spark param base_margin_col
"group", # Use spark param `qid_col` instead
"qid", # Use spark param `qid_col` instead
"eval_group", # Use spark param `qid_col` instead
"eval_qid", # Use spark param `qid_col` instead
}

_unsupported_predict_params = {
Expand All @@ -136,6 +142,7 @@ class _SparkXGBParams(
HasBaseMarginCol,
HasFeaturesCols,
HasEnableSparseDataOptim,
HasQueryIdCol,
):
num_workers = Param(
Params._dummy(),
Expand Down Expand Up @@ -572,13 +579,19 @@ def _get_distributed_train_params(self, dataset):
params["verbose_eval"] = verbose_eval
classification = self._xgb_cls() == XGBClassifier
num_classes = int(dataset.select(countDistinct(alias.label)).collect()[0][0])
if classification and num_classes == 2:
params["objective"] = "binary:logistic"
elif classification and num_classes > 2:
params["objective"] = "multi:softprob"
params["num_class"] = num_classes
if classification:
num_classes = int(
dataset.select(countDistinct(alias.label)).collect()[0][0]
)
if num_classes <= 2:
params["objective"] = "binary:logistic"
else:
params["objective"] = "multi:softprob"
params["num_class"] = num_classes
else:
params["objective"] = "reg:squarederror"
# use user specified objective or default objective.
# e.g., the default objective for Regressor is 'reg:squarederror'
params["objective"] = self.getOrDefault(self.objective)

# TODO: support "num_parallel_tree" for random forest
params["num_boost_round"] = self.getOrDefault(self.n_estimators)
Expand Down Expand Up @@ -648,6 +661,9 @@ def _fit(self, dataset):
col(self.getOrDefault(self.base_margin_col)).alias(alias.margin)
)

if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))

dataset = dataset.select(*select_cols)

num_workers = self.getOrDefault(self.num_workers)
Expand Down Expand Up @@ -782,6 +798,10 @@ def __init__(self, xgb_sklearn_model=None):
super().__init__()
self._xgb_sklearn_model = xgb_sklearn_model

@classmethod
def _xgb_cls(cls):
raise NotImplementedError()

def get_booster(self):
"""
Return the `xgboost.core.Booster` instance.
Expand Down Expand Up @@ -818,9 +838,6 @@ def read(cls):
"""
return SparkXGBModelReader(cls)

def _transform(self, dataset):
raise NotImplementedError()

def _get_feature_col(self, dataset) -> (list, Optional[list]):
"""XGBoost model trained with features_cols parameter can also predict
vector or array feature type. But first we need to check features_cols
Expand Down Expand Up @@ -855,18 +872,6 @@ def _get_feature_col(self, dataset) -> (list, Optional[list]):
)
return features_col, feature_col_names


class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`

.. Note:: This API is experimental.
"""

@classmethod
def _xgb_cls(cls):
return XGBRegressor

def _transform(self, dataset):
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
Expand Down Expand Up @@ -920,6 +925,30 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
return dataset.withColumn(predictionColName, pred_col)


class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`

.. Note:: This API is experimental.
"""

@classmethod
def _xgb_cls(cls):
return XGBRegressor


class SparkXGBRankerModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`

.. Note:: This API is experimental.
"""

@classmethod
def _xgb_cls(cls):
return XGBRanker


class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol):
"""
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
Expand Down
10 changes: 6 additions & 4 deletions python-package/xgboost/spark/data.py
Expand Up @@ -19,8 +19,8 @@ def stack_series(series: pd.Series) -> np.ndarray:

# Global constant for defining column alias shared between estimator and data
# processing procedures.
Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid"))
alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator")
Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid", "qid"))
alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator", "qid")


def concat_or_none(seq: Optional[Sequence[np.ndarray]]) -> Optional[np.ndarray]:
Expand All @@ -41,6 +41,7 @@ def make_blob(part: pd.DataFrame, is_valid: bool) -> None:
append(part, alias.label, is_valid)
append(part, alias.weight, is_valid)
append(part, alias.margin, is_valid)
append(part, alias.qid, is_valid)

has_validation: Optional[bool] = None

Expand Down Expand Up @@ -94,6 +95,7 @@ def next(self, input_data: Callable) -> int:
label=self._fetch(self._data.get(alias.label, None)),
weight=self._fetch(self._data.get(alias.weight, None)),
base_margin=self._fetch(self._data.get(alias.margin, None)),
qid=self._fetch(self._data.get(alias.qid, None)),
)
self._iter += 1
return 1
Expand Down Expand Up @@ -226,9 +228,9 @@ def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix
label = concat_or_none(values.get(alias.label, None))
weight = concat_or_none(values.get(alias.weight, None))
margin = concat_or_none(values.get(alias.margin, None))

qid = concat_or_none(values.get(alias.qid, None))
return DMatrix(
data=data, label=label, weight=weight, base_margin=margin, **kwargs
data=data, label=label, weight=weight, base_margin=margin, qid=qid, **kwargs
)

is_dmatrix = feature_cols is None
Expand Down
131 changes: 130 additions & 1 deletion python-package/xgboost/spark/estimator.py
Expand Up @@ -3,10 +3,11 @@
# pylint: disable=too-many-ancestors
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol

from xgboost import XGBClassifier, XGBRegressor
from xgboost import XGBClassifier, XGBRanker, XGBRegressor

from .core import (
SparkXGBClassifierModel,
SparkXGBRankerModel,
SparkXGBRegressorModel,
_set_pyspark_xgb_cls_param_attrs,
_SparkXGBEstimator,
Expand Down Expand Up @@ -106,6 +107,13 @@ def _xgb_cls(cls):
def _pyspark_model_cls(cls):
return SparkXGBRegressorModel

def _validate_params(self):
super()._validate_params()
if self.isDefined(self.qid_col):
raise ValueError(
"Spark Xgboost regressor estimator does not support `qid_col` param."
)


_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)

Expand Down Expand Up @@ -213,5 +221,126 @@ def _xgb_cls(cls):
def _pyspark_model_cls(cls):
return SparkXGBClassifierModel

def _validate_params(self):
super()._validate_params()
if self.isDefined(self.qid_col):
raise ValueError(
"Spark Xgboost classifier estimator does not support `qid_col` param."
)


_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)


class SparkXGBRanker(_SparkXGBEstimator):
"""SparkXGBRanker is a PySpark ML estimator. It implements the XGBoost
classification algorithm based on XGBoost python library, and it can be used in
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be It implements the XGBoost **ranking** algorithm based on XGBoost python library instead of classification algorithm? I noticed that there are quite a few references to classification left behind in this block comment... should they and the example provided below be updated to ranking so the documentation comes out right?

Just to avoid confusion for new users like me trying to use this new estimator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123 Would you like to open a fix?

PySpark Pipeline and PySpark ML meta algorithms like
:py:class:`~pyspark.ml.tuning.CrossValidator`/
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
:py:class:`~pyspark.ml.classification.OneVsRest`

SparkXGBRanker automatically supports most of the parameters in
`xgboost.XGBClassifier` constructor and most of the parameters used in
:py:class:`xgboost.XGBClassifier` fit and predict method.

SparkXGBRanker doesn't support setting `gpu_id` but support another param `use_gpu`,
see doc below for more details.

SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support
another param called `base_margin_col`. see doc below for more details.

SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin
from the raw prediction column. See `raw_prediction_col` param doc below for more details.

SparkXGBRanker doesn't support `validate_features` and `output_margin` param.

SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the `nthread`
param for each xgboost worker will be set equal to `spark.task.cpus` config value.


Parameters
----------

callbacks:
The export and import of the callback functions are at best effort. For
details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc.
validation_indicator_col:
For params related to `xgboost.XGBClassifier` training with
evaluation dataset's supervision,
set :py:attr:`xgboost.spark.SparkXGBClassifier.validation_indicator_col`
parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier`
fit method.
weight_col:
To specify the weight of the training and validation dataset, set
:py:attr:`xgboost.spark.SparkXGBClassifier.weight_col` parameter instead of setting
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier`
fit method.
xgb_model:
Set the value to be the instance returned by
:func:`xgboost.spark.SparkXGBClassifierModel.get_booster`.
num_workers:
Integer that specifies the number of XGBoost workers to use.
Each XGBoost worker corresponds to one spark task.
use_gpu:
Boolean that specifies whether the executors are running on GPU
instances.
base_margin_col:
To specify the base margins of the training and validation
dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.base_margin_col` parameter
instead of setting `base_margin` and `base_margin_eval_set` in the
`xgboost.XGBRanker` fit method.
qid_col:
To specify the qid of the training and validation
dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.qid_col` parameter
instead of setting `qid` / `group`, `eval_qid` / `eval_group` in the
`xgboost.XGBRanker` fit method.

.. Note:: The Parameters chart above contains parameters that need special handling.
For a full list of parameters, see entries with `Param(parent=...` below.

.. Note:: This API is experimental.

Examples
--------

>>> from xgboost.spark import SparkXGBClassifier
>>> from pyspark.ml.linalg import Vectors
>>> df_train = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
... (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
... ], ["features", "label", "isVal", "weight"])
>>> df_test = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), ),
... ], ["features"])
>>> xgb_classifier = SparkXGBClassifier(max_depth=5, missing=0.0,
... validation_indicator_col='isVal', weight_col='weight',
... early_stopping_rounds=1, eval_metric='logloss')
>>> xgb_clf_model = xgb_classifier.fit(df_train)
>>> xgb_clf_model.transform(df_test).show()

"""

def __init__(self, **kwargs):
super().__init__()
self.setParams(**kwargs)

@classmethod
def _xgb_cls(cls):
return XGBRanker

@classmethod
def _pyspark_model_cls(cls):
return SparkXGBRankerModel

def _validate_params(self):
super()._validate_params()
if not self.isDefined(self.qid_col):
raise ValueError(
"Spark Xgboost ranker estimator requires setting `qid_col` param."
)


_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel)
14 changes: 14 additions & 0 deletions python-package/xgboost/spark/params.py
Expand Up @@ -72,3 +72,17 @@ class HasEnableSparseDataOptim(Params):
def __init__(self):
super().__init__()
self._setDefault(enable_sparse_data_optim=False)


class HasQueryIdCol(Params):
"""
Mixin for param featuresCols: a list of feature column names.
This parameter is taken effect only when use_gpu is enabled.
"""

qid_col = Param(
Params._dummy(),
"qid_col",
"query id column name",
typeConverter=TypeConverters.toString,
)