Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xgboost pyspark integration: User guide doc and tutorials #8082

Merged
merged 7 commits into from Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
82 changes: 82 additions & 0 deletions demo/guide-python/spark_estimator_examples.py
@@ -0,0 +1,82 @@
'''
Collection of examples for using xgboost.spark estimator interface
==================================================

@author: Weichen Xu
'''
from pyspark.sql import SparkSession
from pyspark.sql.functions import rand
from pyspark.ml.linalg import Vectors
import sklearn.datasets
from sklearn.model_selection import train_test_split
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator


spark = SparkSession.builder.master("local[*]").getOrCreate()


def create_spark_df(X, y):
return spark.createDataFrame(
spark.sparkContext.parallelize([
(Vectors.dense(features), float(label))
for features, label in zip(X, y)
]),
["features", "label"]
)


# load diabetes dataset (regression dataset)
diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True)
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \
train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True)

diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train)
diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)

# train xgboost regressor model
xgb_regressor = SparkXGBRegressor(max_depth=5)
xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df)

transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df)
regressor_evaluator = RegressionEvaluator(metricName="rmse")
print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}")

diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn(
"validationIndicatorCol", rand(1) > 0.7
)

# train xgboost regressor model with validation dataset
xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol")
xgb_regressor_model2 = xgb_regressor.fit(diabetes_train_spark_df2)
transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df)
print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}")


# load iris dataset (classification dataset)
iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True)
iris_X_train, iris_X_test, iris_y_train, iris_y_test = \
train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True)

iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train)
iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test)

# train xgboost classifier model
xgb_classifier = SparkXGBClassifier(max_depth=5)
xgb_classifier_model = xgb_classifier.fit(iris_train_spark_df)

transformed_iris_test_spark_df = xgb_classifier_model.transform(iris_test_spark_df)
classifier_evaluator = MulticlassClassificationEvaluator(metricName="f1")
print(f"classifier f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df)}")

iris_train_spark_df2 = iris_train_spark_df.withColumn(
"validationIndicatorCol", rand(1) > 0.7
)

# train xgboost classifier model with validation dataset
xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol")
xgb_classifier_model2 = xgb_classifier.fit(iris_train_spark_df2)
transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df)
print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}")

spark.stop()
66 changes: 66 additions & 0 deletions doc/tutorials/spark_estimator.rst
@@ -0,0 +1,66 @@
###############################
Using XGBoost PySpark Estimator
###############################
Starting from version 1.5, xgboost supports pyspark estimator APIs.
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
The feature is still experimental and not yet ready for production use.

*****************
SparkXGBRegressor
*****************

SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost classification
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.

We can create a `SparkXGBRegressor` estimator like:

.. code-block:: python

from xgboost.spark import SparkXGBRegressor
spark_reg_estimator = SparkXGBRegressor(num_workers=2, max_depth=5)


The above snippet create an spark estimator which can fit on a spark dataset,
and return a spark model that can transform a spark dataset and generate dataset
with prediction column. We can set almost all of xgboost sklearn estimator parameters
as `SparkXGBRegressor` parameters, but some parameter such as `nthread` is forbidden
in spark estimator, and some parameters are replaced with pyspark specific parameters
such as `weight_col`, `validation_indicator_col`, `use_gpu`, for details please see
`SparkXGBRegressor` doc.

The following code snippet shows how to train a spark xgboost regressor model,
first we need to prepare a training dataset as a spark dataframe contains
"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector`
type or spark array type.

.. code-block:: python

xgb_regressor_model = xgb_regressor.fit(train_spark_dataframe)


The following code snippet shows how to predict test data using a spark xgboost regressor model,
first we need to prepare a test dataset as a spark dataframe contains
"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector`
type or spark array type.

.. code-block:: python

transformed_test_spark_dataframe = xgb_regressor.predict(test_spark_dataframe)


The above snippet code returns a `transformed_test_spark_dataframe` that contains the input
dataset columns and an appended column "prediction" representing the prediction results.


******************
SparkXGBClassifier
******************


`SparkXGBClassifier` estimator has similar API with `SparkXGBRegressor`, but it has some
pyspark classifier specific params, e.g. `raw_prediction_col` and `probability_col` parameters.
Correspondingly, by default, `SparkXGBClassifierModel` transforming test dataset will
generate result dataset with 3 new columns:
- "prediction": represents the predicted label.
- "raw_prediction": represents the output margin values.
- "probability": represents the prediction probability on each label.
6 changes: 1 addition & 5 deletions python-package/xgboost/spark/core.py
Expand Up @@ -88,7 +88,7 @@
"features_col": "featuresCol",
"label_col": "labelCol",
"weight_col": "weightCol",
"raw_prediction_ol": "rawPredictionCol",
"raw_prediction_col": "rawPredictionCol",
"prediction_col": "predictionCol",
"probability_col": "probabilityCol",
"validation_indicator_col": "validationIndicatorCol",
Expand Down Expand Up @@ -355,10 +355,6 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name
)
if k in _pyspark_param_alias_map:
real_k = _pyspark_param_alias_map[k]
if real_k in kwargs:
raise ValueError(
f"You should set only one of param '{k}' and '{real_k}'"
)
k = real_k

if self.hasParam(k):
Expand Down
27 changes: 17 additions & 10 deletions python-package/xgboost/spark/estimator.py
Expand Up @@ -30,18 +30,21 @@ class SparkXGBRegressor(_SparkXGBEstimator):

SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.

SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread`
param for each xgboost worker will be set equal to `spark.task.cpus` config value.

callbacks:
The export and import of the callback functions are at best effort.
For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc.
validationIndicatorCol
validation_indicator_col
For params related to `xgboost.XGBRegressor` training
with evaluation dataset's supervision, set
:py:attr:`xgboost.spark.SparkXGBRegressor.validationIndicatorCol`
:py:attr:`xgboost.spark.SparkXGBRegressor.validation_indicator_col`
parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor`
fit method.
weightCol:
weight_col:
To specify the weight of the training and validation dataset, set
:py:attr:`xgboost.spark.SparkXGBRegressor.weightCol` parameter instead of setting
:py:attr:`xgboost.spark.SparkXGBRegressor.weight_col` parameter instead of setting
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor`
fit method.
xgb_model:
Expand Down Expand Up @@ -121,28 +124,32 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
another param called `base_margin_col`. see doc below for more details.

SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin
from the raw prediction column. See `rawPredictionCol` param doc below for more details.
from the raw prediction column. See `raw_prediction_col` param doc below for more details.

SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.

SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread`
param for each xgboost worker will be set equal to `spark.task.cpus` config value.


Parameters
----------
callbacks:
The export and import of the callback functions are at best effort. For
details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc.
rawPredictionCol:
raw_prediction_col:
The `output_margin=True` is implicitly supported by the
`rawPredictionCol` output column, which is always returned with the predicted margin
values.
validationIndicatorCol:
validation_indicator_col:
For params related to `xgboost.XGBClassifier` training with
evaluation dataset's supervision,
set :py:attr:`xgboost.spark.SparkXGBClassifier.validationIndicatorCol`
set :py:attr:`xgboost.spark.SparkXGBClassifier.validation_indicator_col`
parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier`
fit method.
weightCol:
weight_col:
To specify the weight of the training and validation dataset, set
:py:attr:`xgboost.spark.SparkXGBClassifier.weightCol` parameter instead of setting
:py:attr:`xgboost.spark.SparkXGBClassifier.weight_col` parameter instead of setting
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier`
fit method.
xgb_model:
Expand Down