From 3acb92d98e0d2d521c2d5c95fd81d0034808e69b Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 15 Jul 2022 20:11:22 +0800 Subject: [PATCH 1/6] update --- demo/guide-python/spark_estimator_examples.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 demo/guide-python/spark_estimator_examples.py diff --git a/demo/guide-python/spark_estimator_examples.py b/demo/guide-python/spark_estimator_examples.py new file mode 100644 index 000000000000..5fcca354afee --- /dev/null +++ b/demo/guide-python/spark_estimator_examples.py @@ -0,0 +1,48 @@ +''' +Collection of examples for using xgboost.spark estimator interface +================================================== + +@author: Weichen Xu +''' +from pyspark.sql import SparkSession +from pyspark.sql.functions import rand +from pyspark.ml.linalg import Vectors +import sklearn +import sklearn.datasets +from sklearn.model_selection import train_test_split +from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor +from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator + + +spark = SparkSession.builder.master("local[*]").getOrCreate() + + +def create_spark_df(X, y): + return spark.createDataFrame( + spark.sparkContext.parallelize([ + (Vectors.dense(features), float(label)) + for features, label in zip(X, y) + ]), + ["features", "label"] + ) + + +diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True) +diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \ + train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True) + +diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train) +diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test) + +xgb_regressor = SparkXGBRegressor(max_depth=5) +xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df) + +transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df) +regressor_evaluator = RegressionEvaluator(metricName="rmse") +print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}") + +diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn( + "validationIndicatorCol", +) + +spark.stop() From b71d0b5f2f3139062263e0f647696898c0d9ab8c Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 15 Jul 2022 20:38:12 +0800 Subject: [PATCH 2/6] update --- demo/guide-python/spark_estimator_examples.py | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/demo/guide-python/spark_estimator_examples.py b/demo/guide-python/spark_estimator_examples.py index 5fcca354afee..5cb1ca2fdad6 100644 --- a/demo/guide-python/spark_estimator_examples.py +++ b/demo/guide-python/spark_estimator_examples.py @@ -7,7 +7,6 @@ from pyspark.sql import SparkSession from pyspark.sql.functions import rand from pyspark.ml.linalg import Vectors -import sklearn import sklearn.datasets from sklearn.model_selection import train_test_split from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor @@ -27,6 +26,7 @@ def create_spark_df(X, y): ) +# load diabetes dataset (regression dataset) diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True) diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \ train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True) @@ -34,6 +34,7 @@ def create_spark_df(X, y): diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train) diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test) +# train xgboost regressor model xgb_regressor = SparkXGBRegressor(max_depth=5) xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df) @@ -42,7 +43,40 @@ def create_spark_df(X, y): print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}") diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn( - "validationIndicatorCol", + "validationIndicatorCol", rand(1) > 0.7 ) +# train xgboost regressor model with validation dataset +xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol") +xgb_regressor_model2 = xgb_regressor.fit(diabetes_train_spark_df2) +transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df) +print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}") + + +# load iris dataset (classification dataset) +iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True) +iris_X_train, iris_X_test, iris_y_train, iris_y_test = \ + train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True) + +iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train) +iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test) + +# train xgboost classifier model +xgb_classifier = SparkXGBClassifier(max_depth=5) +xgb_classifier_model = xgb_classifier.fit(iris_train_spark_df) + +transformed_iris_test_spark_df = xgb_classifier_model.transform(iris_test_spark_df) +classifier_evaluator = MulticlassClassificationEvaluator(metricName="f1") +print(f"classifier f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df)}") + +iris_train_spark_df2 = iris_train_spark_df.withColumn( + "validationIndicatorCol", rand(1) > 0.7 +) + +# train xgboost classifier model with validation dataset +xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol") +xgb_classifier_model2 = xgb_classifier.fit(iris_train_spark_df2) +transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df) +print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}") + spark.stop() From 6750ef214ddce6d7854e249a3d83b059bb8078d2 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 13 Jul 2022 15:39:03 +0800 Subject: [PATCH 3/6] [PySpark] fix raw_prediction_col parameter and minor cleanup --- python-package/xgboost/spark/core.py | 2 +- python-package/xgboost/spark/estimator.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 68a15a534f33..f8697733c935 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -88,7 +88,7 @@ "features_col": "featuresCol", "label_col": "labelCol", "weight_col": "weightCol", - "raw_prediction_ol": "rawPredictionCol", + "raw_prediction_col": "rawPredictionCol", "prediction_col": "predictionCol", "probability_col": "probabilityCol", "validation_indicator_col": "validationIndicatorCol", diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 3f50ab2bf2b9..664d7c06182e 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -33,15 +33,15 @@ class SparkXGBRegressor(_SparkXGBEstimator): callbacks: The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc. - validationIndicatorCol + validation_indicator_col For params related to `xgboost.XGBRegressor` training with evaluation dataset's supervision, set - :py:attr:`xgboost.spark.SparkXGBRegressor.validationIndicatorCol` + :py:attr:`xgboost.spark.SparkXGBRegressor.validation_indicator_col` parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor` fit method. - weightCol: + weight_col: To specify the weight of the training and validation dataset, set - :py:attr:`xgboost.spark.SparkXGBRegressor.weightCol` parameter instead of setting + :py:attr:`xgboost.spark.SparkXGBRegressor.weight_col` parameter instead of setting `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor` fit method. xgb_model: @@ -121,7 +121,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction another param called `base_margin_col`. see doc below for more details. SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin - from the raw prediction column. See `rawPredictionCol` param doc below for more details. + from the raw prediction column. See `raw_prediction_col` param doc below for more details. SparkXGBClassifier doesn't support `validate_features` and `output_margin` param. @@ -130,19 +130,19 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction callbacks: The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc. - rawPredictionCol: + raw_prediction_col: The `output_margin=True` is implicitly supported by the `rawPredictionCol` output column, which is always returned with the predicted margin values. - validationIndicatorCol: + validation_indicator_col: For params related to `xgboost.XGBClassifier` training with evaluation dataset's supervision, - set :py:attr:`xgboost.spark.SparkXGBClassifier.validationIndicatorCol` + set :py:attr:`xgboost.spark.SparkXGBClassifier.validation_indicator_col` parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier` fit method. - weightCol: + weight_col: To specify the weight of the training and validation dataset, set - :py:attr:`xgboost.spark.SparkXGBClassifier.weightCol` parameter instead of setting + :py:attr:`xgboost.spark.SparkXGBClassifier.weight_col` parameter instead of setting `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier` fit method. xgb_model: From 65de9b48c92653f646e85ea7f61c5fd1f6ec1a87 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 15 Jul 2022 21:20:49 +0800 Subject: [PATCH 4/6] update --- doc/tutorials/spark_estimator.rst | 66 +++++++++++++++++++++++ python-package/xgboost/spark/estimator.py | 7 +++ 2 files changed, 73 insertions(+) create mode 100644 doc/tutorials/spark_estimator.rst diff --git a/doc/tutorials/spark_estimator.rst b/doc/tutorials/spark_estimator.rst new file mode 100644 index 000000000000..d27d27e05399 --- /dev/null +++ b/doc/tutorials/spark_estimator.rst @@ -0,0 +1,66 @@ +############################### +Using XGBoost PySpark Estimator +############################### +Starting from version 1.5, xgboost supports pyspark estimator APIs. +The feature is still experimental and not yet ready for production use. + +***************** +SparkXGBRegressor +***************** + +SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost classification +algorithm based on XGBoost python library, and it can be used in PySpark Pipeline +and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest. + +We can create a `SparkXGBRegressor` estimator like: + +.. code-block:: python + + from xgboost.spark import SparkXGBRegressor + spark_reg_estimator = SparkXGBRegressor(num_workers=2, max_depth=5) + + +The above snippet create an spark estimator which can fit on a spark dataset, +and return a spark model that can transform a spark dataset and generate dataset +with prediction column. We can set almost all of xgboost sklearn estimator parameters +as `SparkXGBRegressor` parameters, but some parameter such as `nthread` is forbidden +in spark estimator, and some parameters are replaced with pyspark specific parameters +such as `weight_col`, `validation_indicator_col`, `use_gpu`, for details please see +`SparkXGBRegressor` doc. + +The following code snippet shows how to train a spark xgboost regressor model, +first we need to prepare a training dataset as a spark dataframe contains +"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector` +type or spark array type. + +.. code-block:: python + + xgb_regressor_model = xgb_regressor.fit(train_spark_dataframe) + + +The following code snippet shows how to predict test data using a spark xgboost regressor model, +first we need to prepare a test dataset as a spark dataframe contains +"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector` +type or spark array type. + +.. code-block:: python + + transformed_test_spark_dataframe = xgb_regressor.predict(test_spark_dataframe) + + +The above snippet code returns a `transformed_test_spark_dataframe` that contains the input +dataset columns and an appended column "prediction" representing the prediction results. + + +****************** +SparkXGBClassifier +****************** + + +`SparkXGBClassifier` estimator has similar API with `SparkXGBRegressor`, but it has some +pyspark classifier specific params, e.g. `raw_prediction_col` and `probability_col` parameters. +Correspondingly, by default, `SparkXGBClassifierModel` transforming test dataset will +generate result dataset with 3 new columns: + - "prediction": represents the predicted label. + - "raw_prediction": represents the output margin values. + - "probability": represents the prediction probability on each label. diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 664d7c06182e..af6c10e4bb52 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -30,6 +30,9 @@ class SparkXGBRegressor(_SparkXGBEstimator): SparkXGBRegressor doesn't support `validate_features` and `output_margin` param. + SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread` + param for each xgboost worker will be set equal to `spark.task.cpus` config value. + callbacks: The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc. @@ -125,6 +128,10 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction SparkXGBClassifier doesn't support `validate_features` and `output_margin` param. + SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread` + param for each xgboost worker will be set equal to `spark.task.cpus` config value. + + Parameters ---------- callbacks: From 486be451afbd5ad350024dcef34fe57865a184e6 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 15 Jul 2022 21:22:16 +0800 Subject: [PATCH 5/6] clean --- python-package/xgboost/spark/core.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index f8697733c935..7040f241f83c 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -355,10 +355,6 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name ) if k in _pyspark_param_alias_map: real_k = _pyspark_param_alias_map[k] - if real_k in kwargs: - raise ValueError( - f"You should set only one of param '{k}' and '{real_k}'" - ) k = real_k if self.hasParam(k): From 8d66aaee842bb019aa166e8c11c1b09ba6f8a244 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 18 Jul 2022 10:04:11 +0800 Subject: [PATCH 6/6] update --- demo/guide-python/spark_estimator_examples.py | 2 +- doc/tutorials/spark_estimator.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/demo/guide-python/spark_estimator_examples.py b/demo/guide-python/spark_estimator_examples.py index 5cb1ca2fdad6..de1bda560251 100644 --- a/demo/guide-python/spark_estimator_examples.py +++ b/demo/guide-python/spark_estimator_examples.py @@ -1,6 +1,6 @@ ''' Collection of examples for using xgboost.spark estimator interface -================================================== +================================================================== @author: Weichen Xu ''' diff --git a/doc/tutorials/spark_estimator.rst b/doc/tutorials/spark_estimator.rst index d27d27e05399..963a79377a6d 100644 --- a/doc/tutorials/spark_estimator.rst +++ b/doc/tutorials/spark_estimator.rst @@ -1,7 +1,7 @@ ############################### Using XGBoost PySpark Estimator ############################### -Starting from version 1.5, xgboost supports pyspark estimator APIs. +Starting from version 2.0, xgboost supports pyspark estimator APIs. The feature is still experimental and not yet ready for production use. *****************