Skip to content

Latest commit

 

History

History
203 lines (141 loc) · 6.56 KB

spark_estimator.rst

File metadata and controls

203 lines (141 loc) · 6.56 KB

Distributed XGBoost with PySpark

Starting from version 2.0, xgboost supports pyspark estimator APIs. The feature is still experimental and not yet ready for production use.

XGBoost PySpark Estimator

SparkXGBRegressor

SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost classification algorithm based on XGBoost python library, and it can be used in PySpark Pipeline and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.

We can create a SparkXGBRegressor estimator like:

from xgboost.spark import SparkXGBRegressor
spark_reg_estimator = SparkXGBRegressor(
  features_col="features",
  label_col="label",
  num_workers=2,
)

The above snippet creates a spark estimator which can fit on a spark dataset, and return a spark model that can transform a spark dataset and generate dataset with prediction column. We can set almost all of xgboost sklearn estimator parameters as SparkXGBRegressor parameters, but some parameter such as nthread is forbidden in spark estimator, and some parameters are replaced with pyspark specific parameters such as weight_col, validation_indicator_col, use_gpu, for details please see SparkXGBRegressor doc.

The following code snippet shows how to train a spark xgboost regressor model, first we need to prepare a training dataset as a spark dataframe contains "label" column and "features" column(s), the "features" column(s) must be pyspark.ml.linalg.Vector type or spark array type or a list of feature column names.

xgb_regressor_model = xgb_regressor.fit(train_spark_dataframe)

The following code snippet shows how to predict test data using a spark xgboost regressor model, first we need to prepare a test dataset as a spark dataframe contains "features" and "label" column, the "features" column must be pyspark.ml.linalg.Vector type or spark array type.

transformed_test_spark_dataframe = xgb_regressor.predict(test_spark_dataframe)

The above snippet code returns a transformed_test_spark_dataframe that contains the input dataset columns and an appended column "prediction" representing the prediction results.

SparkXGBClassifier

SparkXGBClassifier estimator has similar API with SparkXGBRegressor, but it has some pyspark classifier specific params, e.g. raw_prediction_col and probability_col parameters. Correspondingly, by default, SparkXGBClassifierModel transforming test dataset will generate result dataset with 3 new columns: - "prediction": represents the predicted label. - "raw_prediction": represents the output margin values. - "probability": represents the prediction probability on each label.

XGBoost PySpark GPU support

XGBoost PySpark supports GPU training and prediction. To enable GPU support, you first need to install the xgboost and cudf packages. Then you can set use_gpu parameter to True.

Below tutorial will show you how to train a model with XGBoost PySpark GPU on Spark standalone cluster.

Write your PySpark application

from xgboost.spark import SparkXGBRegressor
spark = SparkSession.builder.getOrCreate()

# read data into spark dataframe
train_data_path = "xxxx/train"
train_df = spark.read.parquet(data_path)

test_data_path = "xxxx/test"
test_df = spark.read.parquet(test_data_path)

# assume the label column is named "class"
label_name = "class"

# get a list with feature column names
feature_names = [x.name for x in train_df.schema if x.name != label]

# create a xgboost pyspark regressor estimator and set use_gpu=True
regressor = SparkXGBRegressor(
  features_col=feature_names,
  label_col=label_name,
  num_workers=2,
  use_gpu=True,
)

# train and return the model
model = regressor.fit(train_df)

# predict on test data
predict_df = model.transform(test_df)
predict_df.show()

Prepare the necessary packages

We recommend using Conda or Virtualenv to manage python dependencies in PySpark. Please refer to How to Manage Python Dependencies in PySpark.

conda create -y -n xgboost-env -c conda-forge conda-pack python=3.9
conda activate xgboost-env
pip install xgboost
pip install cudf
conda pack -f -o xgboost-env.tar.gz

Submit the PySpark application

Assuming you have configured your Spark cluster with GPU support, if not yet, please refer to spark standalone configuration with GPU support.

export PYSPARK_DRIVER_PYTHON=python
export PYSPARK_PYTHON=./environment/bin/python

spark-submit \
  --master spark://<master-ip>:7077 \
  --conf spark.executor.resource.gpu.amount=1 \
  --conf spark.task.resource.gpu.amount=1 \
  --archives xgboost-env.tar.gz#environment \
  xgboost_app.py

Model Persistence

# save the model
model.save("/tmp/xgboost-pyspark-model")

# load the model
model2 = SparkXGBRankerModel.load("/tmp/xgboost-pyspark-model")

The above code snippet shows how to save/load xgboost pyspark model. And you can also load the model with xgboost python package directly without involving spark.

import xgboost as xgb
bst = xgb.Booster()
bst.load_model("/tmp/xgboost-pyspark-model/model/part-00000")

Accelerate the whole pipeline of xgboost pyspark

With RAPIDS Accelerator for Apache Spark, you can accelerate the whole pipeline (ETL, Train, Transform) for xgboost pyspark without any code change by leveraging GPU.

You only need to add some configurations to enable RAPIDS plugin when submitting.

export PYSPARK_DRIVER_PYTHON=python
export PYSPARK_PYTHON=./environment/bin/python

spark-submit \
  --master spark://<master-ip>:7077 \
  --conf spark.executor.resource.gpu.amount=1 \
  --conf spark.task.resource.gpu.amount=1 \
  --packages com.nvidia:rapids-4-spark_2.12:22.08.0 \
  --conf spark.plugins=com.nvidia.spark.SQLPlugin \
  --archives xgboost-env.tar.gz#environment \
  xgboost_app.py