From f23cc92130a172fa7a5fbdf5bb605c0bc5183ffe Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 19 Jul 2022 22:25:14 +0800 Subject: [PATCH] [pyspark] User guide doc and tutorials (#8082) Co-authored-by: Bobby Wang --- demo/guide-python/spark_estimator_examples.py | 82 +++++++++++++++++++ doc/tutorials/spark_estimator.rst | 66 +++++++++++++++ python-package/xgboost/spark/core.py | 4 - python-package/xgboost/spark/estimator.py | 7 ++ 4 files changed, 155 insertions(+), 4 deletions(-) create mode 100644 demo/guide-python/spark_estimator_examples.py create mode 100644 doc/tutorials/spark_estimator.rst diff --git a/demo/guide-python/spark_estimator_examples.py b/demo/guide-python/spark_estimator_examples.py new file mode 100644 index 000000000000..de1bda560251 --- /dev/null +++ b/demo/guide-python/spark_estimator_examples.py @@ -0,0 +1,82 @@ +''' +Collection of examples for using xgboost.spark estimator interface +================================================================== + +@author: Weichen Xu +''' +from pyspark.sql import SparkSession +from pyspark.sql.functions import rand +from pyspark.ml.linalg import Vectors +import sklearn.datasets +from sklearn.model_selection import train_test_split +from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor +from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator + + +spark = SparkSession.builder.master("local[*]").getOrCreate() + + +def create_spark_df(X, y): + return spark.createDataFrame( + spark.sparkContext.parallelize([ + (Vectors.dense(features), float(label)) + for features, label in zip(X, y) + ]), + ["features", "label"] + ) + + +# load diabetes dataset (regression dataset) +diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True) +diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \ + train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True) + +diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train) +diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test) + +# train xgboost regressor model +xgb_regressor = SparkXGBRegressor(max_depth=5) +xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df) + +transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df) +regressor_evaluator = RegressionEvaluator(metricName="rmse") +print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}") + +diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn( + "validationIndicatorCol", rand(1) > 0.7 +) + +# train xgboost regressor model with validation dataset +xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol") +xgb_regressor_model2 = xgb_regressor.fit(diabetes_train_spark_df2) +transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df) +print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}") + + +# load iris dataset (classification dataset) +iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True) +iris_X_train, iris_X_test, iris_y_train, iris_y_test = \ + train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True) + +iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train) +iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test) + +# train xgboost classifier model +xgb_classifier = SparkXGBClassifier(max_depth=5) +xgb_classifier_model = xgb_classifier.fit(iris_train_spark_df) + +transformed_iris_test_spark_df = xgb_classifier_model.transform(iris_test_spark_df) +classifier_evaluator = MulticlassClassificationEvaluator(metricName="f1") +print(f"classifier f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df)}") + +iris_train_spark_df2 = iris_train_spark_df.withColumn( + "validationIndicatorCol", rand(1) > 0.7 +) + +# train xgboost classifier model with validation dataset +xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol") +xgb_classifier_model2 = xgb_classifier.fit(iris_train_spark_df2) +transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df) +print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}") + +spark.stop() diff --git a/doc/tutorials/spark_estimator.rst b/doc/tutorials/spark_estimator.rst new file mode 100644 index 000000000000..963a79377a6d --- /dev/null +++ b/doc/tutorials/spark_estimator.rst @@ -0,0 +1,66 @@ +############################### +Using XGBoost PySpark Estimator +############################### +Starting from version 2.0, xgboost supports pyspark estimator APIs. +The feature is still experimental and not yet ready for production use. + +***************** +SparkXGBRegressor +***************** + +SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost classification +algorithm based on XGBoost python library, and it can be used in PySpark Pipeline +and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest. + +We can create a `SparkXGBRegressor` estimator like: + +.. code-block:: python + + from xgboost.spark import SparkXGBRegressor + spark_reg_estimator = SparkXGBRegressor(num_workers=2, max_depth=5) + + +The above snippet create an spark estimator which can fit on a spark dataset, +and return a spark model that can transform a spark dataset and generate dataset +with prediction column. We can set almost all of xgboost sklearn estimator parameters +as `SparkXGBRegressor` parameters, but some parameter such as `nthread` is forbidden +in spark estimator, and some parameters are replaced with pyspark specific parameters +such as `weight_col`, `validation_indicator_col`, `use_gpu`, for details please see +`SparkXGBRegressor` doc. + +The following code snippet shows how to train a spark xgboost regressor model, +first we need to prepare a training dataset as a spark dataframe contains +"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector` +type or spark array type. + +.. code-block:: python + + xgb_regressor_model = xgb_regressor.fit(train_spark_dataframe) + + +The following code snippet shows how to predict test data using a spark xgboost regressor model, +first we need to prepare a test dataset as a spark dataframe contains +"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector` +type or spark array type. + +.. code-block:: python + + transformed_test_spark_dataframe = xgb_regressor.predict(test_spark_dataframe) + + +The above snippet code returns a `transformed_test_spark_dataframe` that contains the input +dataset columns and an appended column "prediction" representing the prediction results. + + +****************** +SparkXGBClassifier +****************** + + +`SparkXGBClassifier` estimator has similar API with `SparkXGBRegressor`, but it has some +pyspark classifier specific params, e.g. `raw_prediction_col` and `probability_col` parameters. +Correspondingly, by default, `SparkXGBClassifierModel` transforming test dataset will +generate result dataset with 3 new columns: + - "prediction": represents the predicted label. + - "raw_prediction": represents the output margin values. + - "probability": represents the prediction probability on each label. diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index bd461d49b5cd..dd5369c36d7a 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -379,10 +379,6 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name ) if k in _pyspark_param_alias_map: real_k = _pyspark_param_alias_map[k] - if real_k in kwargs: - raise ValueError( - f"You should set only one of param '{k}' and '{real_k}'" - ) k = real_k if self.hasParam(k): diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 1af42c3ae120..d174455d5721 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -31,6 +31,9 @@ class SparkXGBRegressor(_SparkXGBEstimator): SparkXGBRegressor doesn't support `validate_features` and `output_margin` param. + SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread` + param for each xgboost worker will be set equal to `spark.task.cpus` config value. + callbacks: The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc. @@ -128,6 +131,10 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction SparkXGBClassifier doesn't support `validate_features` and `output_margin` param. + SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread` + param for each xgboost worker will be set equal to `spark.task.cpus` config value. + + Parameters ----------