From 0a15487bf3593afee6a004a50251bd6e3d36b692 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 21 Jun 2022 21:40:29 +0800 Subject: [PATCH 01/73] init Signed-off-by: Weichen Xu --- python-package/xgboost/spark/__init__.py | 16 + python-package/xgboost/spark/core.py | 810 ++++++++++++++++++ python-package/xgboost/spark/data.py | 262 ++++++ python-package/xgboost/spark/estimator.py | 192 +++++ python-package/xgboost/spark/model.py | 232 +++++ python-package/xgboost/spark/utils.py | 180 ++++ tests/python/test_spark/__init__.py | 0 tests/python/test_spark/data_test.py | 175 ++++ tests/python/test_spark/utils_test.py | 130 +++ .../test_spark/xgboost_local_cluster_test.py | 382 +++++++++ tests/python/test_spark/xgboost_local_test.py | 654 ++++++++++++++ 11 files changed, 3033 insertions(+) create mode 100644 python-package/xgboost/spark/__init__.py create mode 100644 python-package/xgboost/spark/core.py create mode 100644 python-package/xgboost/spark/data.py create mode 100644 python-package/xgboost/spark/estimator.py create mode 100644 python-package/xgboost/spark/model.py create mode 100644 python-package/xgboost/spark/utils.py create mode 100644 tests/python/test_spark/__init__.py create mode 100644 tests/python/test_spark/data_test.py create mode 100644 tests/python/test_spark/utils_test.py create mode 100644 tests/python/test_spark/xgboost_local_cluster_test.py create mode 100644 tests/python/test_spark/xgboost_local_test.py diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py new file mode 100644 index 000000000000..6e58401aa34f --- /dev/null +++ b/python-package/xgboost/spark/__init__.py @@ -0,0 +1,16 @@ +"""XGBoost: eXtreme Gradient Boosting library. + +Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md +""" + +try: + import pyspark +except ImportError: + raise RuntimeError("xgboost spark python API requires pyspark package installed.") + +from .estimator import (XgboostClassifier, XgboostClassifierModel, + XgboostRegressor, XgboostRegressorModel) + +__all__ = ['XgboostClassifier', 'XgboostClassifierModel', + 'XgboostRegressor', 'XgboostRegressorModel'] + diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py new file mode 100644 index 000000000000..48d4c75a4212 --- /dev/null +++ b/python-package/xgboost/spark/core.py @@ -0,0 +1,810 @@ +import shutil +import tempfile +from typing import Iterator, Tuple +import numpy as np +import pandas as pd +from scipy.special import expit, softmax +from pyspark.ml import Estimator, Model +from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasWeightCol, \ + HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasValidationIndicatorCol +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.util import MLReadable, MLWritable +from pyspark.sql.functions import col, pandas_udf, countDistinct, struct +from pyspark.sql.types import ArrayType, FloatType +from xgboost import XGBClassifier, XGBRegressor +from xgboost.core import Booster +import cloudpickle +import xgboost +from xgboost.training import train as worker_train +from .utils import get_logger, _get_max_num_concurrent_tasks +from .data import prepare_predict_data, prepare_train_val_data, convert_partition_data_to_dmatrix +from .model import (XgboostReader, XgboostWriter, XgboostModelReader, + XgboostModelWriter, deserialize_xgb_model, + get_xgb_model_creator, serialize_xgb_model) +from .utils import (_get_default_params_from_func, get_class_name, + HasArbitraryParamsDict, HasBaseMarginCol, RabitContext, + _get_rabit_args, _get_args_from_message_list, + _get_spark_session) + +from pyspark.ml.functions import array_to_vector, vector_to_array + +# Put pyspark specific params here, they won't be passed to XGBoost. +# like `validationIndicatorCol`, `baseMarginCol` +_pyspark_specific_params = [ + 'featuresCol', 'labelCol', 'weightCol', 'rawPredictionCol', + 'predictionCol', 'probabilityCol', 'validationIndicatorCol' + 'baseMarginCol' +] + +_unsupported_xgb_params = [ + 'gpu_id', # [ML-12862] +] +_unsupported_fit_params = { + 'sample_weight', # Supported by spark param weightCol + # Supported by spark param weightCol # and validationIndicatorCol + 'eval_set', + 'sample_weight_eval_set', + 'base_margin' # Supported by spark param baseMarginCol +} +_unsupported_predict_params = { + # [ML-12913], for classification, we can use rawPrediction as margin + 'output_margin', + 'validate_features', # [ML-12923] + 'base_margin' # [ML-12689] +} + +_created_params = {"num_workers", "use_gpu"} + + +class _XgboostParams(HasFeaturesCol, HasLabelCol, HasWeightCol, + HasPredictionCol, HasValidationIndicatorCol, + HasArbitraryParamsDict, HasBaseMarginCol): + num_workers = Param( + Params._dummy(), "num_workers", + "The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.", + TypeConverters.toInt) + use_gpu = Param( + Params._dummy(), "use_gpu", + "A boolean variable. Set use_gpu=true if the executors " + + "are running on GPU instances. Currently, only one GPU per task is supported." + ) + force_repartition = Param( + Params._dummy(), "force_repartition", + "A boolean variable. Set force_repartition=true if you " + + "want to force the input dataset to be repartitioned before XGBoost training." + + "Note: The auto repartitioning judgement is not fully accurate, so it is recommended" + + "to have force_repartition be True.") + use_external_storage = Param( + Params._dummy(), "use_external_storage", + "A boolean variable (that is False by default). External storage is a parameter" + + "for distributed training that allows external storage (disk) to be used when." + + "you have an exceptionally large dataset. This should be set to false for" + + "small datasets. Note that base margin and weighting doesn't work if this is True." + + "Also note that you may use precision if you use external storage." + ) + external_storage_precision = Param( + Params._dummy(), "external_storage_precision", + "The number of significant digits for data storage on disk when using external storage.", + TypeConverters.toInt + ) + + @classmethod + def _xgb_cls(cls): + """ + Subclasses should override this method and + returns an xgboost.XGBModel subclass + """ + raise NotImplementedError() + + def _get_xgb_model_creator(self): + arbitaryParamsDict = self.getOrDefault( + self.getParam("arbitraryParamsDict")) + total_params = {**self._gen_xgb_params_dict(), **arbitaryParamsDict} + # Once we have already added all of the elements of kwargs, we can just remove it + del total_params["arbitraryParamsDict"] + for param in _created_params: + del total_params[param] + return get_xgb_model_creator(self._xgb_cls(), total_params) + + # Parameters for xgboost.XGBModel() + @classmethod + def _get_xgb_params_default(cls): + xgb_model_default = cls._xgb_cls()() + params_dict = xgb_model_default.get_params() + filtered_params_dict = { + k: params_dict[k] + for k in params_dict if k not in _unsupported_xgb_params + } + return filtered_params_dict + + def _set_xgb_params_default(self): + filtered_params_dict = self._get_xgb_params_default() + self._setDefault(**filtered_params_dict) + self._setDefault(**{"arbitraryParamsDict": {}}) + + def _gen_xgb_params_dict(self): + xgb_params = {} + non_xgb_params = \ + set(_pyspark_specific_params) | \ + self._get_fit_params_default().keys() | \ + self._get_predict_params_default().keys() + for param in self.extractParamMap(): + if param.name not in non_xgb_params: + xgb_params[param.name] = self.getOrDefault(param) + return xgb_params + + def _set_distributed_params(self): + self.set(self.num_workers, 1) + self.set(self.use_gpu, False) + self.set(self.force_repartition, False) + self.set(self.use_external_storage, False) + self.set(self.external_storage_precision, 5) # Check if this needs to be modified + + # Parameters for xgboost.XGBModel().fit() + @classmethod + def _get_fit_params_default(cls): + fit_params = _get_default_params_from_func(cls._xgb_cls().fit, + _unsupported_fit_params) + return fit_params + + def _set_fit_params_default(self): + filtered_params_dict = self._get_fit_params_default() + self._setDefault(**filtered_params_dict) + + def _gen_fit_params_dict(self): + """ + Returns a dict of params for .fit() + """ + fit_params_keys = self._get_fit_params_default().keys() + fit_params = {} + for param in self.extractParamMap(): + if param.name in fit_params_keys: + fit_params[param.name] = self.getOrDefault(param) + return fit_params + + # Parameters for xgboost.XGBModel().predict() + @classmethod + def _get_predict_params_default(cls): + predict_params = _get_default_params_from_func( + cls._xgb_cls().predict, _unsupported_predict_params) + return predict_params + + def _set_predict_params_default(self): + filtered_params_dict = self._get_predict_params_default() + self._setDefault(**filtered_params_dict) + + def _gen_predict_params_dict(self): + """ + Returns a dict of params for .predict() + """ + predict_params_keys = self._get_predict_params_default().keys() + predict_params = {} + for param in self.extractParamMap(): + if param.name in predict_params_keys: + predict_params[param.name] = self.getOrDefault(param) + return predict_params + + def _validate_params(self): + init_model = self.getOrDefault(self.xgb_model) + if init_model is not None: + if init_model is not None and not isinstance(init_model, Booster): + raise ValueError( + 'The xgb_model param must be set with a `xgboost.core.Booster` ' + 'instance.') + + if self.getOrDefault(self.num_workers) < 1: + raise ValueError( + f"Number of workers was {self.getOrDefault(self.num_workers)}." + f"It cannot be less than 1 [Default is 1]") + + if self.getOrDefault(self.num_workers) > 1 and not self.getOrDefault( + self.use_gpu): + cpu_per_task = _get_spark_session().sparkContext.getConf().get( + 'spark.task.cpus') + if cpu_per_task and int(cpu_per_task) > 1: + get_logger(self.__class__.__name__).warning( + f'You configured {cpu_per_task} CPU cores for each spark task, but in ' + f'XGBoost training, every Spark task will only use one CPU core.' + ) + + if self.getOrDefault(self.force_repartition) and self.getOrDefault( + self.num_workers) == 1: + get_logger(self.__class__.__name__).warning( + "You set force_repartition to true when there is no need for a repartition." + "Therefore, that parameter will be ignored.") + + if self.getOrDefault(self.use_gpu): + tree_method = self.getParam("tree_method") + if self.getOrDefault( + tree_method + ) is not None and self.getOrDefault(tree_method) != "gpu_hist": + raise ValueError( + f"tree_method should be 'gpu_hist' or None when use_gpu is True," + f"found {self.getOrDefault(tree_method)}.") + + gpu_per_task = _get_spark_session().sparkContext.getConf().get( + 'spark.task.resource.gpu.amount') + + if not gpu_per_task or int(gpu_per_task) < 1: + raise RuntimeError( + "The spark cluster does not have the necessary GPU" + + "configuration for the spark task. Therefore, we cannot" + + "run xgboost training using GPU.") + + if int(gpu_per_task) > 1: + get_logger(self.__class__.__name__).warning( + f'You configured {gpu_per_task} GPU cores for each spark task, but in ' + f'XGBoost training, every Spark task will only use one GPU core.' + ) + + +class _XgboostEstimator(Estimator, _XgboostParams, MLReadable, MLWritable): + def __init__(self): + super().__init__() + self._set_xgb_params_default() + self._set_fit_params_default() + self._set_predict_params_default() + self._set_distributed_params() + + def setParams(self, **kwargs): + _user_defined = {} + for k, v in kwargs.items(): + if self.hasParam(k): + self._set(**{str(k): v}) + else: + _user_defined[k] = v + _defined_args = self.getOrDefault(self.getParam("arbitraryParamsDict")) + _defined_args.update(_user_defined) + self._set(**{"arbitraryParamsDict": _defined_args}) + + @classmethod + def _pyspark_model_cls(cls): + """ + Subclasses should override this method and + returns a _XgboostModel subclass + """ + raise NotImplementedError() + + def _create_pyspark_model(self, xgb_model): + return self._pyspark_model_cls()(xgb_model) + + @classmethod + def _convert_to_classifier(cls, booster): + clf = XGBClassifier() + clf._Booster = booster + return clf + + @classmethod + def _convert_to_regressor(cls, booster): + reg = XGBRegressor() + reg._Booster = booster + return reg + + def _convert_to_model(self, booster): + if self._xgb_cls() == XGBRegressor: + return self._convert_to_regressor(booster) + elif self._xgb_cls() == XGBClassifier: + return self._convert_to_classifier(booster) + else: + return None # check if this else statement is needed. + + def _query_plan_contains_valid_repartition(self, query_plan, + num_partitions): + """ + Returns true if the latest element in the logical plan is a valid repartition + """ + start = query_plan.index("== Optimized Logical Plan ==") + start += len("== Optimized Logical Plan ==") + 1 + num_workers = self.getOrDefault(self.num_workers) + if query_plan[start:start + len("Repartition")] == "Repartition" and \ + num_workers == num_partitions: + return True + return False + + def _repartition_needed(self, dataset): + """ + We repartition the dataset if the number of workers is not equal to the number of + partitions. There is also a check to make sure there was "active partitioning" + where either Round Robin or Hash partitioning was actively used before this stage. + """ + if self.getOrDefault(self.force_repartition): + return True + try: + num_partitions = dataset.rdd.getNumPartitions() + query_plan = dataset._sc._jvm.PythonSQLUtils.explainString( + dataset._jdf.queryExecution(), "extended") + if self._query_plan_contains_valid_repartition( + query_plan, num_partitions): + return False + except: # noqa: E722 + pass + return True + + def _get_distributed_config(self, dataset, params): + """ + This just gets the configuration params for distributed xgboost + """ + + classification = self._xgb_cls() == XGBClassifier + num_classes = int( + dataset.select(countDistinct('label')).collect()[0][0]) + if classification and num_classes == 2: + params["objective"] = "binary:logistic" + elif classification and num_classes > 2: + params["objective"] = "multi:softprob" + params["num_class"] = num_classes + else: + params["objective"] = "reg:squarederror" + + if self.getOrDefault(self.use_gpu): + params["tree_method"] = "gpu_hist" + # TODO: fix this. This only works on databricks runtime. + # On open-source spark, we need get the gpu id from the task allocated gpu resources. + params["gpu_id"] = 0 + params["num_boost_round"] = self.getOrDefault(self.n_estimators) + xgb_params = self._gen_xgb_params_dict() + xgb_params.update(params) + return xgb_params + + @classmethod + def _get_dist_booster_params(cls, train_params): + non_booster_params = _get_default_params_from_func(xgboost.train, {}) + booster_params, kwargs_params = {}, {} + for key, value in train_params.items(): + if key in non_booster_params: + kwargs_params[key] = value + else: + booster_params[key] = value + return booster_params, kwargs_params + + def _fit_distributed(self, xgb_model_creator, dataset, has_weight, + has_validation, fit_params): + """ + Takes in the dataset, the other parameters, and produces a valid booster + """ + num_workers = self.getOrDefault(self.num_workers) + sc = _get_spark_session().sparkContext + max_concurrent_tasks = _get_max_num_concurrent_tasks(sc) + + if num_workers > max_concurrent_tasks: + get_logger(self.__class__.__name__) \ + .warning(f'The num_workers {num_workers} set for xgboost distributed ' + f'training is greater than current max number of concurrent ' + f'spark task slots, you need wait until more task slots available ' + f'or you need increase spark cluster workers.') + + if self._repartition_needed(dataset): + dataset = dataset.withColumn("values", col("values").cast(ArrayType(FloatType()))) + dataset = dataset.repartition(num_workers) + train_params = self._get_distributed_config(dataset, fit_params) + + def _train_booster(pandas_df_iter): + """ + Takes in an RDD partition and outputs a booster for that partition after going through + the Rabit Ring protocol + """ + from pyspark import BarrierTaskContext + context = BarrierTaskContext.get() + + use_external_storage = self.getOrDefault(self.use_external_storage) + external_storage_precision = self.getOrDefault(self.external_storage_precision) + external_storage_path_prefix = None + if use_external_storage: + external_storage_path_prefix = tempfile.mkdtemp() + dtrain, dval = None, [] + if has_validation: + dtrain, dval = convert_partition_data_to_dmatrix( + pandas_df_iter, has_weight, has_validation, + use_external_storage, external_storage_path_prefix, external_storage_precision) + dval = [(dtrain, "training"), (dval, "validation")] + else: + dtrain = convert_partition_data_to_dmatrix( + pandas_df_iter, has_weight, has_validation, + use_external_storage, external_storage_path_prefix, external_storage_precision) + + booster_params, kwargs_params = self._get_dist_booster_params( + train_params) + context.barrier() + _rabit_args = "" + if context.partitionId() == 0: + _rabit_args = str(_get_rabit_args(context, num_workers)) + + messages = context.allGather(message=str(_rabit_args)) + _rabit_args = _get_args_from_message_list(messages) + evals_result = {} + with RabitContext(_rabit_args, context): + booster = worker_train(params=booster_params, + dtrain=dtrain, + evals=dval, + evals_result=evals_result, + **kwargs_params) + context.barrier() + + if use_external_storage: + shutil.rmtree(external_storage_path_prefix) + if context.partitionId() == 0: + yield pd.DataFrame( + data={'booster_bytes': [cloudpickle.dumps(booster)]}) + + result_ser_booster = dataset.mapInPandas( + _train_booster, + schema='booster_bytes binary').rdd.barrier().mapPartitions( + lambda x: x).collect()[0][0] + result_xgb_model = self._convert_to_model( + cloudpickle.loads(result_ser_booster)) + return self._copyValues(self._create_pyspark_model(result_xgb_model)) + + def _fit(self, dataset): + self._validate_params() + # Unwrap the VectorUDT type column "feature" to 4 primitive columns: + # ['features.type', 'features.size', 'features.indices', 'features.values'] + features_col = col(self.getOrDefault(self.featuresCol)) + label_col = col(self.getOrDefault(self.labelCol)).alias('label') + features_array_col = vector_to_array(features_col, dtype="float32").alias("values") + select_cols = [features_array_col, label_col] + + has_weight = False + has_validation = False + has_base_margin = False + + if self.isDefined(self.weightCol) and self.getOrDefault( + self.weightCol): + has_weight = True + select_cols.append( + col(self.getOrDefault(self.weightCol)).alias('weight')) + + if self.isDefined(self.validationIndicatorCol) and \ + self.getOrDefault(self.validationIndicatorCol): + has_validation = True + select_cols.append( + col(self.getOrDefault( + self.validationIndicatorCol)).alias('validationIndicator')) + + if self.isDefined(self.baseMarginCol) and self.getOrDefault( + self.baseMarginCol): + has_base_margin = True + select_cols.append( + col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin")) + + dataset = dataset.select(*select_cols) + # create local var `xgb_model_creator` to avoid pickle `self` object to remote worker + xgb_model_creator = self._get_xgb_model_creator() # pylint: disable=E1111 + fit_params = self._gen_fit_params_dict() + + if self.getOrDefault(self.num_workers) > 1: + return self._fit_distributed(xgb_model_creator, dataset, has_weight, + has_validation, fit_params) + + # Note: fit_params will be pickled to remote, it may include `xgb_model` param + # which is used as initial model in training. The initial model will be a + # `Booster` instance which support pickling. + def train_func(pandas_df_iter): + xgb_model = xgb_model_creator() + train_val_data = prepare_train_val_data(pandas_df_iter, has_weight, + has_validation, + has_base_margin) + # We don't need to handle callbacks param in fit_params specially. + # User need to ensure callbacks is pickle-able. + if has_validation: + train_X, train_y, train_w, train_base_margin, val_X, val_y, val_w, _ = \ + train_val_data + eval_set = [(val_X, val_y)] + sample_weight_eval_set = [val_w] + # base_margin_eval_set = [val_base_margin] <- the underline + # Note that on XGBoost 1.2.0, the above doesn't exist. + xgb_model.fit(train_X, + train_y, + sample_weight=train_w, + base_margin=train_base_margin, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + **fit_params) + else: + train_X, train_y, train_w, train_base_margin = train_val_data + xgb_model.fit(train_X, + train_y, + sample_weight=train_w, + base_margin=train_base_margin, + **fit_params) + + ser_model_string = serialize_xgb_model(xgb_model) + yield pd.DataFrame(data={'model_string': [ser_model_string]}) + + # Train on 1 remote worker, return the string of the serialized model + result_ser_model_string = dataset.repartition(1) \ + .mapInPandas(train_func, schema='model_string string').collect()[0][0] + + # Load model + result_xgb_model = deserialize_xgb_model(result_ser_model_string, + xgb_model_creator) + return self._copyValues(self._create_pyspark_model(result_xgb_model)) + + def write(self): + return XgboostWriter(self) + + @classmethod + def read(cls): + return XgboostReader(cls) + + +class _XgboostModel(Model, _XgboostParams, MLReadable, MLWritable): + def __init__(self, xgb_sklearn_model=None): + super().__init__() + self._xgb_sklearn_model = xgb_sklearn_model + + def get_booster(self): + """ + Return the `xgboost.core.Booster` instance. + """ + return self._xgb_sklearn_model.get_booster() + + def get_feature_importances(self, importance_type='weight'): + """Get feature importance of each feature. + Importance type can be defined as: + + * 'weight': the number of times a feature is used to split the data across all trees. + * 'gain': the average gain across all splits the feature is used in. + * 'cover': the average coverage across all splits the feature is used in. + * 'total_gain': the total gain across all splits the feature is used in. + * 'total_cover': the total coverage across all splits the feature is used in. + + .. note:: Feature importance is defined only for tree boosters + + Feature importance is only defined when the decision tree model is chosen as base + learner (`booster=gbtree`). It is not defined for other base learner types, such + as linear learners (`booster=gblinear`). + + Parameters + ---------- + importance_type: str, default 'weight' + One of the importance types defined above. + """ + return self.get_booster().get_score(importance_type=importance_type) + + def write(self): + return XgboostModelWriter(self) + + @classmethod + def read(cls): + return XgboostModelReader(cls) + + def _transform(self, dataset): + raise NotImplementedError() + + +class XgboostRegressorModel(_XgboostModel): + """ + The model returned by :func:`xgboost.spark.XgboostRegressor.fit` + + .. Note:: This API is experimental. + """ + + @classmethod + def _xgb_cls(cls): + return XGBRegressor + + def _transform(self, dataset): + # Save xgb_sklearn_model and predict_params to be local variable + # to avoid the `self` object to be pickled to remote. + xgb_sklearn_model = self._xgb_sklearn_model + predict_params = self._gen_predict_params_dict() + + @pandas_udf('double') + def predict_udf(iterator: Iterator[Tuple[pd.Series]]) \ + -> Iterator[pd.Series]: + # deserialize model from ser_model_string, avoid pickling model to remote worker + X, _, _, _ = prepare_predict_data(iterator, False) + # Note: In every spark job task, pandas UDF will run in separate python process + # so it is safe here to call the thread-unsafe model.predict method + if len(X) > 0: + preds = xgb_sklearn_model.predict(X, **predict_params) + yield pd.Series(preds) + + @pandas_udf('double') + def predict_udf_base_margin(iterator: Iterator[Tuple[pd.Series, pd.Series]]) \ + -> Iterator[pd.Series]: + # deserialize model from ser_model_string, avoid pickling model to remote worker + X, _, _, b_m = prepare_predict_data(iterator, True) + # Note: In every spark job task, pandas UDF will run in separate python process + # so it is safe here to call the thread-unsafe model.predict method + if len(X) > 0: + preds = xgb_sklearn_model.predict(X, + base_margin=b_m, + **predict_params) + yield pd.Series(preds) + + features_col = col(self.getOrDefault(self.featuresCol)) + features_col = struct(vector_to_array(features_col, dtype="float32").alias("values")) + + has_base_margin = False + if self.isDefined(self.baseMarginCol) and self.getOrDefault( + self.baseMarginCol): + has_base_margin = True + + if has_base_margin: + base_margin_col = col(self.getOrDefault(self.baseMarginCol)) + pred_col = predict_udf_base_margin(features_col, + base_margin_col) + else: + pred_col = predict_udf(features_col) + + predictionColName = self.getOrDefault(self.predictionCol) + + return dataset.withColumn(predictionColName, pred_col) + + +class XgboostClassifierModel(_XgboostModel, HasProbabilityCol, + HasRawPredictionCol): + """ + The model returned by :func:`xgboost.spark.XgboostClassifier.fit` + + .. Note:: This API is experimental. + """ + + @classmethod + def _xgb_cls(cls): + return XGBClassifier + + def _transform(self, dataset): + # Save xgb_sklearn_model and predict_params to be local variable + # to avoid the `self` object to be pickled to remote. + xgb_sklearn_model = self._xgb_sklearn_model + predict_params = self._gen_predict_params_dict() + + @pandas_udf( + 'rawPrediction array, prediction double, probability array' + ) + def predict_udf(iterator: Iterator[Tuple[pd.Series]]) \ + -> Iterator[pd.DataFrame]: + # deserialize model from ser_model_string, avoid pickling model to remote worker + X, _, _, _ = prepare_predict_data(iterator, False) + # Note: In every spark job task, pandas UDF will run in separate python process + # so it is safe here to call the thread-unsafe model.predict method + if len(X) > 0: + margins = xgb_sklearn_model.predict(X, + output_margin=True, + **predict_params) + if margins.ndim == 1: + # binomial case + classone_probs = expit(margins) + classzero_probs = 1.0 - classone_probs + raw_preds = np.vstack((-margins, margins)).transpose() + class_probs = np.vstack( + (classzero_probs, classone_probs)).transpose() + else: + # multinomial case + raw_preds = margins + class_probs = softmax(raw_preds, axis=1) + + # It seems that they use argmax of class probs, + # not of margin to get the prediction (Note: scala implementation) + preds = np.argmax(class_probs, axis=1) + yield pd.DataFrame( + data={ + 'rawPrediction': pd.Series(raw_preds.tolist()), + 'prediction': pd.Series(preds), + 'probability': pd.Series(class_probs.tolist()) + }) + + @pandas_udf( + 'rawPrediction array, prediction double, probability array' + ) + def predict_udf_base_margin(iterator: Iterator[Tuple[pd.Series, pd.Series]])\ + -> Iterator[pd.DataFrame]: + # deserialize model from ser_model_string, avoid pickling model to remote worker + X, _, _, b_m = prepare_predict_data(iterator, True) + # Note: In every spark job task, pandas UDF will run in separate python process + # so it is safe here to call the thread-unsafe model.predict method + if len(X) > 0: + margins = xgb_sklearn_model.predict(X, + base_margin=b_m, + output_margin=True, + **predict_params) + if margins.ndim == 1: + # binomial case + classone_probs = expit(margins) + classzero_probs = 1.0 - classone_probs + raw_preds = np.vstack((-margins, margins)).transpose() + class_probs = np.vstack( + (classzero_probs, classone_probs)).transpose() + else: + # multinomial case + raw_preds = margins + class_probs = softmax(raw_preds, axis=1) + + # It seems that they use argmax of class probs, + # not of margin to get the prediction (Note: scala implementation) + preds = np.argmax(class_probs, axis=1) + yield pd.DataFrame( + data={ + 'rawPrediction': pd.Series(raw_preds.tolist()), + 'prediction': pd.Series(preds), + 'probability': pd.Series(class_probs.tolist()) + }) + + features_col = col(self.getOrDefault(self.featuresCol)) + features_col = struct(vector_to_array(features_col, dtype="float32").alias("values")) + + has_base_margin = False + if self.isDefined(self.baseMarginCol) and self.getOrDefault( + self.baseMarginCol): + has_base_margin = True + + if has_base_margin: + base_margin_col = col(self.getOrDefault(self.baseMarginCol)) + pred_struct = predict_udf_base_margin(features_col, + base_margin_col) + else: + pred_struct = predict_udf(features_col) + + pred_struct_col = '_prediction_struct' + + rawPredictionColName = self.getOrDefault(self.rawPredictionCol) + predictionColName = self.getOrDefault(self.predictionCol) + probabilityColName = self.getOrDefault(self.probabilityCol) + dataset = dataset.withColumn(pred_struct_col, pred_struct) + if rawPredictionColName: + dataset = dataset.withColumn( + rawPredictionColName, + array_to_vector(col(pred_struct_col).rawPrediction)) + if predictionColName: + dataset = dataset.withColumn(predictionColName, + col(pred_struct_col).prediction) + if probabilityColName: + dataset = dataset.withColumn( + probabilityColName, + array_to_vector(col(pred_struct_col).probability)) + + return dataset.drop(pred_struct_col) + + +def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, + pyspark_model_class): + params_dict = pyspark_estimator_class._get_xgb_params_default() + + def param_value_converter(v): + if isinstance(v, np.generic): + # convert numpy scalar values to corresponding python scalar values + return np.array(v).item() + elif isinstance(v, dict): + return {k: param_value_converter(nv) for k, nv in v.items()} + elif isinstance(v, list): + return [param_value_converter(nv) for nv in v] + else: + return v + + def set_param_attrs(attr_name, param_obj_): + param_obj_.typeConverter = param_value_converter + setattr(pyspark_estimator_class, attr_name, param_obj_) + setattr(pyspark_model_class, attr_name, param_obj_) + + for name in params_dict.keys(): + if name == 'missing': + doc = 'Specify the missing value in the features, default np.nan. ' \ + 'We recommend using 0.0 as the missing value for better performance. ' \ + 'Note: In a spark DataFrame, the inactive values in a sparse vector ' \ + 'mean 0 instead of missing values, unless missing=0 is specified.' + else: + doc = f'Refer to XGBoost doc of ' \ + f'{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}' + + param_obj = Param(Params._dummy(), name=name, doc=doc) + set_param_attrs(name, param_obj) + + fit_params_dict = pyspark_estimator_class._get_fit_params_default() + for name in fit_params_dict.keys(): + doc = f'Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}' \ + f'.fit() for this param {name}' + if name == 'callbacks': + doc += 'The callbacks can be arbitrary functions. It is saved using cloudpickle ' \ + 'which is not a fully self-contained format. It may fail to load with ' \ + 'different versions of dependencies.' + param_obj = Param(Params._dummy(), name=name, doc=doc) + set_param_attrs(name, param_obj) + + predict_params_dict = pyspark_estimator_class._get_predict_params_default() + for name in predict_params_dict.keys(): + doc = f'Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}' \ + f'.predict() for this param {name}' + param_obj = Param(Params._dummy(), name=name, doc=doc) + set_param_attrs(name, param_obj) diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py new file mode 100644 index 000000000000..b0ca24c7e36c --- /dev/null +++ b/python-package/xgboost/spark/data.py @@ -0,0 +1,262 @@ +import os +from typing import Iterator +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix +from xgboost import DMatrix + + +# Since sklearn's SVM converter doesn't address weights, this one does address weights: +def _dump_libsvm(features, labels, weights=None, external_storage_precision=5): + esp = external_storage_precision + lines = [] + + def gen_label_str(row_idx): + if weights is not None: + return "{label:.{esp}g}:{weight:.{esp}g}".format( + label=labels[row_idx], esp=esp, weight=weights[row_idx]) + else: + return "{label:.{esp}g}".format(label=labels[row_idx], esp=esp) + + def gen_feature_value_str(feature_idx, feature_val): + return "{idx:.{esp}g}:{value:.{esp}g}".format( + idx=feature_idx, esp=esp, value=feature_val + ) + + is_csr_matrix = isinstance(features, csr_matrix) + + for i in range(len(labels)): + current = [gen_label_str(i)] + if is_csr_matrix: + idx_start = features.indptr[i] + idx_end = features.indptr[i + 1] + for idx in range(idx_start, idx_end): + j = features.indices[idx] + val = features.data[idx] + current.append(gen_feature_value_str(j, val)) + else: + for j, val in enumerate(features[i]): + current.append(gen_feature_value_str(j, val)) + lines.append(" ".join(current) + "\n") + return lines + + +# This is the updated version that handles weights +def _stream_train_val_data(features, labels, weights, main_file, + external_storage_precision): + lines = _dump_libsvm(features, labels, weights, external_storage_precision) + main_file.writelines(lines) + + +def _stream_data_into_libsvm_file(data_iterator, has_weight, + has_validation, file_prefix, + external_storage_precision): + # getting the file names for storage + train_file_name = file_prefix + "/data.txt.train" + train_file = open(train_file_name, "w") + if has_validation: + validation_file_name = file_prefix + "/data.txt.val" + validation_file = open(validation_file_name, "w") + + train_val_data = _process_data_iter(data_iterator, + train=True, + has_weight=has_weight, + has_validation=has_validation) + if has_validation: + train_X, train_y, train_w, _, val_X, val_y, val_w, _ = train_val_data + _stream_train_val_data(train_X, train_y, train_w, train_file, + external_storage_precision) + _stream_train_val_data(val_X, val_y, val_w, validation_file, + external_storage_precision) + else: + train_X, train_y, train_w, _ = train_val_data + _stream_train_val_data(train_X, train_y, train_w, train_file, + external_storage_precision) + + if has_validation: + train_file.close() + validation_file.close() + return train_file_name, validation_file_name + else: + train_file.close() + return train_file_name + + +def _create_dmatrix_from_file(file_name, cache_name): + if os.path.exists(cache_name): + os.remove(cache_name) + if os.path.exists(cache_name + ".row.page"): + os.remove(cache_name + ".row.page") + if os.path.exists(cache_name + ".sorted.col.page"): + os.remove(cache_name + ".sorted.col.page") + return DMatrix(file_name + "#" + cache_name) + + +def prepare_train_val_data(data_iterator, + has_weight, + has_validation, + has_fit_base_margin=False): + def gen_data_pdf(): + for pdf in data_iterator: + yield pdf + + return _process_data_iter(gen_data_pdf(), + train=True, + has_weight=has_weight, + has_validation=has_validation, + has_fit_base_margin=has_fit_base_margin, + has_predict_base_margin=False) + + +def prepare_predict_data(data_iterator, has_predict_base_margin): + return _process_data_iter(data_iterator, + train=False, + has_weight=False, + has_validation=False, + has_fit_base_margin=False, + has_predict_base_margin=has_predict_base_margin) + + +def _check_feature_dims(num_dims, expected_dims): + """ + Check all feature vectors has the same dimension + """ + if expected_dims is None: + return num_dims + if num_dims != expected_dims: + raise ValueError("Rows contain different feature dimensions: " + "Expecting {}, got {}.".format( + expected_dims, num_dims)) + return expected_dims + + +def _row_tuple_list_to_feature_matrix_y_w(data_iterator, train, has_weight, + has_fit_base_margin, + has_predict_base_margin, + has_validation: bool = False): + """ + Construct a feature matrix in ndarray format, label array y and weight array w + from the row_tuple_list. + If train == False, y and w will be None. + If has_weight == False, w will be None. + If has_base_margin == False, b_m will be None. + Note: the row_tuple_list will be cleared during + executing for reducing peak memory consumption + """ + expected_feature_dims = None + label_list, weight_list, base_margin_list = [], [], [] + label_val_list, weight_val_list, base_margin_val_list = [], [], [] + values_list, values_val_list = [], [] + + # Process rows + for pdf in data_iterator: + if type(pdf) == tuple: + pdf = pd.concat(list(pdf), axis=1, names=["values", "baseMargin"]) + + if len(pdf) == 0: + continue + if train and has_validation: + pdf_val = pdf.loc[pdf["validationIndicator"], :] + pdf = pdf.loc[~pdf["validationIndicator"], :] + + num_feature_dims = len(pdf["values"].values[0]) + + expected_feature_dims = _check_feature_dims(num_feature_dims, + expected_feature_dims) + + values_list.append(pdf["values"].to_list()) + if train: + label_list.append(pdf["label"].to_list()) + if has_weight: + weight_list.append(pdf["weight"].to_list()) + if has_fit_base_margin or has_predict_base_margin: + base_margin_list.append(pdf.iloc[:, -1].to_list()) + if has_validation: + values_val_list.append(pdf_val["values"].to_list()) + if train: + label_val_list.append(pdf_val["label"].to_list()) + if has_weight: + weight_val_list.append(pdf_val["weight"].to_list()) + if has_fit_base_margin or has_predict_base_margin: + base_margin_val_list.append(pdf_val.iloc[:, -1].to_list()) + + # Construct feature_matrix + if expected_feature_dims is None: + return [], [], [], [] + + # Construct feature_matrix, y and w + feature_matrix = np.concatenate(values_list) + y = np.concatenate(label_list) if train else None + w = np.concatenate(weight_list) if has_weight else None + b_m = np.concatenate(base_margin_list) if ( + has_fit_base_margin or has_predict_base_margin) else None + if has_validation: + feature_matrix_val = np.concatenate(values_val_list) + y_val = np.concatenate(label_val_list) if train else None + w_val = np.concatenate(weight_val_list) if has_weight else None + b_m_val = np.concatenate(base_margin_val_list) if ( + has_fit_base_margin or has_predict_base_margin) else None + return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val + return feature_matrix, y, w, b_m + + +def _process_data_iter(data_iterator: Iterator[pd.DataFrame], + train: bool, + has_weight: bool, + has_validation: bool, + has_fit_base_margin: bool = False, + has_predict_base_margin: bool = False): + """ + If input is for train and has_validation=True, it will split the train data into train dataset + and validation dataset, and return (train_X, train_y, train_w, train_b_m <- + train base margin, val_X, val_y, val_w, val_b_m <- validation base margin) + otherwise return (X, y, w, b_m <- base margin) + """ + if train and has_validation: + train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m = \ + _row_tuple_list_to_feature_matrix_y_w( + data_iterator, train, has_weight, has_fit_base_margin, + has_predict_base_margin, has_validation) + return train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m + else: + return _row_tuple_list_to_feature_matrix_y_w(data_iterator, train, has_weight, + has_fit_base_margin, has_predict_base_margin, + has_validation) + + +def convert_partition_data_to_dmatrix(partition_data_iter, + has_weight, + has_validation, + use_external_storage=False, + file_prefix=None, + external_storage_precision=5): + # if we are using external storage, we use a different approach for making the dmatrix + if use_external_storage: + if has_validation: + train_file, validation_file = _stream_data_into_libsvm_file( + partition_data_iter, has_weight, + has_validation, file_prefix, external_storage_precision) + training_dmatrix = _create_dmatrix_from_file( + train_file, "{}/train.cache".format(file_prefix)) + val_dmatrix = _create_dmatrix_from_file( + validation_file, "{}/val.cache".format(file_prefix)) + return training_dmatrix, val_dmatrix + else: + train_file = _stream_data_into_libsvm_file( + partition_data_iter, has_weight, + has_validation, file_prefix, external_storage_precision) + training_dmatrix = _create_dmatrix_from_file( + train_file, "{}/train.cache".format(file_prefix)) + return training_dmatrix + + # if we are not using external storage, we use the standard method of parsing data. + train_val_data = prepare_train_val_data(partition_data_iter, has_weight, has_validation) + if has_validation: + train_X, train_y, train_w, _, val_X, val_y, val_w, _ = train_val_data + training_dmatrix = DMatrix(data=train_X, label=train_y, weight=train_w) + val_dmatrix = DMatrix(data=val_X, label=val_y, weight=val_w) + return training_dmatrix, val_dmatrix + else: + train_X, train_y, train_w, _ = train_val_data + training_dmatrix = DMatrix(data=train_X, label=train_y, weight=train_w) + return training_dmatrix diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py new file mode 100644 index 000000000000..d54903e69fec --- /dev/null +++ b/python-package/xgboost/spark/estimator.py @@ -0,0 +1,192 @@ +from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol +from xgboost import XGBClassifier, XGBRegressor +from .core import (_XgboostEstimator, XgboostClassifierModel, + XgboostRegressorModel, _set_pyspark_xgb_cls_param_attrs) + + +class XgboostRegressor(_XgboostEstimator): + """ + XgboostRegressor is a PySpark ML estimator. It implements the XGBoost regression + algorithm based on XGBoost python library, and it can be used in PySpark Pipeline + and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest. + + XgboostRegressor automatically supports most of the parameters in + `xgboost.XGBRegressor` constructor and most of the parameters used in + `xgboost.XGBRegressor` fit and predict method (see `API docs `_ for details). + + XgboostRegressor doesn't support setting `gpu_id` but support another param `use_gpu`, + see doc below for more details. + + XgboostRegressor doesn't support setting `base_margin` explicitly as well, but support + another param called `baseMarginCol`. see doc below for more details. + + XgboostRegressor doesn't support `validate_features` and `output_margin` param. + + :param callbacks: The export and import of the callback functions are at best effort. + For details, see :py:attr:`xgboost.spark.XgboostRegressor.callbacks` param doc. + :param missing: The parameter `missing` in XgboostRegressor has different semantics with + that in `xgboost.XGBRegressor`. For details, see + :py:attr:`xgboost.spark.XgboostRegressor.missing` param doc. + :param validationIndicatorCol: For params related to `xgboost.XGBRegressor` training + with evaluation dataset's supervision, set + :py:attr:`xgboost.spark.XgboostRegressor.validationIndicatorCol` + parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor` + fit method. + :param weightCol: To specify the weight of the training and validation dataset, set + :py:attr:`xgboost.spark.XgboostRegressor.weightCol` parameter instead of setting + `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor` + fit method. + :param xgb_model: Set the value to be the instance returned by + :func:`xgboost.spark.XgboostRegressorModel.get_booster`. + :param num_workers: Integer that specifies the number of XGBoost workers to use. + Each XGBoost worker corresponds to one spark task. + :param use_gpu: Boolean that specifies whether the executors are running on GPU + instances. + :param use_external_storage: Boolean that specifices whether you want to use + external storage when training in a distributed manner. This allows using disk + as cache. Setting this to true is useful when you want better memory utilization + but is not needed for small test datasets. + :param baseMarginCol: To specify the base margins of the training and validation + dataset, set :py:attr:`xgboost.spark.XgboostRegressor.baseMarginCol` parameter + instead of setting `base_margin` and `base_margin_eval_set` in the + `xgboost.XGBRegressor` fit method. Note: this isn't available for distributed + training. + + .. Note:: The Parameters chart above contains parameters that need special handling. + For a full list of parameters, see entries with `Param(parent=...` below. + + .. Note:: This API is experimental. + + **Examples** + + >>> from xgboost.spark import XgboostRegressor + >>> from pyspark.ml.linalg import Vectors + >>> df_train = spark.createDataFrame([ + ... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), + ... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), + ... (Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0), + ... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0), + ... ], ["features", "label", "isVal", "weight"]) + >>> df_test = spark.createDataFrame([ + ... (Vectors.dense(1.0, 2.0, 3.0), ), + ... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), ) + ... ], ["features"]) + >>> xgb_regressor = XgboostRegressor(max_depth=5, missing=0.0, + ... validationIndicatorCol='isVal', weightCol='weight', + ... early_stopping_rounds=1, eval_metric='rmse') + >>> xgb_reg_model = xgb_regressor.fit(df_train) + >>> xgb_reg_model.transform(df_test) + + """ + def __init__(self, **kwargs): + super().__init__() + self.setParams(**kwargs) + + @classmethod + def _xgb_cls(cls): + return XGBRegressor + + @classmethod + def _pyspark_model_cls(cls): + return XgboostRegressorModel + + +_set_pyspark_xgb_cls_param_attrs(XgboostRegressor, XgboostRegressorModel) + + +class XgboostClassifier(_XgboostEstimator, HasProbabilityCol, + HasRawPredictionCol): + """ + XgboostClassifier is a PySpark ML estimator. It implements the XGBoost classification + algorithm based on XGBoost python library, and it can be used in PySpark Pipeline + and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest. + + XgboostClassifier automatically supports most of the parameters in + `xgboost.XGBClassifier` constructor and most of the parameters used in + `xgboost.XGBClassifier` fit and predict method (see `API docs `_ for details). + + XgboostClassifier doesn't support setting `gpu_id` but support another param `use_gpu`, + see doc below for more details. + + XgboostClassifier doesn't support setting `base_margin` explicitly as well, but support + another param called `baseMarginCol`. see doc below for more details. + + XgboostClassifier doesn't support setting `output_margin`, but we can get output margin + from the raw prediction column. See `rawPredictionCol` param doc below for more details. + + XgboostClassifier doesn't support `validate_features` and `output_margin` param. + + :param callbacks: The export and import of the callback functions are at best effort. For + details, see :py:attr:`xgboost.spark.XgboostClassifier.callbacks` param doc. + :param missing: The parameter `missing` in XgboostClassifier has different semantics with + that in `xgboost.XGBClassifier`. For details, see + :py:attr:`xgboost.spark.XgboostClassifier.missing` param doc. + :param rawPredictionCol: The `output_margin=True` is implicitly supported by the + `rawPredictionCol` output column, which is always returned with the predicted margin + values. + :param validationIndicatorCol: For params related to `xgboost.XGBClassifier` training with + evaluation dataset's supervision, + set :py:attr:`xgboost.spark.XgboostClassifier.validationIndicatorCol` + parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier` + fit method. + :param weightCol: To specify the weight of the training and validation dataset, set + :py:attr:`xgboost.spark.XgboostClassifier.weightCol` parameter instead of setting + `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier` + fit method. + :param xgb_model: Set the value to be the instance returned by + :func:`xgboost.spark.XgboostClassifierModel.get_booster`. + :param num_workers: Integer that specifies the number of XGBoost workers to use. + Each XGBoost worker corresponds to one spark task. + :param use_gpu: Boolean that specifies whether the executors are running on GPU + instances. + :param use_external_storage: Boolean that specifices whether you want to use + external storage when training in a distributed manner. This allows using disk + as cache. Setting this to true is useful when you want better memory utilization + but is not needed for small test datasets. + :param baseMarginCol: To specify the base margins of the training and validation + dataset, set :py:attr:`xgboost.spark.XgboostClassifier.baseMarginCol` parameter + instead of setting `base_margin` and `base_margin_eval_set` in the + `xgboost.XGBClassifier` fit method. Note: this isn't available for distributed + training. + + .. Note:: The Parameters chart above contains parameters that need special handling. + For a full list of parameters, see entries with `Param(parent=...` below. + + .. Note:: This API is experimental. + + **Examples** + + >>> from xgboost.spark import XgboostClassifier + >>> from pyspark.ml.linalg import Vectors + >>> df_train = spark.createDataFrame([ + ... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), + ... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), + ... (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), + ... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), + ... ], ["features", "label", "isVal", "weight"]) + >>> df_test = spark.createDataFrame([ + ... (Vectors.dense(1.0, 2.0, 3.0), ), + ... ], ["features"]) + >>> xgb_classifier = XgboostClassifier(max_depth=5, missing=0.0, + ... validationIndicatorCol='isVal', weightCol='weight', + ... early_stopping_rounds=1, eval_metric='logloss') + >>> xgb_clf_model = xgb_classifier.fit(df_train) + >>> xgb_clf_model.transform(df_test).show() + + """ + def __init__(self, **kwargs): + super().__init__() + self.setParams(**kwargs) + + @classmethod + def _xgb_cls(cls): + return XGBClassifier + + @classmethod + def _pyspark_model_cls(cls): + return XgboostClassifierModel + + +_set_pyspark_xgb_cls_param_attrs(XgboostClassifier, XgboostClassifierModel) diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py new file mode 100644 index 000000000000..1edb44635490 --- /dev/null +++ b/python-package/xgboost/spark/model.py @@ -0,0 +1,232 @@ +import base64 +import os +import uuid + +from pyspark import cloudpickle +from pyspark import SparkFiles +from pyspark.sql import SparkSession +from pyspark.ml.util import (DefaultParamsReader, DefaultParamsWriter, + MLReader, MLWriter) +from xgboost.core import Booster + +from .utils import get_logger, get_class_name + + +def get_xgb_model_creator(model_cls, xgb_params): + """ + Returns a function that can be used to create an xgboost.XGBModel instance. + This function is used for creating the model instance on the worker, and is + shared by _XgboostEstimator and XgboostModel. + :param model_cls: a subclass of xgboost.XGBModel + :param xgb_params: a dict of params to initialize the model_cls + """ + return lambda: model_cls(**xgb_params) # pylint: disable=W0108 + + +def _get_or_create_tmp_dir(): + root_dir = SparkFiles.getRootDirectory() + xgb_tmp_dir = os.path.join(root_dir, 'xgboost-tmp') + if not os.path.exists(xgb_tmp_dir): + os.makedirs(xgb_tmp_dir) + return xgb_tmp_dir + + +def serialize_xgb_model(model): + """ + Serialize the input model to a string. + :param model: an xgboost.XGBModel instance, + such as xgboost.XGBClassifier or xgboost.XGBRegressor instance + """ + # TODO: change to use string io + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + model.save_model(tmp_file_name) + with open(tmp_file_name) as f: + ser_model_string = f.read() + return ser_model_string + + +def deserialize_xgb_model(ser_model_string, xgb_model_creator): + """ + Deserialize an xgboost.XGBModel instance from the input ser_model_string. + """ + xgb_model = xgb_model_creator() + # TODO: change to use string io + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + with open(tmp_file_name, "w") as f: + f.write(ser_model_string) + xgb_model.load_model(tmp_file_name) + return xgb_model + + +def serialize_booster(booster): + """ + Serialize the input booster to a string. + :param booster: an xgboost.core.Booster instance + """ + # TODO: change to use string io + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + booster.save_model(tmp_file_name) + with open(tmp_file_name) as f: + ser_model_string = f.read() + return ser_model_string + + +def deserialize_booster(ser_model_string): + """ + Deserialize an xgboost.core.Booster from the input ser_model_string. + """ + booster = Booster() + # TODO: change to use string io + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + with open(tmp_file_name, "w") as f: + f.write(ser_model_string) + booster.load_model(tmp_file_name) + return booster + + +_INIT_BOOSTER_SAVE_PATH = "init_booster.json" + + +def _get_spark_session(): + return SparkSession.builder.getOrCreate() + + +class XgboostSharedReadWrite: + + @staticmethod + def saveMetadata(instance, path, sc, logger, extraMetadata=None): + """ + Save the metadata of an xgboost.spark._XgboostEstimator or + xgboost.spark._XgboostModel. + """ + instance._validate_params() + skipParams = ['callbacks', 'xgb_model'] + jsonParams = {} + for p, v in instance._paramMap.items(): + if p.name not in skipParams: + jsonParams[p.name] = v + + extraMetadata = extraMetadata or {} + callbacks = instance.getOrDefault(instance.callbacks) + if callbacks is not None: + logger.warning('The callbacks parameter is saved using cloudpickle and it ' + 'is not a fully self-contained format. It may fail to load ' + 'with different versions of dependencies.') + serialized_callbacks = \ + base64.encodebytes(cloudpickle.dumps(callbacks)).decode('ascii') + extraMetadata['serialized_callbacks'] = serialized_callbacks + init_booster = instance.getOrDefault(instance.xgb_model) + if init_booster is not None: + extraMetadata['init_booster'] = _INIT_BOOSTER_SAVE_PATH + DefaultParamsWriter.saveMetadata( + instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams) + if init_booster is not None: + ser_init_booster = serialize_booster(init_booster) + save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH) + _get_spark_session().createDataFrame( + [(ser_init_booster,)], ['init_booster']).write.parquet(save_path) + + @staticmethod + def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): + """ + Load the metadata and the instance of an xgboost.spark._XgboostEstimator or + xgboost.spark._XgboostModel. + + :return: a tuple of (metadata, instance) + """ + metadata = DefaultParamsReader.loadMetadata( + path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)) + pyspark_xgb = pyspark_xgb_cls() + DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata) + + if 'serialized_callbacks' in metadata: + serialized_callbacks = metadata['serialized_callbacks'] + try: + callbacks = \ + cloudpickle.loads(base64.decodebytes(serialized_callbacks.encode('ascii'))) + pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) + except Exception as e: # pylint: disable=W0703 + logger.warning('Fails to load the callbacks param due to {}. Please set the ' + 'callbacks param manually for the loaded estimator.'.format(e)) + + if 'init_booster' in metadata: + load_path = os.path.join(path, metadata['init_booster']) + ser_init_booster = _get_spark_session().read.parquet(load_path) \ + .collect()[0].init_booster + init_booster = deserialize_booster(ser_init_booster) + pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) + + pyspark_xgb._resetUid(metadata["uid"]) + return metadata, pyspark_xgb + + +class XgboostWriter(MLWriter): + + def __init__(self, instance): + super().__init__() + self.instance = instance + self.logger = get_logger(self.__class__.__name__, level='WARN') + + def saveImpl(self, path): + XgboostSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + + +class XgboostReader(MLReader): + + def __init__(self, cls): + super().__init__() + self.cls = cls + self.logger = get_logger(self.__class__.__name__, level='WARN') + + def load(self, path): + _, pyspark_xgb = XgboostSharedReadWrite \ + .loadMetadataAndInstance(self.cls, path, self.sc, self.logger) + return pyspark_xgb + + +class XgboostModelWriter(MLWriter): + + def __init__(self, instance): + super().__init__() + self.instance = instance + self.logger = get_logger(self.__class__.__name__, level='WARN') + + def saveImpl(self, path): + """ + Save metadata and model for a :py:class:`_XgboostModel` + - save metadata to path/metadata + - save model to path/model.json + """ + xgb_model = self.instance._xgb_sklearn_model + XgboostSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + model_save_path = os.path.join(path, "model.json") + ser_xgb_model = serialize_xgb_model(xgb_model) + _get_spark_session().createDataFrame( + [(ser_xgb_model,)], ['xgb_sklearn_model']).write.parquet(model_save_path) + + +class XgboostModelReader(MLReader): + + def __init__(self, cls): + super().__init__() + self.cls = cls + self.logger = get_logger(self.__class__.__name__, level='WARN') + + def load(self, path): + """ + Load metadata and model for a :py:class:`_XgboostModel` + + :return: XgboostRegressorModel or XgboostClassifierModel instance + """ + _, py_model = XgboostSharedReadWrite.loadMetadataAndInstance( + self.cls, path, self.sc, self.logger) + + xgb_params = py_model._gen_xgb_params_dict() + model_load_path = os.path.join(path, "model.json") + + ser_xgb_model = _get_spark_session().read.parquet(model_load_path) \ + .collect()[0].xgb_sklearn_model + xgb_model = deserialize_xgb_model(ser_xgb_model, + lambda: self.cls._xgb_cls()(**xgb_params)) + py_model._xgb_sklearn_model = xgb_model + return py_model diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py new file mode 100644 index 000000000000..940fbab8d322 --- /dev/null +++ b/python-package/xgboost/spark/utils.py @@ -0,0 +1,180 @@ +import inspect +from threading import Thread +import sys +import logging + +from xgboost import rabit +from xgboost.tracker import RabitTracker +import pyspark +from pyspark.sql.session import SparkSession +from pyspark.ml.param.shared import Param, Params + + +def get_class_name(cls): + return f"{cls.__module__}.{cls.__name__}" + + +def _get_default_params_from_func(func, unsupported_set): + """ + Returns a dictionary of parameters and their default value of function fn. + Only the parameters with a default value will be included. + """ + sig = inspect.signature(func) + filtered_params_dict = dict() + for parameter in sig.parameters.values(): + # Remove parameters without a default value and those in the unsupported_set + if parameter.default is not parameter.empty \ + and parameter.name not in unsupported_set: + filtered_params_dict[parameter.name] = parameter.default + return filtered_params_dict + + +class HasArbitraryParamsDict(Params): + """ + This is a Params based class that is extended by _XGBoostParams + and holds the variable to store the **kwargs parts of the XGBoost + input. + """ + + arbitraryParamsDict = Param(Params._dummy(), "arbitraryParamsDict", + "This parameter holds all of the user defined parameters that" + " the sklearn implementation of XGBoost can't recognize. " + "It is stored as a dictionary.") + + def setArbitraryParamsDict(self, value): + return self._set(arbitraryParamsDict=value) + + def getArbitraryParamsDict(self, value): + return self.getOrDefault(self.arbitraryParamsDict) + + +class HasBaseMarginCol(Params): + """ + This is a Params based class that is extended by _XGBoostParams + and holds the variable to store the base margin column part of XGboost. + """ + baseMarginCol = Param( + Params._dummy(), "baseMarginCol", + "This stores the name for the column of the base margin") + + def setBaseMarginCol(self, value): + return self._set(baseMarginCol=value) + + def getBaseMarginCol(self, value): + return self.getOrDefault(self.baseMarginCol) + + +class RabitContext: + """ + A context controlling rabit initialization and finalization. + This isn't specificially necessary (note Part 3), but it is more understandable coding-wise. + """ + def __init__(self, args, context): + self.args = args + self.args.append( + ('DMLC_TASK_ID=' + str(context.partitionId())).encode()) + + def __enter__(self): + rabit.init(self.args) + + def __exit__(self, *args): + rabit.finalize() + + +def _start_tracker(context, n_workers): + """ + Start Rabit tracker with n_workers + """ + env = {'DMLC_NUM_WORKER': n_workers} + host = get_host_ip(context) + rabit_context = RabitTracker(hostIP=host, nslave=n_workers) + env.update(rabit_context.slave_envs()) + rabit_context.start(n_workers) + thread = Thread(target=rabit_context.join) + thread.daemon = True + thread.start() + return env + + +def _get_rabit_args(context, n_workers): + """ + Get rabit context arguments to send to each worker. + """ + env = _start_tracker(context, n_workers) + rabit_args = [('%s=%s' % item).encode() for item in env.items()] + return rabit_args + + +def get_host_ip(context): + """ + Gets the hostIP for Spark. This essentially gets the IP of the first worker. + """ + task_ip_list = [ + info.address.split(":")[0] for info in context.getTaskInfos() + ] + return task_ip_list[0] + + +def _get_args_from_message_list(messages): + """ + A function to send/recieve messages in barrier context mode + """ + output = "" + for message in messages: + if message != "": + output = message + break + return [ + elem.split("'")[1].encode() for elem in output.strip('][').split(', ') + ] + + +def _get_spark_session(): + """Get or create spark session. Note: This function can only be invoked from driver side.""" + if pyspark.TaskContext.get() is not None: + # This is a safety check. + raise RuntimeError( + '_get_spark_session should not be invoked from executor side.') + return SparkSession.builder.getOrCreate() + + +def _getConfBoolean(sqlContext, key, defaultValue): + """ + Get the conf "key" from the given sqlContext, + or return the default value if the conf is not set. + This expects the conf value to be a boolean or string; if the value is a string, + this checks for all capitalization patterns of "true" and "false" to match Scala. + :param key: string for conf name + """ + # Convert default value to str to avoid a Spark 2.3.1 + Python 3 bug: SPARK-25397 + val = sqlContext.getConf(key, str(defaultValue)) + # Convert val to str to handle unicode issues across Python 2 and 3. + lowercase_val = str(val.lower()) + if lowercase_val == 'true': + return True + elif lowercase_val == 'false': + return False + else: + raise Exception("_getConfBoolean expected a boolean conf value but found value of type {} " + "with value: {}".format(type(val), val)) + + +def get_logger(name, level='INFO'): + """ Gets a logger by name, or creates and configures it for the first time. """ + logger = logging.getLogger(name) + logger.setLevel(level) + # If the logger is configured, skip the configure + if not logger.handlers and not logging.getLogger().handlers: + handler = logging.StreamHandler(sys.stderr) + logger.addHandler(handler) + return logger + + +def _get_max_num_concurrent_tasks(sc): + """Gets the current max number of concurrent tasks.""" + # spark 3.1 and above has a different API for fetching max concurrent tasks + if sc._jsc.sc().version() >= '3.1': + return sc._jsc.sc().maxNumConcurrentTasks( + sc._jsc.sc().resourceProfileManager().resourceProfileFromId(0) + ) + return sc._jsc.sc().maxNumConcurrentTasks() diff --git a/tests/python/test_spark/__init__.py b/tests/python/test_spark/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py new file mode 100644 index 000000000000..014a394eed21 --- /dev/null +++ b/tests/python/test_spark/data_test.py @@ -0,0 +1,175 @@ +import tempfile +import shutil +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix + +from xgboost.spark.data import _row_tuple_list_to_feature_matrix_y_w, convert_partition_data_to_dmatrix, _dump_libsvm + +from xgboost import DMatrix, XGBClassifier +from xgboost.training import train as worker_train +from .utils_test import SparkTestCase +import logging +logging.getLogger("py4j").setLevel(logging.INFO) + + +class DataTest(SparkTestCase): + + def test_sparse_dense_vector(self): + def row_tup_iter(data): + pdf = pd.DataFrame(data) + yield pdf + + # row1 = Vectors.dense(1.0, 2.0, 3.0),), + # row2 = Vectors.sparse(3, {1: 1.0, 2: 5.5}) + expected_ndarray = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]} + feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( + list(row_tup_iter(data)), train=False, has_weight=False, has_fit_base_margin=False, has_predict_base_margin=False) + self.assertIsNone(y) + self.assertIsNone(w) + # self.assertTrue(isinstance(feature_matrix, csr_matrix)) + self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) + + data["label"] = [1, 0] + feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( + row_tup_iter(data), train=True, has_weight=False, has_fit_base_margin=False, has_predict_base_margin=False) + self.assertIsNone(w) + # self.assertTrue(isinstance(feature_matrix, csr_matrix)) + self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) + self.assertTrue(np.array_equal(y, np.array(data['label']))) + + data["weight"] = [0.2, 0.8] + feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( + list(row_tup_iter(data)), train=True, has_weight=True, has_fit_base_margin=False, has_predict_base_margin=False) + # self.assertTrue(isinstance(feature_matrix, csr_matrix)) + self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) + self.assertTrue(np.array_equal(y, np.array(data['label']))) + self.assertTrue(np.array_equal(w, np.array(data['weight']))) + + def test_dmatrix_creator(self): + + # This function acts as a pseudo-itertools.chain() + def row_tup_iter(data): + pdf = pd.DataFrame(data) + yield pdf + + # Standard testing DMatrix creation + expected_features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100) + expected_labels = np.array([1, 0] * 100) + expected_dmatrix = DMatrix(data=expected_features, label=expected_labels) + + data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, "label": [1, 0] * 100} + output_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=False, has_validation=False) + # You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using + # the same classifier and making sure the outputs are equal + model = XGBClassifier() + model.fit(expected_features, expected_labels) + expected_preds = model.get_booster().predict(expected_dmatrix) + output_preds = model.get_booster().predict(output_dmatrix) + self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3)) + + # DMatrix creation with weights + expected_weight = np.array([0.2, 0.8] * 100) + expected_dmatrix = DMatrix(data=expected_features, label=expected_labels, weight=expected_weight) + + data["weight"] = [0.2, 0.8] * 100 + output_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=True, has_validation=False) + + model.fit(expected_features, expected_labels, sample_weight=expected_weight) + expected_preds = model.get_booster().predict(expected_dmatrix) + output_preds = model.get_booster().predict(output_dmatrix) + self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3)) + + def test_external_storage(self): + # Instantiating base data (features, labels) + features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100) + labels = np.array([1, 0] * 100) + normal_dmatrix = DMatrix(features, labels) + test_dmatrix = DMatrix(features) + + data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, "label": [1, 0] * 100} + + # Creating the dmatrix based on storage + temporary_path = tempfile.mkdtemp() + storage_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=False, + has_validation=False, use_external_storage=True, + file_prefix=temporary_path) + + # Testing without weights + normal_booster = worker_train({}, normal_dmatrix) + storage_booster = worker_train({}, storage_dmatrix) + normal_preds = normal_booster.predict(test_dmatrix) + storage_preds = storage_booster.predict(test_dmatrix) + self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3)) + shutil.rmtree(temporary_path) + + # Testing weights + weights = np.array([0.2, 0.8] * 100) + normal_dmatrix = DMatrix(data=features, label=labels, weight=weights) + data["weight"] = [0.2, 0.8] * 100 + + temporary_path = tempfile.mkdtemp() + storage_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=True, + has_validation=False, use_external_storage=True, + file_prefix=temporary_path) + + normal_booster = worker_train({}, normal_dmatrix) + storage_booster = worker_train({}, storage_dmatrix) + normal_preds = normal_booster.predict(test_dmatrix) + storage_preds = storage_booster.predict(test_dmatrix) + self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3)) + shutil.rmtree(temporary_path) + + def test_dump_libsvm(self): + num_features = 3 + features_test_list = [ + [[1,2,3],[0,1,5.5]], + csr_matrix(([1, 2, 3], [0, 2, 2], [0, 2, 3]), shape=(2, 3)) + ] + labels = [0, 1] + + for features in features_test_list: + if isinstance(features, csr_matrix): + features_array = features.toarray() + else: + features_array = features + # testing without weights + # The format should be label index:feature_value index:feature_value... + # Note: from initial testing, it seems all of the indices must be listed regardless of whether + # they exist or not + output = _dump_libsvm(features, labels) + for i, line in enumerate(output): + split_line = line.split(" ") + self.assertEqual(float(split_line[0]), labels[i]) + split_line = [elem.split(":") for elem in split_line[1:]] + loaded_feature = [0.0] * num_features + for split in split_line: + loaded_feature[int(split[0])] = float(split[1]) + self.assertListEqual(loaded_feature, list(features_array[i])) + + weights = [0.2, 0.8] + # testing with weights + # The format should be label:weight index:feature_value index:feature_value... + output = _dump_libsvm(features, labels, weights) + for i, line in enumerate(output): + split_line = line.split(" ") + split_line = [elem.split(":") for elem in split_line] + self.assertEqual(float(split_line[0][0]), labels[i]) + self.assertEqual(float(split_line[0][1]), weights[i]) + + split_line = split_line[1:] + loaded_feature = [0.0] * num_features + for split in split_line: + loaded_feature[int(split[0])] = float(split[1]) + self.assertListEqual(loaded_feature, list(features_array[i])) + + features = [[1.34234,2.342321,3.34322],[0.344234,1.123123,5.534322],[3.553423e10,3.5632e10,0.00000000000012345]] + features_prec = [[1.34, 2.34, 3.34], [0.344, 1.12, 5.53],[3.55e10, 3.56e10, 1.23e-13]] + labels = [0, 1] + output = _dump_libsvm(features, labels, external_storage_precision=3) + for i, line in enumerate(output): + split_line = line.split(" ") + self.assertEqual(float(split_line[0]), labels[i]) + split_line = [elem.split(":") for elem in split_line[1:]] + self.assertListEqual([float(v[1]) for v in split_line], features_prec[i]) diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py new file mode 100644 index 000000000000..ca94c99a4068 --- /dev/null +++ b/tests/python/test_spark/utils_test.py @@ -0,0 +1,130 @@ +import unittest +import contextlib +import logging +import shutil +import subprocess +import sys +import tempfile + +import unittest + +from six import StringIO + +from pyspark.sql import SQLContext +from pyspark.sql import SparkSession +from pyspark.taskcontext import TaskContext + +from xgboost.spark.utils import _get_default_params_from_func + + +class UtilsTest(unittest.TestCase): + + def test_get_default_params(self): + + class Foo: + def func1(self, x, y, key1=None, key2="val2", key3=0, key4=None): + pass + + unsupported_params = {"key2", "key4"} + expected_default_params = { + "key1": None, + "key3": 0, + } + actual_default_params = _get_default_params_from_func(Foo.func1, unsupported_params) + self.assertEqual(len(expected_default_params.keys()), len(actual_default_params.keys())) + for k, v in actual_default_params.items(): + self.assertEqual(expected_default_params[k], v) + +@contextlib.contextmanager +def patch_stdout(): + """patch stdout and give an output""" + sys_stdout = sys.stdout + io_out = StringIO() + sys.stdout = io_out + try: + yield io_out + finally: + sys.stdout = sys_stdout + + +@contextlib.contextmanager +def patch_logger(name): + """patch logger and give an output""" + io_out = StringIO() + log = logging.getLogger(name) + handler = logging.StreamHandler(io_out) + log.addHandler(handler) + try: + yield io_out + finally: + log.removeHandler(handler) + + +class TestTempDir(object): + @classmethod + def make_tempdir(cls): + """ + :param dir: Root directory in which to create the temp directory + """ + cls.tempdir = tempfile.mkdtemp(prefix="sparkdl_tests") + + @classmethod + def remove_tempdir(cls): + shutil.rmtree(cls.tempdir) + + +class TestSparkContext(object): + @classmethod + def setup_env(cls, spark_config): + builder = SparkSession.builder.appName('xgboost spark python API Tests') + for k, v in spark_config.items(): + builder.config(k, v) + spark = builder.getOrCreate() + if spark_config['spark.master'].startswith('local-cluster'): + # We run a dummy job so that we block until the workers have connected to the master + spark.sparkContext.parallelize(range(2), 2).barrier().mapPartitions(lambda _: []).collect() + + logging.getLogger('pyspark').setLevel(logging.INFO) + + cls.sc = spark.sparkContext + cls.sql = SQLContext(cls.sc) + cls.session = spark + + @classmethod + def tear_down_env(cls): + cls.session.stop() + cls.session = None + cls.sc.stop() + cls.sc = None + cls.sql = None + + +class SparkTestCase(TestSparkContext, TestTempDir, unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.setup_env({ + 'spark.master': 'local[2]', + 'spark.python.worker.reuse': 'false', + }) + + @classmethod + def tearDownClass(cls): + cls.tear_down_env() + + +class SparkLocalClusterTestCase(TestSparkContext, TestTempDir, unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.setup_env({ + 'spark.master': 'local-cluster[2, 1, 1024]', + 'spark.python.worker.reuse': 'false', + }) + cls.make_tempdir() + cls.sc.parallelize(range(4), 4).barrier().mapPartitions(lambda _: []).collect() + + @classmethod + def tearDownClass(cls): + cls.remove_tempdir() + cls.tear_down_env() diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py new file mode 100644 index 000000000000..aee638fddb87 --- /dev/null +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -0,0 +1,382 @@ +import random + +import numpy as np +from pyspark.ml.linalg import Vectors + +from xgboost.spark import XgboostClassifier, XgboostRegressor +from .utils_test import SparkLocalClusterTestCase +from xgboost.spark.utils import _get_max_num_concurrent_tasks +import json +import uuid +import os + + +class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): + + def setUp(self): + random.seed(2020) + + self.n_workers = _get_max_num_concurrent_tasks(self.session.sparkContext) + # The following code use xgboost python library to train xgb model and predict. + # + # >>> import numpy as np + # >>> import xgboost + # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + # >>> y = np.array([0, 1]) + # >>> reg1 = xgboost.XGBRegressor() + # >>> reg1.fit(X, y) + # >>> reg1.predict(X) + # array([8.8363886e-04, 9.9911636e-01], dtype=float32) + # >>> def custom_lr(boosting_round, num_boost_round): + # ... return 1.0 / (boosting_round + 1) + # ... + # >>> reg1.fit(X, y, callbacks=[xgboost.callback.reset_learning_rate(custom_lr)]) + # >>> reg1.predict(X) + # array([0.02406833, 0.97593164], dtype=float32) + # >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10) + # >>> reg2.fit(X, y) + # >>> reg2.predict(X, ntree_limit=5) + # array([0.22185263, 0.77814734], dtype=float32) + self.reg_params = {'max_depth': 5, 'n_estimators': 10, 'ntree_limit': 5} + self.reg_df_train = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1) + ], ["features", "label"]) + self.reg_df_test = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759) + ], ["features", "expected_prediction", "expected_prediction_with_params", + "expected_prediction_with_callbacks"]) + + # Distributed section + # Binary classification + self.cls_df_train_distributed = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + (Vectors.dense(4.0, 5.0, 6.0), 0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1), + ] * 100, ["features", "label"]) + self.cls_df_test_distributed = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, [0.9949826, 0.0050174]), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0050174, 0.9949826]), + (Vectors.dense(4.0, 5.0, 6.0), 0, [0.9949826, 0.0050174]), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0050174, 0.9949826]), + ], ["features", "expected_label", "expected_probability"]) + # Binary classification with different num_estimators + self.cls_df_test_distributed_lower_estimators = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, [0.9735, 0.0265]), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0265, 0.9735]), + (Vectors.dense(4.0, 5.0, 6.0), 0, [0.9735, 0.0265]), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0265 , 0.9735]), + ], ["features", "expected_label", "expected_probability"]) + + # Multiclass classification + self.cls_df_train_distributed_multiclass = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + (Vectors.dense(4.0, 5.0, 6.0), 0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2), + ] * 100, ["features", "label"]) + self.cls_df_test_distributed_multiclass = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, [ 4.294563, -2.449409, -2.449409 ]), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [-2.3796105, 3.669014, -2.449409 ]), + (Vectors.dense(4.0, 5.0, 6.0), 0, [ 4.294563, -2.449409, -2.449409 ]), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2, [-2.3796105, -2.449409, 3.669014 ]), + ], ["features", "expected_label", "expected_margins"]) + + # Regression + self.reg_df_train_distributed = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + (Vectors.dense(4.0, 5.0, 6.0), 0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2), + ] * 100, ["features", "label"]) + self.reg_df_test_distributed = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 1.533e-04), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 9.999e-01), + (Vectors.dense(4.0, 5.0, 6.0), 1.533e-04), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1.999e+00), + ], ["features", "expected_label"]) + + # Adding weight and validation + self.clf_params_with_eval_dist = {'validationIndicatorCol': 'isVal','early_stopping_rounds': 1, 'eval_metric': 'logloss'} + self.clf_params_with_weight_dist = {'weightCol': 'weight'} + self.cls_df_train_distributed_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), + (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), + ] * 100, ["features", "label", "isVal", "weight"]) + self.cls_df_test_distributed_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), [0.9955, 0.0044], [0.9904, 0.0096], [0.9903, 0.0097]), + ], ["features", "expected_prob_with_weight", "expected_prob_with_eval", + "expected_prob_with_weight_and_eval"]) + self.clf_best_score_eval = 0.009677 + self.clf_best_score_weight_and_eval = 0.006628 + + self.reg_params_with_eval_dist = {'validationIndicatorCol': 'isVal','early_stopping_rounds': 1, 'eval_metric': 'rmse'} + self.reg_params_with_weight_dist = {'weightCol': 'weight'} + self.reg_df_train_distributed_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), + (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), + ] * 100, ["features", "label", "isVal", "weight"]) + self.reg_df_test_distributed_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 4.583e-05, 5.239e-05, 6.03e-05), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 9.9997e-01, 9.99947e-01, 9.9995e-01) + ], ["features", "expected_prediction_with_weight", "expected_prediction_with_eval", + "expected_prediction_with_weight_and_eval"]) + self.reg_best_score_eval = 5.2e-05 + self.reg_best_score_weight_and_eval = 4.9e-05 + + def test_regressor_basic_with_params(self): + regressor = XgboostRegressor(**self.reg_params) + model = regressor.fit(self.reg_df_train) + pred_result = model.transform(self.reg_df_test).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_params, atol=1e-3) + ) + + def test_callbacks(self): + from xgboost.callback import LearningRateScheduler + path = os.path.join(self.tempdir, str(uuid.uuid4())) + + def custom_learning_rate(boosting_round): + return 1.0 / (boosting_round + 1) + + cb = [LearningRateScheduler(custom_learning_rate)] + regressor = XgboostRegressor(callbacks=cb) + + # Test the save/load of the estimator instead of the model, since + # the callbacks param only exists in the estimator but not in the model + regressor.save(path) + regressor = XgboostRegressor.load(path) + + model = regressor.fit(self.reg_df_train) + pred_result = model.transform(self.reg_df_test).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_callbacks, atol=1e-3) + ) + + def test_classifier_distributed_basic(self): + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False) + model = classifier.fit(self.cls_df_train_distributed) + pred_result = model.transform(self.cls_df_test_distributed).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) + + def test_classifier_distributed_external_storage_basic(self): + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True) + model = classifier.fit(self.cls_df_train_distributed) + pred_result = model.transform(self.cls_df_test_distributed).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) + + def test_classifier_distributed_multiclass(self): + # There is no built-in multiclass option for external storage + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False) + model = classifier.fit(self.cls_df_train_distributed_multiclass) + pred_result = model.transform(self.cls_df_test_distributed_multiclass).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + self.assertTrue(np.allclose(row.expected_margins, row.rawPrediction, atol=1e-3)) + + def test_regressor_distributed_basic(self): + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False) + model = regressor.fit(self.reg_df_train_distributed) + pred_result = model.transform(self.reg_df_test_distributed).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + + def test_regressor_distributed_external_storage_basic(self): + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True) + model = regressor.fit(self.reg_df_train_distributed) + pred_result = model.transform(self.reg_df_test_distributed).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + + def check_use_gpu_param(self): + # Classifier + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_gpu=True, use_external_storage=False) + self.assertTrue(hasattr(classifier, 'use_gpu')) + self.assertTrue(classifier.getOrDefault(classifier.use_gpu)) + clf_model = classifier.fit(self.cls_df_train_distributed) + pred_result = model.transform(self.cls_df_test_distributed).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) + + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_gpu=True, use_external_storage=False) + self.assertTrue(hasattr(regressor, 'use_gpu')) + self.assertTrue(regressor.getOrDefault(regressor.use_gpu)) + model = regressor.fit(self.reg_df_train_distributed) + pred_result = model.transform(self.reg_df_test_distributed).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + + def test_classifier_distributed_weight_eval(self): + # with weight + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.clf_params_with_weight_dist) + model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_weight, atol=1e-3)) + + # with eval only + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.clf_params_with_eval_dist) + model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_eval) + + # with both weight and eval + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.clf_params_with_eval_dist, **self.clf_params_with_weight_dist) + model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_weight_and_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_weight_and_eval) + + def test_classifier_distributed_weight_eval_external_storage(self): + # with weight + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.clf_params_with_weight_dist) + model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_weight, atol=1e-3)) + + # with eval only + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.clf_params_with_eval_dist) + model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_eval) + + # with both weight and eval + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.clf_params_with_eval_dist, **self.clf_params_with_weight_dist) + model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_weight_and_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_weight_and_eval) + + def test_regressor_distributed_weight_eval(self): + # with weight + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.reg_params_with_weight_dist) + model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_weight, atol=1e-3)) + # with eval only + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.reg_params_with_eval_dist) + model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_eval) + # with both weight and eval + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.reg_params_with_eval_dist, **self.reg_params_with_weight_dist) + model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_weight_and_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_weight_and_eval) + + def test_regressor_distributed_weight_eval_external_storage(self): + # with weight + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.reg_params_with_weight_dist) + model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_weight, atol=1e-3)) + # with eval only + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.reg_params_with_eval_dist) + model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_eval) + # with both weight and eval + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.reg_params_with_eval_dist, **self.reg_params_with_weight_dist) + model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) + pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_weight_and_eval, atol=1e-3)) + self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_weight_and_eval) + + def test_num_estimators(self): + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=10, use_external_storage=False) + model = classifier.fit(self.cls_df_train_distributed) + pred_result = model.transform(self.cls_df_test_distributed_lower_estimators).collect() + print(pred_result) + for row in pred_result: + self.assertTrue(np.isclose(row.expected_label, + row.prediction, atol=1e-3)) + self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) + + def test_missing_value_zero_with_external_storage(self): + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=10, use_external_storage=False, + missing=0.0) + classifier.fit(self.cls_df_train_distributed) + + def test_distributed_params(self): + classifier = XgboostClassifier(num_workers=self.n_workers, max_depth=7) + model = classifier.fit(self.cls_df_train_distributed) + self.assertTrue(hasattr(classifier, 'max_depth')) + self.assertEqual(classifier.getOrDefault(classifier.max_depth), 7) + booster_config = json.loads(model.get_booster().save_config()) + max_depth = booster_config["learner"]["gradient_booster"]["updater"]["grow_histmaker"]["train_param"]["max_depth"] + self.assertEqual(int(max_depth), 7) + + def test_repartition(self): + # The following test case has a few partitioned datasets that are either + # well partitioned relative to the number of workers that the user wants + # or poorly partitioned. We only want to repartition when the dataset + # is poorly partitioned so _repartition_needed is true in those instances. + + classifier = XgboostClassifier(num_workers=self.n_workers) + basic = self.cls_df_train_distributed + self.assertTrue(classifier._repartition_needed(basic)) + bad_repartitioned = basic.repartition(self.n_workers + 1) + self.assertTrue(classifier._repartition_needed(bad_repartitioned)) + good_repartitioned = basic.repartition(self.n_workers) + self.assertFalse(classifier._repartition_needed(good_repartitioned)) + + # Now testing if force_repartition returns True regardless of whether the data is well partitioned + classifier = XgboostClassifier(num_workers=self.n_workers, force_repartition=True) + good_repartitioned = basic.repartition(self.n_workers) + self.assertTrue(classifier._repartition_needed(good_repartitioned)) diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py new file mode 100644 index 000000000000..fb8673e72cf4 --- /dev/null +++ b/tests/python/test_spark/xgboost_local_test.py @@ -0,0 +1,654 @@ +import logging +import random +import uuid + +import numpy as np +from pyspark.ml import Pipeline, PipelineModel +from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ + MulticlassClassificationEvaluator +from pyspark.ml.linalg import Vectors +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder + +from xgboost.spark import (XgboostClassifier, XgboostClassifierModel, + XgboostRegressor, XgboostRegressorModel) +from .utils_test import SparkTestCase +from xgboost import XGBClassifier, XGBRegressor + +logging.getLogger("py4j").setLevel(logging.INFO) + + +class XgboostLocalTest(SparkTestCase): + + def setUp(self): + logging.getLogger().setLevel('INFO') + random.seed(2020) + + # The following code use xgboost python library to train xgb model and predict. + # + # >>> import numpy as np + # >>> import xgboost + # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + # >>> y = np.array([0, 1]) + # >>> reg1 = xgboost.XGBRegressor() + # >>> reg1.fit(X, y) + # >>> reg1.predict(X) + # array([8.8375784e-04, 9.9911624e-01], dtype=float32) + # >>> def custom_lr(boosting_round): + # ... return 1.0 / (boosting_round + 1) + # ... + # >>> reg1.fit(X, y, callbacks=[xgboost.callback.LearningRateScheduler(custom_lr)]) + # >>> reg1.predict(X) + # array([0.02406844, 0.9759315 ], dtype=float32) + # >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10) + # >>> reg2.fit(X, y) + # >>> reg2.predict(X, ntree_limit=5) + # array([0.22185266, 0.77814734], dtype=float32) + self.reg_params = {'max_depth': 5, 'n_estimators': 10, 'ntree_limit': 5} + self.reg_df_train = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1) + ], ["features", "label"]) + self.reg_df_test = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759) + ], ["features", "expected_prediction", "expected_prediction_with_params", + "expected_prediction_with_callbacks"]) + + # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + # >>> y = np.array([0, 1]) + # >>> cl1 = xgboost.XGBClassifier() + # >>> cl1.fit(X, y) + # >>> cl1.predict(X) + # array([0, 0]) + # >>> cl1.predict_proba(X) + # array([[0.5, 0.5], + # [0.5, 0.5]], dtype=float32) + # >>> cl2 = xgboost.XGBClassifier(max_depth=5, n_estimators=10, scale_pos_weight=4) + # >>> cl2.fit(X, y) + # >>> cl2.predict(X) + # array([1, 1]) + # >>> cl2.predict_proba(X) + # array([[0.27574146, 0.72425854 ], + # [0.27574146, 0.72425854 ]], dtype=float32) + self.cls_params = {'max_depth': 5, 'n_estimators': 10, 'scale_pos_weight': 4} + + cls_df_train_data = [ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1) + ] + self.cls_df_train = self.session.createDataFrame( + cls_df_train_data, ["features", "label"]) + self.cls_df_train_large = self.session.createDataFrame( + cls_df_train_data * 100, ["features", "label"]) + self.cls_df_test = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, [0.5, 0.5], 1, [0.27574146, 0.72425854]), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, [0.5, 0.5], 1, [0.27574146, 0.72425854]) + ], ["features", + "expected_prediction", "expected_probability", + "expected_prediction_with_params", "expected_probability_with_params"]) + + # kwargs test (using the above data, train, we get the same results) + self.cls_params_kwargs = {'tree_method': 'approx', 'sketch_eps':0.03} + + # >>> X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]]) + # >>> y = np.array([0, 0, 1, 2]) + # >>> cl = xgboost.XGBClassifier() + # >>> cl.fit(X, y) + # >>> cl.predict_proba(np.array([[1.0, 2.0, 3.0]])) + # array([[0.5374299 , 0.23128504, 0.23128504]], dtype=float32) + multi_cls_df_train_data = [ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.dense(1.0, 2.0, 4.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + (Vectors.dense(-1.0, -2.0, 1.0), 2), + ] + self.multi_cls_df_train = self.session.createDataFrame( + multi_cls_df_train_data, ["features", "label"]) + self.multi_cls_df_train_large = self.session.createDataFrame( + multi_cls_df_train_data * 100, ["features", "label"]) + self.multi_cls_df_test = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), [0.5374, 0.2312, 0.2312]), + ], ["features", "expected_probability"]) + + # Test regressor with weight and eval set + # >>> import numpy as np + # >>> import xgboost + # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) + # >>> w = np.array([1.0, 2.0, 1.0, 2.0]) + # >>> y = np.array([0, 1, 2, 3]) + # >>> reg1 = xgboost.XGBRegressor() + # >>> reg1.fit(X, y, sample_weight=w) + # >>> reg1.predict(X) + # >>> array([1.0679445e-03, 1.0000550e+00, ... + # >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + # >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) + # >>> y_train = np.array([0, 1]) + # >>> y_val = np.array([2, 3]) + # >>> w_train = np.array([1.0, 2.0]) + # >>> w_val = np.array([1.0, 2.0]) + # >>> reg2 = xgboost.XGBRegressor() + # >>> reg2.fit(X_train, y_train, eval_set=[(X_val, y_val)], + # >>> early_stopping_rounds=1, eval_metric='rmse') + # >>> reg2.predict(X) + # >>> array([8.8370638e-04, 9.9911624e-01, ... + # >>> reg2.best_score + # 2.0000002682208837 + # >>> reg3 = xgboost.XGBRegressor() + # >>> reg3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)], + # >>> sample_weight_eval_set=[w_val], + # >>> early_stopping_rounds=1, eval_metric='rmse') + # >>> reg3.predict(X) + # >>> array([0.03155671, 0.98874104,... + # >>> reg3.best_score + # 1.9970891552124017 + self.reg_df_train_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), + (Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0), + ], ["features", "label", "isVal", "weight"]) + self.reg_params_with_eval = {'validationIndicatorCol': 'isVal', + 'early_stopping_rounds': 1, 'eval_metric': 'rmse'} + self.reg_df_test_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0.001068, 0.00088, 0.03155), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.000055, 0.9991, 0.9887) + ], ["features", "expected_prediction_with_weight", "expected_prediction_with_eval", + "expected_prediction_with_weight_and_eval"]) + self.reg_with_eval_best_score = 2.0 + self.reg_with_eval_and_weight_best_score = 1.997 + + # Test classifier with weight and eval set + # >>> import numpy as np + # >>> import xgboost + # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) + # >>> w = np.array([1.0, 2.0, 1.0, 2.0]) + # >>> y = np.array([0, 1, 0, 1]) + # >>> cls1 = xgboost.XGBClassifier() + # >>> cls1.fit(X, y, sample_weight=w) + # >>> cls1.predict_proba(X) + # array([[0.3333333, 0.6666667],... + # >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + # >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) + # >>> y_train = np.array([0, 1]) + # >>> y_val = np.array([0, 1]) + # >>> w_train = np.array([1.0, 2.0]) + # >>> w_val = np.array([1.0, 2.0]) + # >>> cls2 = xgboost.XGBClassifier() + # >>> cls2.fit(X_train, y_train, eval_set=[(X_val, y_val)], + # >>> early_stopping_rounds=1, eval_metric='logloss') + # >>> cls2.predict_proba(X) + # array([[0.5, 0.5],... + # >>> cls2.best_score + # 0.6931 + # >>> cls3 = xgboost.XGBClassifier() + # >>> cls3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)], + # >>> sample_weight_eval_set=[w_val], + # >>> early_stopping_rounds=1, eval_metric='logloss') + # >>> cls3.predict_proba(X) + # array([[0.3344962, 0.6655038],... + # >>> cls3.best_score + # 0.6365 + self.cls_df_train_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), + (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), + ], ["features", "label", "isVal", "weight"]) + self.cls_params_with_eval = {'validationIndicatorCol': 'isVal', + 'early_stopping_rounds': 1, 'eval_metric': 'logloss'} + self.cls_df_test_with_eval_weight = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], [0.5, 0.5], [0.3097, 0.6903]), + ], ["features", "expected_prob_with_weight", "expected_prob_with_eval", + "expected_prob_with_weight_and_eval"]) + self.cls_with_eval_best_score = 0.6931 + self.cls_with_eval_and_weight_best_score = 0.6378 + + # Test classifier with both base margin and without + # >>> import numpy as np + # >>> import xgboost + # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) + # >>> w = np.array([1.0, 2.0, 1.0, 2.0]) + # >>> y = np.array([0, 1, 0, 1]) + # >>> base_margin = np.array([1,0,0,1]) + # + # This is without the base margin + # >>> cls1 = xgboost.XGBClassifier() + # >>> cls1.fit(X, y, sample_weight=w) + # >>> cls1.predict_proba(np.array([[1.0, 2.0, 3.0]])) + # array([[0.3333333, 0.6666667]], dtype=float32) + # >>> cls1.predict(np.array([[1.0, 2.0, 3.0]])) + # array([1]) + # + # This is with the same base margin for predict + # >>> cls2 = xgboost.XGBClassifier() + # >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin) + # >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[0]) + # array([[0.44142532, 0.5585747 ]], dtype=float32) + # >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0]) + # array([1]) + # + # This is with a different base margin for predict + # # >>> cls2 = xgboost.XGBClassifier() + # >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin) + # >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[1]) + # array([[0.2252, 0.7747 ]], dtype=float32) + # >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0]) + # array([1]) + self.cls_df_train_without_base_margin = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0), + (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0), + ], ["features", "label", "weight"]) + self.cls_df_test_without_base_margin = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], 1), + ], ["features", "expected_prob_without_base_margin", "expected_prediction_without_base_margin"]) + + self.cls_df_train_with_same_base_margin = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0), + (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), + ], ["features", "label", "weight", "baseMarginCol"]) + self.cls_df_test_with_same_base_margin = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, [0.4415, 0.5585], 1), + ], ["features", "baseMarginCol", "expected_prob_with_base_margin", "expected_prediction_with_base_margin"]) + + self.cls_df_train_with_different_base_margin = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0), + (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), + ], ["features", "label", "weight", "baseMarginCol"]) + self.cls_df_test_with_different_base_margin = self.session.createDataFrame([ + (Vectors.dense(1.0, 2.0, 3.0), 1, [0.2252, 0.7747], 1), + ], ["features", "baseMarginCol", "expected_prob_with_base_margin", "expected_prediction_with_base_margin"]) + + def get_local_tmp_dir(self): + return "/tmp/xgboost_local_test/" + str(uuid.uuid4()) + + def test_regressor_params_basic(self): + py_reg = XgboostRegressor() + self.assertTrue(hasattr(py_reg, 'n_estimators')) + self.assertEqual(py_reg.n_estimators.parent, py_reg.uid) + self.assertFalse(hasattr(py_reg, 'gpu_id')) + self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100) + self.assertEqual(py_reg._get_xgb_model_creator()().n_estimators, 100) + py_reg2 = XgboostRegressor(n_estimators=200) + self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200) + self.assertEqual(py_reg2._get_xgb_model_creator()().n_estimators, 200) + py_reg3 = py_reg2.copy({py_reg2.max_depth: 10}) + self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200) + self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10) + + def test_classifier_params_basic(self): + py_cls = XgboostClassifier() + self.assertTrue(hasattr(py_cls, 'n_estimators')) + self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) + self.assertFalse(hasattr(py_cls, 'gpu_id')) + self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100) + self.assertEqual(py_cls._get_xgb_model_creator()().n_estimators, 100) + py_cls2 = XgboostClassifier(n_estimators=200) + self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200) + self.assertEqual(py_cls2._get_xgb_model_creator()().n_estimators, 200) + py_cls3 = py_cls2.copy({py_cls2.max_depth: 10}) + self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200) + self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10) + + def test_classifier_kwargs_basic(self): + py_cls = XgboostClassifier(**self.cls_params_kwargs) + self.assertTrue(hasattr(py_cls, 'n_estimators')) + self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) + self.assertFalse(hasattr(py_cls, 'gpu_id')) + self.assertTrue(hasattr(py_cls, 'arbitraryParamsDict')) + expected_kwargs = {'sketch_eps':0.03} + self.assertEqual(py_cls.getOrDefault(py_cls.arbitraryParamsDict), expected_kwargs) + self.assertTrue("sketch_eps" in py_cls._get_xgb_model_creator()().get_params()) + # We want all of the new params to be in the .get_params() call and be an attribute of py_cls, but not of the actual model + self.assertTrue("arbitraryParamsDict" not in py_cls._get_xgb_model_creator()().get_params()) + + # Testing overwritten params + py_cls = XgboostClassifier() + py_cls.setParams(x=1, y=2) + py_cls.setParams(y=1, z=2) + self.assertTrue("x" in py_cls._get_xgb_model_creator()().get_params()) + self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["x"], 1) + self.assertTrue("y" in py_cls._get_xgb_model_creator()().get_params()) + self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["y"], 1) + self.assertTrue("z" in py_cls._get_xgb_model_creator()().get_params()) + self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["z"], 2) + + @staticmethod + def test_param_value_converter(): + py_cls = XgboostClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3)) + # don't check by isintance(v, float) because for numpy scalar it will also return True + assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == 'float' + assert py_cls.getOrDefault(py_cls.arbitraryParamsDict)['sketch_eps'].__class__.__name__ \ + == 'float64' + + def test_regressor_basic(self): + regressor = XgboostRegressor() + model = regressor.fit(self.reg_df_train) + pred_result = model.transform(self.reg_df_test).collect() + for row in pred_result: + self.assertTrue(np.isclose(row.prediction, row.expected_prediction, atol=1e-3)) + + def test_classifier_basic(self): + classifier = XgboostClassifier() + model = classifier.fit(self.cls_df_train) + pred_result = model.transform(self.cls_df_test).collect() + for row in pred_result: + self.assertEqual(row.prediction, row.expected_prediction) + self.assertTrue(np.allclose(row.probability, row.expected_probability, rtol=1e-3)) + + def test_multi_classifier(self): + classifier = XgboostClassifier() + model = classifier.fit(self.multi_cls_df_train) + pred_result = model.transform(self.multi_cls_df_test).collect() + for row in pred_result: + self.assertTrue(np.allclose(row.probability, row.expected_probability, rtol=1e-3)) + + def _check_sub_dict_match(self, sub_dist, whole_dict): + for k in sub_dist: + self.assertTrue(k in whole_dict) + self.assertEqual(sub_dist[k], whole_dict[k]) + + def test_regressor_with_params(self): + regressor = XgboostRegressor(**self.reg_params) + all_params = dict(**(regressor._gen_xgb_params_dict()), + **(regressor._gen_fit_params_dict()), + **(regressor._gen_predict_params_dict())) + self._check_sub_dict_match(self.reg_params, all_params) + + model = regressor.fit(self.reg_df_train) + all_params = dict(**(model._gen_xgb_params_dict()), + **(model._gen_fit_params_dict()), + **(model._gen_predict_params_dict())) + self._check_sub_dict_match(self.reg_params, all_params) + pred_result = model.transform(self.reg_df_test).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_params, atol=1e-3) + ) + + def test_classifier_with_params(self): + classifier = XgboostClassifier(**self.cls_params) + all_params = dict(**(classifier._gen_xgb_params_dict()), + **(classifier._gen_fit_params_dict()), + **(classifier._gen_predict_params_dict())) + self._check_sub_dict_match(self.cls_params, all_params) + + model = classifier.fit(self.cls_df_train) + all_params = dict(**(model._gen_xgb_params_dict()), + **(model._gen_fit_params_dict()), + **(model._gen_predict_params_dict())) + self._check_sub_dict_match(self.cls_params, all_params) + pred_result = model.transform(self.cls_df_test).collect() + for row in pred_result: + self.assertEqual(row.prediction, row.expected_prediction_with_params) + self.assertTrue(np.allclose(row.probability, row.expected_probability_with_params, rtol=1e-3)) + + def test_regressor_model_save_load(self): + path = 'file:' + self.get_local_tmp_dir() + regressor = XgboostRegressor(**self.reg_params) + model = regressor.fit(self.reg_df_train) + model.save(path) + loaded_model = XgboostRegressorModel.load(path) + self.assertEqual(model.uid, loaded_model.uid) + for k, v in self.reg_params.items(): + self.assertEqual(loaded_model.getOrDefault(k), v) + + pred_result = loaded_model.transform(self.reg_df_test).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, row.expected_prediction_with_params, atol=1e-3)) + + with self.assertRaisesRegex(AssertionError, 'Expected class name'): + XgboostClassifierModel.load(path) + + def test_classifier_model_save_load(self): + path = 'file:' + self.get_local_tmp_dir() + regressor = XgboostClassifier(**self.cls_params) + model = regressor.fit(self.cls_df_train) + model.save(path) + loaded_model = XgboostClassifierModel.load(path) + self.assertEqual(model.uid, loaded_model.uid) + for k, v in self.cls_params.items(): + self.assertEqual(loaded_model.getOrDefault(k), v) + + pred_result = loaded_model.transform(self.cls_df_test).collect() + for row in pred_result: + self.assertTrue( + np.allclose(row.probability, row.expected_probability_with_params, atol=1e-3)) + + with self.assertRaisesRegex(AssertionError, 'Expected class name'): + XgboostRegressorModel.load(path) + + @staticmethod + def _get_params_map(params_kv, estimator): + return {getattr(estimator, k): v for k, v in params_kv.items()} + + def test_regressor_model_pipeline_save_load(self): + path = 'file:' + self.get_local_tmp_dir() + regressor = XgboostRegressor() + pipeline = Pipeline(stages=[regressor]) + pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor)) + model = pipeline.fit(self.reg_df_train) + model.save(path) + + loaded_model = PipelineModel.load(path) + for k, v in self.reg_params.items(): + self.assertEqual(loaded_model.stages[0].getOrDefault(k), v) + + pred_result = loaded_model.transform(self.reg_df_test).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, row.expected_prediction_with_params, atol=1e-3)) + + def test_classifier_model_pipeline_save_load(self): + path = 'file:' + self.get_local_tmp_dir() + classifier = XgboostClassifier() + pipeline = Pipeline(stages=[classifier]) + pipeline = pipeline.copy(extra=self._get_params_map(self.cls_params, classifier)) + model = pipeline.fit(self.cls_df_train) + model.save(path) + + loaded_model = PipelineModel.load(path) + for k, v in self.cls_params.items(): + self.assertEqual(loaded_model.stages[0].getOrDefault(k), v) + + pred_result = loaded_model.transform(self.cls_df_test).collect() + for row in pred_result: + self.assertTrue( + np.allclose(row.probability, row.expected_probability_with_params, atol=1e-3)) + + def test_classifier_with_cross_validator(self): + xgb_classifer = XgboostClassifier() + paramMaps = ParamGridBuilder().addGrid(xgb_classifer.max_depth, [1, 2]).build() + cvBin = CrossValidator(estimator=xgb_classifer, estimatorParamMaps=paramMaps, + evaluator=BinaryClassificationEvaluator(), seed=1) + cvBinModel = cvBin.fit(self.cls_df_train_large) + cvBinModel.transform(self.cls_df_test) + cvMulti = CrossValidator(estimator=xgb_classifer, estimatorParamMaps=paramMaps, + evaluator=MulticlassClassificationEvaluator(), seed=1) + cvMultiModel = cvMulti.fit(self.multi_cls_df_train_large) + cvMultiModel.transform(self.multi_cls_df_test) + + def test_callbacks(self): + from xgboost.callback import LearningRateScheduler + + path = self.get_local_tmp_dir() + + def custom_learning_rate(boosting_round): + return 1.0 / (boosting_round + 1) + + cb = [LearningRateScheduler(custom_learning_rate)] + regressor = XgboostRegressor(callbacks=cb) + + # Test the save/load of the estimator instead of the model, since + # the callbacks param only exists in the estimator but not in the model + regressor.save(path) + regressor = XgboostRegressor.load(path) + + model = regressor.fit(self.reg_df_train) + pred_result = model.transform(self.reg_df_test).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, + row.expected_prediction_with_callbacks, atol=1e-3) + ) + + def test_train_with_initial_model(self): + path = self.get_local_tmp_dir() + reg1 = XgboostRegressor(**self.reg_params) + model = reg1.fit(self.reg_df_train) + init_booster = model.get_booster() + reg2 = XgboostRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster) + model21 = reg2.fit(self.reg_df_train) + pred_res21 = model21.transform(self.reg_df_test).collect() + reg2.save(path) + reg2 = XgboostRegressor.load(path) + self.assertTrue(reg2.getOrDefault(reg2.xgb_model) is not None) + model22 = reg2.fit(self.reg_df_train) + pred_res22 = model22.transform(self.reg_df_test).collect() + # Test the transform result is the same for original and loaded model + for row1, row2 in zip(pred_res21, pred_res22): + self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) + + def test_classifier_with_base_margin(self): + cls_without_base_margin = XgboostClassifier(weightCol = "weight") + model_without_base_margin = cls_without_base_margin.fit(self.cls_df_train_without_base_margin) + pred_result_without_base_margin = model_without_base_margin.transform(self.cls_df_test_without_base_margin).collect() + for row in pred_result_without_base_margin: + self.assertTrue(np.isclose(row.prediction, + row.expected_prediction_without_base_margin, atol=1e-3)) + self.assertTrue(np.allclose(row.probability, + row.expected_prob_without_base_margin, atol=1e-3)) + + cls_with_same_base_margin = XgboostClassifier(weightCol = "weight", baseMarginCol = "baseMarginCol") + model_with_same_base_margin = cls_with_same_base_margin.fit(self.cls_df_train_with_same_base_margin) + pred_result_with_same_base_margin = model_with_same_base_margin.transform(self.cls_df_test_with_same_base_margin).collect() + for row in pred_result_with_same_base_margin: + self.assertTrue(np.isclose(row.prediction, + row.expected_prediction_with_base_margin, atol=1e-3)) + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_base_margin, atol=1e-3)) + + cls_with_different_base_margin = XgboostClassifier(weightCol = "weight", baseMarginCol = "baseMarginCol") + model_with_different_base_margin = cls_with_different_base_margin.fit(self.cls_df_train_with_different_base_margin) + pred_result_with_different_base_margin = model_with_different_base_margin.transform(self.cls_df_test_with_different_base_margin).collect() + for row in pred_result_with_different_base_margin: + self.assertTrue(np.isclose(row.prediction, + row.expected_prediction_with_base_margin, atol=1e-3)) + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_base_margin, atol=1e-3)) + + def test_regressor_with_weight_eval(self): + # with weight + regressor_with_weight = XgboostRegressor(weightCol='weight') + model_with_weight = regressor_with_weight.fit(self.reg_df_train_with_eval_weight) + pred_result_with_weight = model_with_weight \ + .transform(self.reg_df_test_with_eval_weight).collect() + for row in pred_result_with_weight: + self.assertTrue(np.isclose(row.prediction, + row.expected_prediction_with_weight, atol=1e-3)) + # with eval + regressor_with_eval = XgboostRegressor(**self.reg_params_with_eval) + model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight) + self.assertTrue(np.isclose(model_with_eval._xgb_sklearn_model.best_score, + self.reg_with_eval_best_score, atol=1e-3), + f"Expected best score: {self.reg_with_eval_best_score}, " + f"but get {model_with_eval._xgb_sklearn_model.best_score}") + pred_result_with_eval = model_with_eval \ + .transform(self.reg_df_test_with_eval_weight).collect() + for row in pred_result_with_eval: + self.assertTrue(np.isclose(row.prediction, + row.expected_prediction_with_eval, atol=1e-3), + f"Expect prediction is {row.expected_prediction_with_eval}," + f"but get {row.prediction}") + # with weight and eval + regressor_with_weight_eval = XgboostRegressor( + weightCol='weight', **self.reg_params_with_eval) + model_with_weight_eval = regressor_with_weight_eval.fit(self.reg_df_train_with_eval_weight) + pred_result_with_weight_eval = model_with_weight_eval \ + .transform(self.reg_df_test_with_eval_weight).collect() + self.assertTrue(np.isclose(model_with_weight_eval._xgb_sklearn_model.best_score, + self.reg_with_eval_and_weight_best_score, atol=1e-3)) + for row in pred_result_with_weight_eval: + self.assertTrue(np.isclose(row.prediction, + row.expected_prediction_with_weight_and_eval, atol=1e-3)) + + def test_classifier_with_weight_eval(self): + # with weight + classifier_with_weight = XgboostClassifier(weightCol='weight') + model_with_weight = classifier_with_weight.fit(self.cls_df_train_with_eval_weight) + pred_result_with_weight = model_with_weight \ + .transform(self.cls_df_test_with_eval_weight).collect() + for row in pred_result_with_weight: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_weight, atol=1e-3)) + # with eval + classifier_with_eval = XgboostClassifier(**self.cls_params_with_eval) + model_with_eval = classifier_with_eval.fit(self.cls_df_train_with_eval_weight) + self.assertTrue(np.isclose(model_with_eval._xgb_sklearn_model.best_score, + self.cls_with_eval_best_score, atol=1e-3)) + pred_result_with_eval = model_with_eval \ + .transform(self.cls_df_test_with_eval_weight).collect() + for row in pred_result_with_eval: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_eval, atol=1e-3)) + # with weight and eval + # Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which + # doesn't really indicate this working correctly. + classifier_with_weight_eval = XgboostClassifier( + weightCol='weight', scale_pos_weight=4, **self.cls_params_with_eval) + model_with_weight_eval = classifier_with_weight_eval \ + .fit(self.cls_df_train_with_eval_weight) + pred_result_with_weight_eval = model_with_weight_eval \ + .transform(self.cls_df_test_with_eval_weight).collect() + self.assertTrue(np.isclose(model_with_weight_eval._xgb_sklearn_model.best_score, + self.cls_with_eval_and_weight_best_score, atol=1e-3)) + for row in pred_result_with_weight_eval: + self.assertTrue(np.allclose(row.probability, + row.expected_prob_with_weight_and_eval, atol=1e-3)) + + def test_num_workers_param(self): + regressor = XgboostRegressor(num_workers=-1) + self.assertRaises(ValueError, regressor._validate_params) + classifier = XgboostClassifier(num_workers=0) + self.assertRaises(ValueError, classifier._validate_params) + + def test_use_gpu_param(self): + classifier = XgboostClassifier(use_gpu=True, tree_method="exact") + self.assertRaises(ValueError, classifier._validate_params) + regressor = XgboostRegressor(use_gpu=True, tree_method="exact") + self.assertRaises(ValueError, regressor._validate_params) + regressor = XgboostRegressor(use_gpu=True, tree_method="gpu_hist") + regressor = XgboostRegressor(use_gpu=True) + classifier = XgboostClassifier(use_gpu=True, tree_method="gpu_hist") + classifier = XgboostClassifier(use_gpu=True) + + def test_convert_to_model(self): + classifier = XgboostClassifier() + clf_model = classifier.fit(self.cls_df_train) + + regressor = XgboostRegressor() + reg_model = regressor.fit(self.reg_df_train) + + # Check that regardless of what booster, _convert_to_model converts to the correct class type + self.assertEqual(type(classifier._convert_to_model(clf_model.get_booster())), XGBClassifier) + self.assertEqual(type(classifier._convert_to_model(reg_model.get_booster())), XGBClassifier) + self.assertEqual(type(regressor._convert_to_model(clf_model.get_booster())), XGBRegressor) + self.assertEqual(type(regressor._convert_to_model(reg_model.get_booster())), XGBRegressor) + + def test_feature_importances(self): + reg1 = XgboostRegressor(**self.reg_params) + model = reg1.fit(self.reg_df_train) + booster = model.get_booster() + self.assertEqual(model.get_feature_importances(), booster.get_score()) + self.assertEqual( + model.get_feature_importances(importance_type='gain'), + booster.get_score(importance_type='gain') + ) + From a04a1d09891c9a6272e385b69116e8d5cb43b907 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 26 Jun 2022 21:59:22 +0800 Subject: [PATCH 02/73] fix Signed-off-by: Weichen Xu --- python-package/xgboost/spark/utils.py | 4 ++-- tests/python/test_spark/discover_gpu.sh | 3 +++ tests/python/test_spark/utils_test.py | 14 +++++++++----- .../test_spark/xgboost_local_cluster_test.py | 9 +++++---- 4 files changed, 19 insertions(+), 11 deletions(-) create mode 100755 tests/python/test_spark/discover_gpu.sh diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 940fbab8d322..cc69f7682adb 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -87,8 +87,8 @@ def _start_tracker(context, n_workers): """ env = {'DMLC_NUM_WORKER': n_workers} host = get_host_ip(context) - rabit_context = RabitTracker(hostIP=host, nslave=n_workers) - env.update(rabit_context.slave_envs()) + rabit_context = RabitTracker(host_ip=host, n_workers=n_workers) + env.update(rabit_context.worker_envs()) rabit_context.start(n_workers) thread = Thread(target=rabit_context.join) thread.daemon = True diff --git a/tests/python/test_spark/discover_gpu.sh b/tests/python/test_spark/discover_gpu.sh new file mode 100755 index 000000000000..42dd0551784d --- /dev/null +++ b/tests/python/test_spark/discover_gpu.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}" diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index ca94c99a4068..dd1eaa60dd73 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -80,10 +80,6 @@ def setup_env(cls, spark_config): for k, v in spark_config.items(): builder.config(k, v) spark = builder.getOrCreate() - if spark_config['spark.master'].startswith('local-cluster'): - # We run a dummy job so that we block until the workers have connected to the master - spark.sparkContext.parallelize(range(2), 2).barrier().mapPartitions(lambda _: []).collect() - logging.getLogger('pyspark').setLevel(logging.INFO) cls.sc = spark.sparkContext @@ -118,10 +114,18 @@ class SparkLocalClusterTestCase(TestSparkContext, TestTempDir, unittest.TestCase @classmethod def setUpClass(cls): cls.setup_env({ - 'spark.master': 'local-cluster[2, 1, 1024]', + 'spark.master': 'local-cluster[2, 2, 1024]', 'spark.python.worker.reuse': 'false', + 'spark.cores.max': '4', + 'spark.task.cpus': '1', + 'spark.executor.cores': '2', + 'spark.worker.resource.gpu.amount': '4', + 'spark.task.resource.gpu.amount': '2', + 'spark.executor.resource.gpu.amount': '4', + 'spark.worker.resource.gpu.discoveryScript': 'test_spark/discover_gpu.sh' }) cls.make_tempdir() + # We run a dummy job so that we block until the workers have connected to the master cls.sc.parallelize(range(4), 4).barrier().mapPartitions(lambda _: []).collect() @classmethod diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index aee638fddb87..4a8810fcc1f8 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -1,4 +1,5 @@ import random +import unittest import numpy as np from pyspark.ml.linalg import Vectors @@ -207,13 +208,14 @@ def test_regressor_distributed_external_storage_basic(self): self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) - def check_use_gpu_param(self): + @unittest.skip + def test_check_use_gpu_param(self): # Classifier classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_gpu=True, use_external_storage=False) self.assertTrue(hasattr(classifier, 'use_gpu')) self.assertTrue(classifier.getOrDefault(classifier.use_gpu)) clf_model = classifier.fit(self.cls_df_train_distributed) - pred_result = model.transform(self.cls_df_test_distributed).collect() + pred_result = clf_model.transform(self.cls_df_test_distributed).collect() for row in pred_result: self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) @@ -225,8 +227,7 @@ def check_use_gpu_param(self): model = regressor.fit(self.reg_df_train_distributed) pred_result = model.transform(self.reg_df_test_distributed).collect() for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) + self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) def test_classifier_distributed_weight_eval(self): # with weight From d2dbb8dbe337aa25fd1a5aa96c55376c75c4ceb5 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 26 Jun 2022 22:08:26 +0800 Subject: [PATCH 03/73] clean Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 48d4c75a4212..7d38dca348fe 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -37,7 +37,7 @@ ] _unsupported_xgb_params = [ - 'gpu_id', # [ML-12862] + 'gpu_id', # we have "use_gpu" pyspark param instead. ] _unsupported_fit_params = { 'sample_weight', # Supported by spark param weightCol @@ -47,10 +47,10 @@ 'base_margin' # Supported by spark param baseMarginCol } _unsupported_predict_params = { - # [ML-12913], for classification, we can use rawPrediction as margin + # for classification, we can use rawPrediction as margin 'output_margin', - 'validate_features', # [ML-12923] - 'base_margin' # [ML-12689] + 'validate_features', # TODO + 'base_margin' # TODO } _created_params = {"num_workers", "use_gpu"} From 34827efae8f7083420dcbed0d0b2aee3fd6c85c3 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 26 Jun 2022 22:32:35 +0800 Subject: [PATCH 04/73] remove external mode Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 30 +----- python-package/xgboost/spark/data.py | 25 +---- python-package/xgboost/spark/estimator.py | 8 -- tests/python/test_spark/data_test.py | 12 +-- .../test_spark/xgboost_local_cluster_test.py | 102 +++--------------- 5 files changed, 24 insertions(+), 153 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 7d38dca348fe..6a539a1dda7a 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -74,19 +74,6 @@ class _XgboostParams(HasFeaturesCol, HasLabelCol, HasWeightCol, "want to force the input dataset to be repartitioned before XGBoost training." + "Note: The auto repartitioning judgement is not fully accurate, so it is recommended" + "to have force_repartition be True.") - use_external_storage = Param( - Params._dummy(), "use_external_storage", - "A boolean variable (that is False by default). External storage is a parameter" + - "for distributed training that allows external storage (disk) to be used when." + - "you have an exceptionally large dataset. This should be set to false for" + - "small datasets. Note that base margin and weighting doesn't work if this is True." + - "Also note that you may use precision if you use external storage." - ) - external_storage_precision = Param( - Params._dummy(), "external_storage_precision", - "The number of significant digits for data storage on disk when using external storage.", - TypeConverters.toInt - ) @classmethod def _xgb_cls(cls): @@ -137,8 +124,6 @@ def _set_distributed_params(self): self.set(self.num_workers, 1) self.set(self.use_gpu, False) self.set(self.force_repartition, False) - self.set(self.use_external_storage, False) - self.set(self.external_storage_precision, 5) # Check if this needs to be modified # Parameters for xgboost.XGBModel().fit() @classmethod @@ -386,21 +371,16 @@ def _train_booster(pandas_df_iter): from pyspark import BarrierTaskContext context = BarrierTaskContext.get() - use_external_storage = self.getOrDefault(self.use_external_storage) - external_storage_precision = self.getOrDefault(self.external_storage_precision) - external_storage_path_prefix = None - if use_external_storage: - external_storage_path_prefix = tempfile.mkdtemp() dtrain, dval = None, [] if has_validation: dtrain, dval = convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation, - use_external_storage, external_storage_path_prefix, external_storage_precision) + pandas_df_iter, has_weight, has_validation + ) dval = [(dtrain, "training"), (dval, "validation")] else: dtrain = convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation, - use_external_storage, external_storage_path_prefix, external_storage_precision) + pandas_df_iter, has_weight, has_validation + ) booster_params, kwargs_params = self._get_dist_booster_params( train_params) @@ -420,8 +400,6 @@ def _train_booster(pandas_df_iter): **kwargs_params) context.barrier() - if use_external_storage: - shutil.rmtree(external_storage_path_prefix) if context.partitionId() == 0: yield pd.DataFrame( data={'booster_bytes': [cloudpickle.dumps(booster)]}) diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index b0ca24c7e36c..ae8e5c25309f 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -164,6 +164,7 @@ def _row_tuple_list_to_feature_matrix_y_w(data_iterator, train, has_weight, expected_feature_dims = _check_feature_dims(num_feature_dims, expected_feature_dims) + # TODO: Improve performance, avoid use python list values_list.append(pdf["values"].to_list()) if train: label_list.append(pdf["label"].to_list()) @@ -226,29 +227,7 @@ def _process_data_iter(data_iterator: Iterator[pd.DataFrame], def convert_partition_data_to_dmatrix(partition_data_iter, has_weight, - has_validation, - use_external_storage=False, - file_prefix=None, - external_storage_precision=5): - # if we are using external storage, we use a different approach for making the dmatrix - if use_external_storage: - if has_validation: - train_file, validation_file = _stream_data_into_libsvm_file( - partition_data_iter, has_weight, - has_validation, file_prefix, external_storage_precision) - training_dmatrix = _create_dmatrix_from_file( - train_file, "{}/train.cache".format(file_prefix)) - val_dmatrix = _create_dmatrix_from_file( - validation_file, "{}/val.cache".format(file_prefix)) - return training_dmatrix, val_dmatrix - else: - train_file = _stream_data_into_libsvm_file( - partition_data_iter, has_weight, - has_validation, file_prefix, external_storage_precision) - training_dmatrix = _create_dmatrix_from_file( - train_file, "{}/train.cache".format(file_prefix)) - return training_dmatrix - + has_validation): # if we are not using external storage, we use the standard method of parsing data. train_val_data = prepare_train_val_data(partition_data_iter, has_weight, has_validation) if has_validation: diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index d54903e69fec..509b5f7c3fa7 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -43,10 +43,6 @@ class XgboostRegressor(_XgboostEstimator): Each XGBoost worker corresponds to one spark task. :param use_gpu: Boolean that specifies whether the executors are running on GPU instances. - :param use_external_storage: Boolean that specifices whether you want to use - external storage when training in a distributed manner. This allows using disk - as cache. Setting this to true is useful when you want better memory utilization - but is not needed for small test datasets. :param baseMarginCol: To specify the base margins of the training and validation dataset, set :py:attr:`xgboost.spark.XgboostRegressor.baseMarginCol` parameter instead of setting `base_margin` and `base_margin_eval_set` in the @@ -141,10 +137,6 @@ class XgboostClassifier(_XgboostEstimator, HasProbabilityCol, Each XGBoost worker corresponds to one spark task. :param use_gpu: Boolean that specifies whether the executors are running on GPU instances. - :param use_external_storage: Boolean that specifices whether you want to use - external storage when training in a distributed manner. This allows using disk - as cache. Setting this to true is useful when you want better memory utilization - but is not needed for small test datasets. :param baseMarginCol: To specify the base margins of the training and validation dataset, set :py:attr:`xgboost.spark.XgboostClassifier.baseMarginCol` parameter instead of setting `base_margin` and `base_margin_eval_set` in the diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 014a394eed21..29b0108ae9ba 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -92,9 +92,9 @@ def test_external_storage(self): # Creating the dmatrix based on storage temporary_path = tempfile.mkdtemp() - storage_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=False, - has_validation=False, use_external_storage=True, - file_prefix=temporary_path) + storage_dmatrix = convert_partition_data_to_dmatrix( + [pd.DataFrame(data)], has_weight=False, has_validation=False + ) # Testing without weights normal_booster = worker_train({}, normal_dmatrix) @@ -110,9 +110,9 @@ def test_external_storage(self): data["weight"] = [0.2, 0.8] * 100 temporary_path = tempfile.mkdtemp() - storage_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=True, - has_validation=False, use_external_storage=True, - file_prefix=temporary_path) + storage_dmatrix = convert_partition_data_to_dmatrix( + [pd.DataFrame(data)], has_weight=True, has_validation=False + ) normal_booster = worker_train({}, normal_dmatrix) storage_booster = worker_train({}, storage_dmatrix) diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index 4a8810fcc1f8..fb25f3bbf3a3 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -165,16 +165,7 @@ def custom_learning_rate(boosting_round): ) def test_classifier_distributed_basic(self): - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False) - model = classifier.fit(self.cls_df_train_distributed) - pred_result = model.transform(self.cls_df_test_distributed).collect() - for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) - self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) - - def test_classifier_distributed_external_storage_basic(self): - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True) + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100) model = classifier.fit(self.cls_df_train_distributed) pred_result = model.transform(self.cls_df_test_distributed).collect() for row in pred_result: @@ -184,7 +175,7 @@ def test_classifier_distributed_external_storage_basic(self): def test_classifier_distributed_multiclass(self): # There is no built-in multiclass option for external storage - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False) + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100) model = classifier.fit(self.cls_df_train_distributed_multiclass) pred_result = model.transform(self.cls_df_test_distributed_multiclass).collect() for row in pred_result: @@ -193,25 +184,16 @@ def test_classifier_distributed_multiclass(self): self.assertTrue(np.allclose(row.expected_margins, row.rawPrediction, atol=1e-3)) def test_regressor_distributed_basic(self): - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False) - model = regressor.fit(self.reg_df_train_distributed) - pred_result = model.transform(self.reg_df_test_distributed).collect() - for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) - - def test_regressor_distributed_external_storage_basic(self): - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True) + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100) model = regressor.fit(self.reg_df_train_distributed) pred_result = model.transform(self.reg_df_test_distributed).collect() for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) + self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) @unittest.skip def test_check_use_gpu_param(self): # Classifier - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_gpu=True, use_external_storage=False) + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_gpu=True) self.assertTrue(hasattr(classifier, 'use_gpu')) self.assertTrue(classifier.getOrDefault(classifier.use_gpu)) clf_model = classifier.fit(self.cls_df_train_distributed) @@ -221,7 +203,7 @@ def test_check_use_gpu_param(self): row.prediction, atol=1e-3)) self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_gpu=True, use_external_storage=False) + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_gpu=True) self.assertTrue(hasattr(regressor, 'use_gpu')) self.assertTrue(regressor.getOrDefault(regressor.use_gpu)) model = regressor.fit(self.reg_df_train_distributed) @@ -231,34 +213,7 @@ def test_check_use_gpu_param(self): def test_classifier_distributed_weight_eval(self): # with weight - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.clf_params_with_weight_dist) - model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() - for row in pred_result: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_weight, atol=1e-3)) - - # with eval only - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.clf_params_with_eval_dist) - model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() - for row in pred_result: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_eval) - - # with both weight and eval - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.clf_params_with_eval_dist, **self.clf_params_with_weight_dist) - model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() - for row in pred_result: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_weight_and_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_weight_and_eval) - - def test_classifier_distributed_weight_eval_external_storage(self): - # with weight - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.clf_params_with_weight_dist) + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_weight_dist) model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() for row in pred_result: @@ -266,7 +221,7 @@ def test_classifier_distributed_weight_eval_external_storage(self): row.expected_prob_with_weight, atol=1e-3)) # with eval only - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.clf_params_with_eval_dist) + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_eval_dist) model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() for row in pred_result: @@ -275,7 +230,7 @@ def test_classifier_distributed_weight_eval_external_storage(self): self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_eval) # with both weight and eval - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.clf_params_with_eval_dist, **self.clf_params_with_weight_dist) + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_eval_dist, **self.clf_params_with_weight_dist) model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() for row in pred_result: @@ -285,7 +240,7 @@ def test_classifier_distributed_weight_eval_external_storage(self): def test_regressor_distributed_weight_eval(self): # with weight - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.reg_params_with_weight_dist) + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, **self.reg_params_with_weight_dist) model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() for row in pred_result: @@ -293,7 +248,7 @@ def test_regressor_distributed_weight_eval(self): np.isclose(row.prediction, row.expected_prediction_with_weight, atol=1e-3)) # with eval only - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.reg_params_with_eval_dist) + regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, **self.reg_params_with_eval_dist) model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() for row in pred_result: @@ -311,36 +266,8 @@ def test_regressor_distributed_weight_eval(self): row.expected_prediction_with_weight_and_eval, atol=1e-3)) self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_weight_and_eval) - def test_regressor_distributed_weight_eval_external_storage(self): - # with weight - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.reg_params_with_weight_dist) - model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() - for row in pred_result: - self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_weight, atol=1e-3)) - # with eval only - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.reg_params_with_eval_dist) - model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() - for row in pred_result: - self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_eval) - # with both weight and eval - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=True, **self.reg_params_with_eval_dist, **self.reg_params_with_weight_dist) - model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() - for row in pred_result: - self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_weight_and_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_weight_and_eval) - def test_num_estimators(self): - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=10, use_external_storage=False) + classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=10) model = classifier.fit(self.cls_df_train_distributed) pred_result = model.transform(self.cls_df_test_distributed_lower_estimators).collect() print(pred_result) @@ -349,11 +276,6 @@ def test_num_estimators(self): row.prediction, atol=1e-3)) self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) - def test_missing_value_zero_with_external_storage(self): - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=10, use_external_storage=False, - missing=0.0) - classifier.fit(self.cls_df_train_distributed) - def test_distributed_params(self): classifier = XgboostClassifier(num_workers=self.n_workers, max_depth=7) model = classifier.fit(self.cls_df_train_distributed) From 50ebb1f7b8a65e8a1eedc2ab1665398c801d7ad7 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 26 Jun 2022 22:41:48 +0800 Subject: [PATCH 05/73] update doc style Signed-off-by: Weichen Xu --- python-package/xgboost/spark/estimator.py | 53 +++++++++++++++-------- python-package/xgboost/spark/model.py | 23 +++++++--- python-package/xgboost/spark/utils.py | 6 ++- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 509b5f7c3fa7..e8c8bbd5350b 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -23,27 +23,35 @@ class XgboostRegressor(_XgboostEstimator): XgboostRegressor doesn't support `validate_features` and `output_margin` param. - :param callbacks: The export and import of the callback functions are at best effort. + callbacks: + The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.XgboostRegressor.callbacks` param doc. - :param missing: The parameter `missing` in XgboostRegressor has different semantics with + missing: + The parameter `missing` in XgboostRegressor has different semantics with that in `xgboost.XGBRegressor`. For details, see :py:attr:`xgboost.spark.XgboostRegressor.missing` param doc. - :param validationIndicatorCol: For params related to `xgboost.XGBRegressor` training + validationIndicatorCol + For params related to `xgboost.XGBRegressor` training with evaluation dataset's supervision, set :py:attr:`xgboost.spark.XgboostRegressor.validationIndicatorCol` parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor` fit method. - :param weightCol: To specify the weight of the training and validation dataset, set + weightCol: + To specify the weight of the training and validation dataset, set :py:attr:`xgboost.spark.XgboostRegressor.weightCol` parameter instead of setting `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor` fit method. - :param xgb_model: Set the value to be the instance returned by + xgb_model: + Set the value to be the instance returned by :func:`xgboost.spark.XgboostRegressorModel.get_booster`. - :param num_workers: Integer that specifies the number of XGBoost workers to use. + num_workers: + Integer that specifies the number of XGBoost workers to use. Each XGBoost worker corresponds to one spark task. - :param use_gpu: Boolean that specifies whether the executors are running on GPU + use_gpu: + Boolean that specifies whether the executors are running on GPU instances. - :param baseMarginCol: To specify the base margins of the training and validation + baseMarginCol: + To specify the base margins of the training and validation dataset, set :py:attr:`xgboost.spark.XgboostRegressor.baseMarginCol` parameter instead of setting `base_margin` and `base_margin_eval_set` in the `xgboost.XGBRegressor` fit method. Note: this isn't available for distributed @@ -114,30 +122,41 @@ class XgboostClassifier(_XgboostEstimator, HasProbabilityCol, XgboostClassifier doesn't support `validate_features` and `output_margin` param. - :param callbacks: The export and import of the callback functions are at best effort. For + Parameters + ---------- + callbacks: + The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.XgboostClassifier.callbacks` param doc. - :param missing: The parameter `missing` in XgboostClassifier has different semantics with + missing: + The parameter `missing` in XgboostClassifier has different semantics with that in `xgboost.XGBClassifier`. For details, see :py:attr:`xgboost.spark.XgboostClassifier.missing` param doc. - :param rawPredictionCol: The `output_margin=True` is implicitly supported by the + rawPredictionCol: + The `output_margin=True` is implicitly supported by the `rawPredictionCol` output column, which is always returned with the predicted margin values. - :param validationIndicatorCol: For params related to `xgboost.XGBClassifier` training with + validationIndicatorCol: + For params related to `xgboost.XGBClassifier` training with evaluation dataset's supervision, set :py:attr:`xgboost.spark.XgboostClassifier.validationIndicatorCol` parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier` fit method. - :param weightCol: To specify the weight of the training and validation dataset, set + weightCol: + To specify the weight of the training and validation dataset, set :py:attr:`xgboost.spark.XgboostClassifier.weightCol` parameter instead of setting `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier` fit method. - :param xgb_model: Set the value to be the instance returned by + xgb_model: + Set the value to be the instance returned by :func:`xgboost.spark.XgboostClassifierModel.get_booster`. - :param num_workers: Integer that specifies the number of XGBoost workers to use. + num_workers: + Integer that specifies the number of XGBoost workers to use. Each XGBoost worker corresponds to one spark task. - :param use_gpu: Boolean that specifies whether the executors are running on GPU + use_gpu: + Boolean that specifies whether the executors are running on GPU instances. - :param baseMarginCol: To specify the base margins of the training and validation + baseMarginCol: + To specify the base margins of the training and validation dataset, set :py:attr:`xgboost.spark.XgboostClassifier.baseMarginCol` parameter instead of setting `base_margin` and `base_margin_eval_set` in the `xgboost.XGBClassifier` fit method. Note: this isn't available for distributed diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 1edb44635490..31994061b146 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -17,8 +17,13 @@ def get_xgb_model_creator(model_cls, xgb_params): Returns a function that can be used to create an xgboost.XGBModel instance. This function is used for creating the model instance on the worker, and is shared by _XgboostEstimator and XgboostModel. - :param model_cls: a subclass of xgboost.XGBModel - :param xgb_params: a dict of params to initialize the model_cls + + Parameters + ---------- + model_cls: + a subclass of xgboost.XGBModel + xgb_params: + a dict of params to initialize the model_cls """ return lambda: model_cls(**xgb_params) # pylint: disable=W0108 @@ -34,8 +39,12 @@ def _get_or_create_tmp_dir(): def serialize_xgb_model(model): """ Serialize the input model to a string. - :param model: an xgboost.XGBModel instance, - such as xgboost.XGBClassifier or xgboost.XGBRegressor instance + + Parameters + ---------- + model: + an xgboost.XGBModel instance, such as + xgboost.XGBClassifier or xgboost.XGBRegressor instance """ # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') @@ -61,7 +70,11 @@ def deserialize_xgb_model(ser_model_string, xgb_model_creator): def serialize_booster(booster): """ Serialize the input booster to a string. - :param booster: an xgboost.core.Booster instance + + Parameters + ---------- + booster: + an xgboost.core.Booster instance """ # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index cc69f7682adb..6f4be9ec838b 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -144,7 +144,11 @@ def _getConfBoolean(sqlContext, key, defaultValue): or return the default value if the conf is not set. This expects the conf value to be a boolean or string; if the value is a string, this checks for all capitalization patterns of "true" and "false" to match Scala. - :param key: string for conf name + + Parameters + ---------- + key: + string for conf name """ # Convert default value to str to avoid a Spark 2.3.1 + Python 3 bug: SPARK-25397 val = sqlContext.getConf(key, str(defaultValue)) From f386ee18b374963f8e8c08bea8194aaf881bf44f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 26 Jun 2022 22:43:30 +0800 Subject: [PATCH 06/73] black Signed-off-by: Weichen Xu --- python-package/xgboost/spark/__init__.py | 17 +- python-package/xgboost/spark/core.py | 499 ++++++++------ python-package/xgboost/spark/data.py | 163 +++-- python-package/xgboost/spark/estimator.py | 13 +- python-package/xgboost/spark/model.py | 106 +-- python-package/xgboost/spark/utils.py | 59 +- tests/python/test_spark/data_test.py | 75 ++- tests/python/test_spark/utils_test.py | 51 +- .../test_spark/xgboost_local_cluster_test.py | 382 +++++++---- tests/python/test_spark/xgboost_local_test.py | 626 ++++++++++++------ 10 files changed, 1302 insertions(+), 689 deletions(-) diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py index 6e58401aa34f..06e3499e65e6 100644 --- a/python-package/xgboost/spark/__init__.py +++ b/python-package/xgboost/spark/__init__.py @@ -8,9 +8,16 @@ except ImportError: raise RuntimeError("xgboost spark python API requires pyspark package installed.") -from .estimator import (XgboostClassifier, XgboostClassifierModel, - XgboostRegressor, XgboostRegressorModel) - -__all__ = ['XgboostClassifier', 'XgboostClassifierModel', - 'XgboostRegressor', 'XgboostRegressorModel'] +from .estimator import ( + XgboostClassifier, + XgboostClassifierModel, + XgboostRegressor, + XgboostRegressorModel, +) +__all__ = [ + "XgboostClassifier", + "XgboostClassifierModel", + "XgboostRegressor", + "XgboostRegressorModel", +] diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 6a539a1dda7a..47f0d4e34904 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -5,8 +5,15 @@ import pandas as pd from scipy.special import expit, softmax from pyspark.ml import Estimator, Model -from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasWeightCol, \ - HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasValidationIndicatorCol +from pyspark.ml.param.shared import ( + HasFeaturesCol, + HasLabelCol, + HasWeightCol, + HasPredictionCol, + HasProbabilityCol, + HasRawPredictionCol, + HasValidationIndicatorCol, +) from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.util import MLReadable, MLWritable from pyspark.sql.functions import col, pandas_udf, countDistinct, struct @@ -17,63 +24,94 @@ import xgboost from xgboost.training import train as worker_train from .utils import get_logger, _get_max_num_concurrent_tasks -from .data import prepare_predict_data, prepare_train_val_data, convert_partition_data_to_dmatrix -from .model import (XgboostReader, XgboostWriter, XgboostModelReader, - XgboostModelWriter, deserialize_xgb_model, - get_xgb_model_creator, serialize_xgb_model) -from .utils import (_get_default_params_from_func, get_class_name, - HasArbitraryParamsDict, HasBaseMarginCol, RabitContext, - _get_rabit_args, _get_args_from_message_list, - _get_spark_session) +from .data import ( + prepare_predict_data, + prepare_train_val_data, + convert_partition_data_to_dmatrix, +) +from .model import ( + XgboostReader, + XgboostWriter, + XgboostModelReader, + XgboostModelWriter, + deserialize_xgb_model, + get_xgb_model_creator, + serialize_xgb_model, +) +from .utils import ( + _get_default_params_from_func, + get_class_name, + HasArbitraryParamsDict, + HasBaseMarginCol, + RabitContext, + _get_rabit_args, + _get_args_from_message_list, + _get_spark_session, +) from pyspark.ml.functions import array_to_vector, vector_to_array # Put pyspark specific params here, they won't be passed to XGBoost. # like `validationIndicatorCol`, `baseMarginCol` _pyspark_specific_params = [ - 'featuresCol', 'labelCol', 'weightCol', 'rawPredictionCol', - 'predictionCol', 'probabilityCol', 'validationIndicatorCol' - 'baseMarginCol' + "featuresCol", + "labelCol", + "weightCol", + "rawPredictionCol", + "predictionCol", + "probabilityCol", + "validationIndicatorCol" "baseMarginCol", ] _unsupported_xgb_params = [ - 'gpu_id', # we have "use_gpu" pyspark param instead. + "gpu_id", # we have "use_gpu" pyspark param instead. ] _unsupported_fit_params = { - 'sample_weight', # Supported by spark param weightCol + "sample_weight", # Supported by spark param weightCol # Supported by spark param weightCol # and validationIndicatorCol - 'eval_set', - 'sample_weight_eval_set', - 'base_margin' # Supported by spark param baseMarginCol + "eval_set", + "sample_weight_eval_set", + "base_margin", # Supported by spark param baseMarginCol } _unsupported_predict_params = { # for classification, we can use rawPrediction as margin - 'output_margin', - 'validate_features', # TODO - 'base_margin' # TODO + "output_margin", + "validate_features", # TODO + "base_margin", # TODO } _created_params = {"num_workers", "use_gpu"} -class _XgboostParams(HasFeaturesCol, HasLabelCol, HasWeightCol, - HasPredictionCol, HasValidationIndicatorCol, - HasArbitraryParamsDict, HasBaseMarginCol): +class _XgboostParams( + HasFeaturesCol, + HasLabelCol, + HasWeightCol, + HasPredictionCol, + HasValidationIndicatorCol, + HasArbitraryParamsDict, + HasBaseMarginCol, +): num_workers = Param( - Params._dummy(), "num_workers", + Params._dummy(), + "num_workers", "The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.", - TypeConverters.toInt) + TypeConverters.toInt, + ) use_gpu = Param( - Params._dummy(), "use_gpu", - "A boolean variable. Set use_gpu=true if the executors " + - "are running on GPU instances. Currently, only one GPU per task is supported." + Params._dummy(), + "use_gpu", + "A boolean variable. Set use_gpu=true if the executors " + + "are running on GPU instances. Currently, only one GPU per task is supported.", ) force_repartition = Param( - Params._dummy(), "force_repartition", - "A boolean variable. Set force_repartition=true if you " + - "want to force the input dataset to be repartitioned before XGBoost training." + - "Note: The auto repartitioning judgement is not fully accurate, so it is recommended" + - "to have force_repartition be True.") + Params._dummy(), + "force_repartition", + "A boolean variable. Set force_repartition=true if you " + + "want to force the input dataset to be repartitioned before XGBoost training." + + "Note: The auto repartitioning judgement is not fully accurate, so it is recommended" + + "to have force_repartition be True.", + ) @classmethod def _xgb_cls(cls): @@ -84,8 +122,7 @@ def _xgb_cls(cls): raise NotImplementedError() def _get_xgb_model_creator(self): - arbitaryParamsDict = self.getOrDefault( - self.getParam("arbitraryParamsDict")) + arbitaryParamsDict = self.getOrDefault(self.getParam("arbitraryParamsDict")) total_params = {**self._gen_xgb_params_dict(), **arbitaryParamsDict} # Once we have already added all of the elements of kwargs, we can just remove it del total_params["arbitraryParamsDict"] @@ -99,8 +136,7 @@ def _get_xgb_params_default(cls): xgb_model_default = cls._xgb_cls()() params_dict = xgb_model_default.get_params() filtered_params_dict = { - k: params_dict[k] - for k in params_dict if k not in _unsupported_xgb_params + k: params_dict[k] for k in params_dict if k not in _unsupported_xgb_params } return filtered_params_dict @@ -111,10 +147,11 @@ def _set_xgb_params_default(self): def _gen_xgb_params_dict(self): xgb_params = {} - non_xgb_params = \ - set(_pyspark_specific_params) | \ - self._get_fit_params_default().keys() | \ - self._get_predict_params_default().keys() + non_xgb_params = ( + set(_pyspark_specific_params) + | self._get_fit_params_default().keys() + | self._get_predict_params_default().keys() + ) for param in self.extractParamMap(): if param.name not in non_xgb_params: xgb_params[param.name] = self.getOrDefault(param) @@ -128,8 +165,9 @@ def _set_distributed_params(self): # Parameters for xgboost.XGBModel().fit() @classmethod def _get_fit_params_default(cls): - fit_params = _get_default_params_from_func(cls._xgb_cls().fit, - _unsupported_fit_params) + fit_params = _get_default_params_from_func( + cls._xgb_cls().fit, _unsupported_fit_params + ) return fit_params def _set_fit_params_default(self): @@ -151,7 +189,8 @@ def _gen_fit_params_dict(self): @classmethod def _get_predict_params_default(cls): predict_params = _get_default_params_from_func( - cls._xgb_cls().predict, _unsupported_predict_params) + cls._xgb_cls().predict, _unsupported_predict_params + ) return predict_params def _set_predict_params_default(self): @@ -174,52 +213,65 @@ def _validate_params(self): if init_model is not None: if init_model is not None and not isinstance(init_model, Booster): raise ValueError( - 'The xgb_model param must be set with a `xgboost.core.Booster` ' - 'instance.') + "The xgb_model param must be set with a `xgboost.core.Booster` " + "instance." + ) if self.getOrDefault(self.num_workers) < 1: raise ValueError( f"Number of workers was {self.getOrDefault(self.num_workers)}." - f"It cannot be less than 1 [Default is 1]") + f"It cannot be less than 1 [Default is 1]" + ) if self.getOrDefault(self.num_workers) > 1 and not self.getOrDefault( - self.use_gpu): - cpu_per_task = _get_spark_session().sparkContext.getConf().get( - 'spark.task.cpus') + self.use_gpu + ): + cpu_per_task = ( + _get_spark_session().sparkContext.getConf().get("spark.task.cpus") + ) if cpu_per_task and int(cpu_per_task) > 1: get_logger(self.__class__.__name__).warning( - f'You configured {cpu_per_task} CPU cores for each spark task, but in ' - f'XGBoost training, every Spark task will only use one CPU core.' + f"You configured {cpu_per_task} CPU cores for each spark task, but in " + f"XGBoost training, every Spark task will only use one CPU core." ) - if self.getOrDefault(self.force_repartition) and self.getOrDefault( - self.num_workers) == 1: + if ( + self.getOrDefault(self.force_repartition) + and self.getOrDefault(self.num_workers) == 1 + ): get_logger(self.__class__.__name__).warning( "You set force_repartition to true when there is no need for a repartition." - "Therefore, that parameter will be ignored.") + "Therefore, that parameter will be ignored." + ) if self.getOrDefault(self.use_gpu): tree_method = self.getParam("tree_method") - if self.getOrDefault( - tree_method - ) is not None and self.getOrDefault(tree_method) != "gpu_hist": + if ( + self.getOrDefault(tree_method) is not None + and self.getOrDefault(tree_method) != "gpu_hist" + ): raise ValueError( f"tree_method should be 'gpu_hist' or None when use_gpu is True," - f"found {self.getOrDefault(tree_method)}.") + f"found {self.getOrDefault(tree_method)}." + ) - gpu_per_task = _get_spark_session().sparkContext.getConf().get( - 'spark.task.resource.gpu.amount') + gpu_per_task = ( + _get_spark_session() + .sparkContext.getConf() + .get("spark.task.resource.gpu.amount") + ) if not gpu_per_task or int(gpu_per_task) < 1: raise RuntimeError( - "The spark cluster does not have the necessary GPU" + - "configuration for the spark task. Therefore, we cannot" + - "run xgboost training using GPU.") + "The spark cluster does not have the necessary GPU" + + "configuration for the spark task. Therefore, we cannot" + + "run xgboost training using GPU." + ) if int(gpu_per_task) > 1: get_logger(self.__class__.__name__).warning( - f'You configured {gpu_per_task} GPU cores for each spark task, but in ' - f'XGBoost training, every Spark task will only use one GPU core.' + f"You configured {gpu_per_task} GPU cores for each spark task, but in " + f"XGBoost training, every Spark task will only use one GPU core." ) @@ -273,16 +325,17 @@ def _convert_to_model(self, booster): else: return None # check if this else statement is needed. - def _query_plan_contains_valid_repartition(self, query_plan, - num_partitions): + def _query_plan_contains_valid_repartition(self, query_plan, num_partitions): """ Returns true if the latest element in the logical plan is a valid repartition """ start = query_plan.index("== Optimized Logical Plan ==") start += len("== Optimized Logical Plan ==") + 1 num_workers = self.getOrDefault(self.num_workers) - if query_plan[start:start + len("Repartition")] == "Repartition" and \ - num_workers == num_partitions: + if ( + query_plan[start : start + len("Repartition")] == "Repartition" + and num_workers == num_partitions + ): return True return False @@ -297,9 +350,9 @@ def _repartition_needed(self, dataset): try: num_partitions = dataset.rdd.getNumPartitions() query_plan = dataset._sc._jvm.PythonSQLUtils.explainString( - dataset._jdf.queryExecution(), "extended") - if self._query_plan_contains_valid_repartition( - query_plan, num_partitions): + dataset._jdf.queryExecution(), "extended" + ) + if self._query_plan_contains_valid_repartition(query_plan, num_partitions): return False except: # noqa: E722 pass @@ -311,8 +364,7 @@ def _get_distributed_config(self, dataset, params): """ classification = self._xgb_cls() == XGBClassifier - num_classes = int( - dataset.select(countDistinct('label')).collect()[0][0]) + num_classes = int(dataset.select(countDistinct("label")).collect()[0][0]) if classification and num_classes == 2: params["objective"] = "binary:logistic" elif classification and num_classes > 2: @@ -342,8 +394,9 @@ def _get_dist_booster_params(cls, train_params): booster_params[key] = value return booster_params, kwargs_params - def _fit_distributed(self, xgb_model_creator, dataset, has_weight, - has_validation, fit_params): + def _fit_distributed( + self, xgb_model_creator, dataset, has_weight, has_validation, fit_params + ): """ Takes in the dataset, the other parameters, and produces a valid booster """ @@ -352,14 +405,17 @@ def _fit_distributed(self, xgb_model_creator, dataset, has_weight, max_concurrent_tasks = _get_max_num_concurrent_tasks(sc) if num_workers > max_concurrent_tasks: - get_logger(self.__class__.__name__) \ - .warning(f'The num_workers {num_workers} set for xgboost distributed ' - f'training is greater than current max number of concurrent ' - f'spark task slots, you need wait until more task slots available ' - f'or you need increase spark cluster workers.') + get_logger(self.__class__.__name__).warning( + f"The num_workers {num_workers} set for xgboost distributed " + f"training is greater than current max number of concurrent " + f"spark task slots, you need wait until more task slots available " + f"or you need increase spark cluster workers." + ) if self._repartition_needed(dataset): - dataset = dataset.withColumn("values", col("values").cast(ArrayType(FloatType()))) + dataset = dataset.withColumn( + "values", col("values").cast(ArrayType(FloatType())) + ) dataset = dataset.repartition(num_workers) train_params = self._get_distributed_config(dataset, fit_params) @@ -369,6 +425,7 @@ def _train_booster(pandas_df_iter): the Rabit Ring protocol """ from pyspark import BarrierTaskContext + context = BarrierTaskContext.get() dtrain, dval = None, [] @@ -382,8 +439,7 @@ def _train_booster(pandas_df_iter): pandas_df_iter, has_weight, has_validation ) - booster_params, kwargs_params = self._get_dist_booster_params( - train_params) + booster_params, kwargs_params = self._get_dist_booster_params(train_params) context.barrier() _rabit_args = "" if context.partitionId() == 0: @@ -393,23 +449,25 @@ def _train_booster(pandas_df_iter): _rabit_args = _get_args_from_message_list(messages) evals_result = {} with RabitContext(_rabit_args, context): - booster = worker_train(params=booster_params, - dtrain=dtrain, - evals=dval, - evals_result=evals_result, - **kwargs_params) + booster = worker_train( + params=booster_params, + dtrain=dtrain, + evals=dval, + evals_result=evals_result, + **kwargs_params, + ) context.barrier() if context.partitionId() == 0: - yield pd.DataFrame( - data={'booster_bytes': [cloudpickle.dumps(booster)]}) - - result_ser_booster = dataset.mapInPandas( - _train_booster, - schema='booster_bytes binary').rdd.barrier().mapPartitions( - lambda x: x).collect()[0][0] - result_xgb_model = self._convert_to_model( - cloudpickle.loads(result_ser_booster)) + yield pd.DataFrame(data={"booster_bytes": [cloudpickle.dumps(booster)]}) + + result_ser_booster = ( + dataset.mapInPandas(_train_booster, schema="booster_bytes binary") + .rdd.barrier() + .mapPartitions(lambda x: x) + .collect()[0][0] + ) + result_xgb_model = self._convert_to_model(cloudpickle.loads(result_ser_booster)) return self._copyValues(self._create_pyspark_model(result_xgb_model)) def _fit(self, dataset): @@ -417,32 +475,35 @@ def _fit(self, dataset): # Unwrap the VectorUDT type column "feature" to 4 primitive columns: # ['features.type', 'features.size', 'features.indices', 'features.values'] features_col = col(self.getOrDefault(self.featuresCol)) - label_col = col(self.getOrDefault(self.labelCol)).alias('label') - features_array_col = vector_to_array(features_col, dtype="float32").alias("values") + label_col = col(self.getOrDefault(self.labelCol)).alias("label") + features_array_col = vector_to_array(features_col, dtype="float32").alias( + "values" + ) select_cols = [features_array_col, label_col] has_weight = False has_validation = False has_base_margin = False - if self.isDefined(self.weightCol) and self.getOrDefault( - self.weightCol): + if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol): has_weight = True - select_cols.append( - col(self.getOrDefault(self.weightCol)).alias('weight')) + select_cols.append(col(self.getOrDefault(self.weightCol)).alias("weight")) - if self.isDefined(self.validationIndicatorCol) and \ - self.getOrDefault(self.validationIndicatorCol): + if self.isDefined(self.validationIndicatorCol) and self.getOrDefault( + self.validationIndicatorCol + ): has_validation = True select_cols.append( - col(self.getOrDefault( - self.validationIndicatorCol)).alias('validationIndicator')) + col(self.getOrDefault(self.validationIndicatorCol)).alias( + "validationIndicator" + ) + ) - if self.isDefined(self.baseMarginCol) and self.getOrDefault( - self.baseMarginCol): + if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): has_base_margin = True select_cols.append( - col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin")) + col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") + ) dataset = dataset.select(*select_cols) # create local var `xgb_model_creator` to avoid pickle `self` object to remote worker @@ -450,51 +511,68 @@ def _fit(self, dataset): fit_params = self._gen_fit_params_dict() if self.getOrDefault(self.num_workers) > 1: - return self._fit_distributed(xgb_model_creator, dataset, has_weight, - has_validation, fit_params) + return self._fit_distributed( + xgb_model_creator, dataset, has_weight, has_validation, fit_params + ) # Note: fit_params will be pickled to remote, it may include `xgb_model` param # which is used as initial model in training. The initial model will be a # `Booster` instance which support pickling. def train_func(pandas_df_iter): xgb_model = xgb_model_creator() - train_val_data = prepare_train_val_data(pandas_df_iter, has_weight, - has_validation, - has_base_margin) + train_val_data = prepare_train_val_data( + pandas_df_iter, has_weight, has_validation, has_base_margin + ) # We don't need to handle callbacks param in fit_params specially. # User need to ensure callbacks is pickle-able. if has_validation: - train_X, train_y, train_w, train_base_margin, val_X, val_y, val_w, _ = \ - train_val_data + ( + train_X, + train_y, + train_w, + train_base_margin, + val_X, + val_y, + val_w, + _, + ) = train_val_data eval_set = [(val_X, val_y)] sample_weight_eval_set = [val_w] # base_margin_eval_set = [val_base_margin] <- the underline # Note that on XGBoost 1.2.0, the above doesn't exist. - xgb_model.fit(train_X, - train_y, - sample_weight=train_w, - base_margin=train_base_margin, - eval_set=eval_set, - sample_weight_eval_set=sample_weight_eval_set, - **fit_params) + xgb_model.fit( + train_X, + train_y, + sample_weight=train_w, + base_margin=train_base_margin, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + **fit_params, + ) else: train_X, train_y, train_w, train_base_margin = train_val_data - xgb_model.fit(train_X, - train_y, - sample_weight=train_w, - base_margin=train_base_margin, - **fit_params) + xgb_model.fit( + train_X, + train_y, + sample_weight=train_w, + base_margin=train_base_margin, + **fit_params, + ) ser_model_string = serialize_xgb_model(xgb_model) - yield pd.DataFrame(data={'model_string': [ser_model_string]}) + yield pd.DataFrame(data={"model_string": [ser_model_string]}) # Train on 1 remote worker, return the string of the serialized model - result_ser_model_string = dataset.repartition(1) \ - .mapInPandas(train_func, schema='model_string string').collect()[0][0] + result_ser_model_string = ( + dataset.repartition(1) + .mapInPandas(train_func, schema="model_string string") + .collect()[0][0] + ) # Load model - result_xgb_model = deserialize_xgb_model(result_ser_model_string, - xgb_model_creator) + result_xgb_model = deserialize_xgb_model( + result_ser_model_string, xgb_model_creator + ) return self._copyValues(self._create_pyspark_model(result_xgb_model)) def write(self): @@ -516,7 +594,7 @@ def get_booster(self): """ return self._xgb_sklearn_model.get_booster() - def get_feature_importances(self, importance_type='weight'): + def get_feature_importances(self, importance_type="weight"): """Get feature importance of each feature. Importance type can be defined as: @@ -567,9 +645,8 @@ def _transform(self, dataset): xgb_sklearn_model = self._xgb_sklearn_model predict_params = self._gen_predict_params_dict() - @pandas_udf('double') - def predict_udf(iterator: Iterator[Tuple[pd.Series]]) \ - -> Iterator[pd.Series]: + @pandas_udf("double") + def predict_udf(iterator: Iterator[Tuple[pd.Series]]) -> Iterator[pd.Series]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, _ = prepare_predict_data(iterator, False) # Note: In every spark job task, pandas UDF will run in separate python process @@ -578,31 +655,30 @@ def predict_udf(iterator: Iterator[Tuple[pd.Series]]) \ preds = xgb_sklearn_model.predict(X, **predict_params) yield pd.Series(preds) - @pandas_udf('double') - def predict_udf_base_margin(iterator: Iterator[Tuple[pd.Series, pd.Series]]) \ - -> Iterator[pd.Series]: + @pandas_udf("double") + def predict_udf_base_margin( + iterator: Iterator[Tuple[pd.Series, pd.Series]] + ) -> Iterator[pd.Series]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, b_m = prepare_predict_data(iterator, True) # Note: In every spark job task, pandas UDF will run in separate python process # so it is safe here to call the thread-unsafe model.predict method if len(X) > 0: - preds = xgb_sklearn_model.predict(X, - base_margin=b_m, - **predict_params) + preds = xgb_sklearn_model.predict(X, base_margin=b_m, **predict_params) yield pd.Series(preds) features_col = col(self.getOrDefault(self.featuresCol)) - features_col = struct(vector_to_array(features_col, dtype="float32").alias("values")) + features_col = struct( + vector_to_array(features_col, dtype="float32").alias("values") + ) has_base_margin = False - if self.isDefined(self.baseMarginCol) and self.getOrDefault( - self.baseMarginCol): + if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): has_base_margin = True if has_base_margin: base_margin_col = col(self.getOrDefault(self.baseMarginCol)) - pred_col = predict_udf_base_margin(features_col, - base_margin_col) + pred_col = predict_udf_base_margin(features_col, base_margin_col) else: pred_col = predict_udf(features_col) @@ -611,8 +687,7 @@ def predict_udf_base_margin(iterator: Iterator[Tuple[pd.Series, pd.Series]]) \ return dataset.withColumn(predictionColName, pred_col) -class XgboostClassifierModel(_XgboostModel, HasProbabilityCol, - HasRawPredictionCol): +class XgboostClassifierModel(_XgboostModel, HasProbabilityCol, HasRawPredictionCol): """ The model returned by :func:`xgboost.spark.XgboostClassifier.fit` @@ -630,25 +705,25 @@ def _transform(self, dataset): predict_params = self._gen_predict_params_dict() @pandas_udf( - 'rawPrediction array, prediction double, probability array' + "rawPrediction array, prediction double, probability array" ) - def predict_udf(iterator: Iterator[Tuple[pd.Series]]) \ - -> Iterator[pd.DataFrame]: + def predict_udf(iterator: Iterator[Tuple[pd.Series]]) -> Iterator[pd.DataFrame]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, _ = prepare_predict_data(iterator, False) # Note: In every spark job task, pandas UDF will run in separate python process # so it is safe here to call the thread-unsafe model.predict method if len(X) > 0: - margins = xgb_sklearn_model.predict(X, - output_margin=True, - **predict_params) + margins = xgb_sklearn_model.predict( + X, output_margin=True, **predict_params + ) if margins.ndim == 1: # binomial case classone_probs = expit(margins) classzero_probs = 1.0 - classone_probs raw_preds = np.vstack((-margins, margins)).transpose() class_probs = np.vstack( - (classzero_probs, classone_probs)).transpose() + (classzero_probs, classone_probs) + ).transpose() else: # multinomial case raw_preds = margins @@ -659,32 +734,34 @@ def predict_udf(iterator: Iterator[Tuple[pd.Series]]) \ preds = np.argmax(class_probs, axis=1) yield pd.DataFrame( data={ - 'rawPrediction': pd.Series(raw_preds.tolist()), - 'prediction': pd.Series(preds), - 'probability': pd.Series(class_probs.tolist()) - }) + "rawPrediction": pd.Series(raw_preds.tolist()), + "prediction": pd.Series(preds), + "probability": pd.Series(class_probs.tolist()), + } + ) @pandas_udf( - 'rawPrediction array, prediction double, probability array' + "rawPrediction array, prediction double, probability array" ) - def predict_udf_base_margin(iterator: Iterator[Tuple[pd.Series, pd.Series]])\ - -> Iterator[pd.DataFrame]: + def predict_udf_base_margin( + iterator: Iterator[Tuple[pd.Series, pd.Series]] + ) -> Iterator[pd.DataFrame]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, b_m = prepare_predict_data(iterator, True) # Note: In every spark job task, pandas UDF will run in separate python process # so it is safe here to call the thread-unsafe model.predict method if len(X) > 0: - margins = xgb_sklearn_model.predict(X, - base_margin=b_m, - output_margin=True, - **predict_params) + margins = xgb_sklearn_model.predict( + X, base_margin=b_m, output_margin=True, **predict_params + ) if margins.ndim == 1: # binomial case classone_probs = expit(margins) classzero_probs = 1.0 - classone_probs raw_preds = np.vstack((-margins, margins)).transpose() class_probs = np.vstack( - (classzero_probs, classone_probs)).transpose() + (classzero_probs, classone_probs) + ).transpose() else: # multinomial case raw_preds = margins @@ -695,27 +772,28 @@ def predict_udf_base_margin(iterator: Iterator[Tuple[pd.Series, pd.Series]])\ preds = np.argmax(class_probs, axis=1) yield pd.DataFrame( data={ - 'rawPrediction': pd.Series(raw_preds.tolist()), - 'prediction': pd.Series(preds), - 'probability': pd.Series(class_probs.tolist()) - }) + "rawPrediction": pd.Series(raw_preds.tolist()), + "prediction": pd.Series(preds), + "probability": pd.Series(class_probs.tolist()), + } + ) features_col = col(self.getOrDefault(self.featuresCol)) - features_col = struct(vector_to_array(features_col, dtype="float32").alias("values")) + features_col = struct( + vector_to_array(features_col, dtype="float32").alias("values") + ) has_base_margin = False - if self.isDefined(self.baseMarginCol) and self.getOrDefault( - self.baseMarginCol): + if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): has_base_margin = True if has_base_margin: base_margin_col = col(self.getOrDefault(self.baseMarginCol)) - pred_struct = predict_udf_base_margin(features_col, - base_margin_col) + pred_struct = predict_udf_base_margin(features_col, base_margin_col) else: pred_struct = predict_udf(features_col) - pred_struct_col = '_prediction_struct' + pred_struct_col = "_prediction_struct" rawPredictionColName = self.getOrDefault(self.rawPredictionCol) predictionColName = self.getOrDefault(self.predictionCol) @@ -724,20 +802,21 @@ def predict_udf_base_margin(iterator: Iterator[Tuple[pd.Series, pd.Series]])\ if rawPredictionColName: dataset = dataset.withColumn( rawPredictionColName, - array_to_vector(col(pred_struct_col).rawPrediction)) + array_to_vector(col(pred_struct_col).rawPrediction), + ) if predictionColName: - dataset = dataset.withColumn(predictionColName, - col(pred_struct_col).prediction) + dataset = dataset.withColumn( + predictionColName, col(pred_struct_col).prediction + ) if probabilityColName: dataset = dataset.withColumn( - probabilityColName, - array_to_vector(col(pred_struct_col).probability)) + probabilityColName, array_to_vector(col(pred_struct_col).probability) + ) return dataset.drop(pred_struct_col) -def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, - pyspark_model_class): +def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class): params_dict = pyspark_estimator_class._get_xgb_params_default() def param_value_converter(v): @@ -757,32 +836,42 @@ def set_param_attrs(attr_name, param_obj_): setattr(pyspark_model_class, attr_name, param_obj_) for name in params_dict.keys(): - if name == 'missing': - doc = 'Specify the missing value in the features, default np.nan. ' \ - 'We recommend using 0.0 as the missing value for better performance. ' \ - 'Note: In a spark DataFrame, the inactive values in a sparse vector ' \ - 'mean 0 instead of missing values, unless missing=0 is specified.' + if name == "missing": + doc = ( + "Specify the missing value in the features, default np.nan. " + "We recommend using 0.0 as the missing value for better performance. " + "Note: In a spark DataFrame, the inactive values in a sparse vector " + "mean 0 instead of missing values, unless missing=0 is specified." + ) else: - doc = f'Refer to XGBoost doc of ' \ - f'{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}' + doc = ( + f"Refer to XGBoost doc of " + f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}" + ) param_obj = Param(Params._dummy(), name=name, doc=doc) set_param_attrs(name, param_obj) fit_params_dict = pyspark_estimator_class._get_fit_params_default() for name in fit_params_dict.keys(): - doc = f'Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}' \ - f'.fit() for this param {name}' - if name == 'callbacks': - doc += 'The callbacks can be arbitrary functions. It is saved using cloudpickle ' \ - 'which is not a fully self-contained format. It may fail to load with ' \ - 'different versions of dependencies.' + doc = ( + f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}" + f".fit() for this param {name}" + ) + if name == "callbacks": + doc += ( + "The callbacks can be arbitrary functions. It is saved using cloudpickle " + "which is not a fully self-contained format. It may fail to load with " + "different versions of dependencies." + ) param_obj = Param(Params._dummy(), name=name, doc=doc) set_param_attrs(name, param_obj) predict_params_dict = pyspark_estimator_class._get_predict_params_default() for name in predict_params_dict.keys(): - doc = f'Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}' \ - f'.predict() for this param {name}' + doc = ( + f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}" + f".predict() for this param {name}" + ) param_obj = Param(Params._dummy(), name=name, doc=doc) set_param_attrs(name, param_obj) diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index ae8e5c25309f..3060d4dc9184 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -14,7 +14,8 @@ def _dump_libsvm(features, labels, weights=None, external_storage_precision=5): def gen_label_str(row_idx): if weights is not None: return "{label:.{esp}g}:{weight:.{esp}g}".format( - label=labels[row_idx], esp=esp, weight=weights[row_idx]) + label=labels[row_idx], esp=esp, weight=weights[row_idx] + ) else: return "{label:.{esp}g}".format(label=labels[row_idx], esp=esp) @@ -42,15 +43,16 @@ def gen_feature_value_str(feature_idx, feature_val): # This is the updated version that handles weights -def _stream_train_val_data(features, labels, weights, main_file, - external_storage_precision): +def _stream_train_val_data( + features, labels, weights, main_file, external_storage_precision +): lines = _dump_libsvm(features, labels, weights, external_storage_precision) main_file.writelines(lines) -def _stream_data_into_libsvm_file(data_iterator, has_weight, - has_validation, file_prefix, - external_storage_precision): +def _stream_data_into_libsvm_file( + data_iterator, has_weight, has_validation, file_prefix, external_storage_precision +): # getting the file names for storage train_file_name = file_prefix + "/data.txt.train" train_file = open(train_file_name, "w") @@ -58,20 +60,22 @@ def _stream_data_into_libsvm_file(data_iterator, has_weight, validation_file_name = file_prefix + "/data.txt.val" validation_file = open(validation_file_name, "w") - train_val_data = _process_data_iter(data_iterator, - train=True, - has_weight=has_weight, - has_validation=has_validation) + train_val_data = _process_data_iter( + data_iterator, train=True, has_weight=has_weight, has_validation=has_validation + ) if has_validation: train_X, train_y, train_w, _, val_X, val_y, val_w, _ = train_val_data - _stream_train_val_data(train_X, train_y, train_w, train_file, - external_storage_precision) - _stream_train_val_data(val_X, val_y, val_w, validation_file, - external_storage_precision) + _stream_train_val_data( + train_X, train_y, train_w, train_file, external_storage_precision + ) + _stream_train_val_data( + val_X, val_y, val_w, validation_file, external_storage_precision + ) else: train_X, train_y, train_w, _ = train_val_data - _stream_train_val_data(train_X, train_y, train_w, train_file, - external_storage_precision) + _stream_train_val_data( + train_X, train_y, train_w, train_file, external_storage_precision + ) if has_validation: train_file.close() @@ -92,29 +96,32 @@ def _create_dmatrix_from_file(file_name, cache_name): return DMatrix(file_name + "#" + cache_name) -def prepare_train_val_data(data_iterator, - has_weight, - has_validation, - has_fit_base_margin=False): +def prepare_train_val_data( + data_iterator, has_weight, has_validation, has_fit_base_margin=False +): def gen_data_pdf(): for pdf in data_iterator: yield pdf - return _process_data_iter(gen_data_pdf(), - train=True, - has_weight=has_weight, - has_validation=has_validation, - has_fit_base_margin=has_fit_base_margin, - has_predict_base_margin=False) + return _process_data_iter( + gen_data_pdf(), + train=True, + has_weight=has_weight, + has_validation=has_validation, + has_fit_base_margin=has_fit_base_margin, + has_predict_base_margin=False, + ) def prepare_predict_data(data_iterator, has_predict_base_margin): - return _process_data_iter(data_iterator, - train=False, - has_weight=False, - has_validation=False, - has_fit_base_margin=False, - has_predict_base_margin=has_predict_base_margin) + return _process_data_iter( + data_iterator, + train=False, + has_weight=False, + has_validation=False, + has_fit_base_margin=False, + has_predict_base_margin=has_predict_base_margin, + ) def _check_feature_dims(num_dims, expected_dims): @@ -124,16 +131,21 @@ def _check_feature_dims(num_dims, expected_dims): if expected_dims is None: return num_dims if num_dims != expected_dims: - raise ValueError("Rows contain different feature dimensions: " - "Expecting {}, got {}.".format( - expected_dims, num_dims)) + raise ValueError( + "Rows contain different feature dimensions: " + "Expecting {}, got {}.".format(expected_dims, num_dims) + ) return expected_dims -def _row_tuple_list_to_feature_matrix_y_w(data_iterator, train, has_weight, - has_fit_base_margin, - has_predict_base_margin, - has_validation: bool = False): +def _row_tuple_list_to_feature_matrix_y_w( + data_iterator, + train, + has_weight, + has_fit_base_margin, + has_predict_base_margin, + has_validation: bool = False, +): """ Construct a feature matrix in ndarray format, label array y and weight array w from the row_tuple_list. @@ -161,8 +173,9 @@ def _row_tuple_list_to_feature_matrix_y_w(data_iterator, train, has_weight, num_feature_dims = len(pdf["values"].values[0]) - expected_feature_dims = _check_feature_dims(num_feature_dims, - expected_feature_dims) + expected_feature_dims = _check_feature_dims( + num_feature_dims, expected_feature_dims + ) # TODO: Improve performance, avoid use python list values_list.append(pdf["values"].to_list()) @@ -189,24 +202,32 @@ def _row_tuple_list_to_feature_matrix_y_w(data_iterator, train, has_weight, feature_matrix = np.concatenate(values_list) y = np.concatenate(label_list) if train else None w = np.concatenate(weight_list) if has_weight else None - b_m = np.concatenate(base_margin_list) if ( - has_fit_base_margin or has_predict_base_margin) else None + b_m = ( + np.concatenate(base_margin_list) + if (has_fit_base_margin or has_predict_base_margin) + else None + ) if has_validation: feature_matrix_val = np.concatenate(values_val_list) y_val = np.concatenate(label_val_list) if train else None w_val = np.concatenate(weight_val_list) if has_weight else None - b_m_val = np.concatenate(base_margin_val_list) if ( - has_fit_base_margin or has_predict_base_margin) else None + b_m_val = ( + np.concatenate(base_margin_val_list) + if (has_fit_base_margin or has_predict_base_margin) + else None + ) return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val return feature_matrix, y, w, b_m -def _process_data_iter(data_iterator: Iterator[pd.DataFrame], - train: bool, - has_weight: bool, - has_validation: bool, - has_fit_base_margin: bool = False, - has_predict_base_margin: bool = False): +def _process_data_iter( + data_iterator: Iterator[pd.DataFrame], + train: bool, + has_weight: bool, + has_validation: bool, + has_fit_base_margin: bool = False, + has_predict_base_margin: bool = False, +): """ If input is for train and has_validation=True, it will split the train data into train dataset and validation dataset, and return (train_X, train_y, train_w, train_b_m <- @@ -214,22 +235,40 @@ def _process_data_iter(data_iterator: Iterator[pd.DataFrame], otherwise return (X, y, w, b_m <- base margin) """ if train and has_validation: - train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m = \ - _row_tuple_list_to_feature_matrix_y_w( - data_iterator, train, has_weight, has_fit_base_margin, - has_predict_base_margin, has_validation) + ( + train_X, + train_y, + train_w, + train_b_m, + val_X, + val_y, + val_w, + val_b_m, + ) = _row_tuple_list_to_feature_matrix_y_w( + data_iterator, + train, + has_weight, + has_fit_base_margin, + has_predict_base_margin, + has_validation, + ) return train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m else: - return _row_tuple_list_to_feature_matrix_y_w(data_iterator, train, has_weight, - has_fit_base_margin, has_predict_base_margin, - has_validation) + return _row_tuple_list_to_feature_matrix_y_w( + data_iterator, + train, + has_weight, + has_fit_base_margin, + has_predict_base_margin, + has_validation, + ) -def convert_partition_data_to_dmatrix(partition_data_iter, - has_weight, - has_validation): +def convert_partition_data_to_dmatrix(partition_data_iter, has_weight, has_validation): # if we are not using external storage, we use the standard method of parsing data. - train_val_data = prepare_train_val_data(partition_data_iter, has_weight, has_validation) + train_val_data = prepare_train_val_data( + partition_data_iter, has_weight, has_validation + ) if has_validation: train_X, train_y, train_w, _, val_X, val_y, val_w, _ = train_val_data training_dmatrix = DMatrix(data=train_X, label=train_y, weight=train_w) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index e8c8bbd5350b..804fd24950be 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,7 +1,11 @@ from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from xgboost import XGBClassifier, XGBRegressor -from .core import (_XgboostEstimator, XgboostClassifierModel, - XgboostRegressorModel, _set_pyspark_xgb_cls_param_attrs) +from .core import ( + _XgboostEstimator, + XgboostClassifierModel, + XgboostRegressorModel, + _set_pyspark_xgb_cls_param_attrs, +) class XgboostRegressor(_XgboostEstimator): @@ -83,6 +87,7 @@ class XgboostRegressor(_XgboostEstimator): >>> xgb_reg_model.transform(df_test) """ + def __init__(self, **kwargs): super().__init__() self.setParams(**kwargs) @@ -99,8 +104,7 @@ def _pyspark_model_cls(cls): _set_pyspark_xgb_cls_param_attrs(XgboostRegressor, XgboostRegressorModel) -class XgboostClassifier(_XgboostEstimator, HasProbabilityCol, - HasRawPredictionCol): +class XgboostClassifier(_XgboostEstimator, HasProbabilityCol, HasRawPredictionCol): """ XgboostClassifier is a PySpark ML estimator. It implements the XGBoost classification algorithm based on XGBoost python library, and it can be used in PySpark Pipeline @@ -187,6 +191,7 @@ class XgboostClassifier(_XgboostEstimator, HasProbabilityCol, >>> xgb_clf_model.transform(df_test).show() """ + def __init__(self, **kwargs): super().__init__() self.setParams(**kwargs) diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 31994061b146..b8ef24c0e7da 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -5,8 +5,7 @@ from pyspark import cloudpickle from pyspark import SparkFiles from pyspark.sql import SparkSession -from pyspark.ml.util import (DefaultParamsReader, DefaultParamsWriter, - MLReader, MLWriter) +from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter from xgboost.core import Booster from .utils import get_logger, get_class_name @@ -30,7 +29,7 @@ def get_xgb_model_creator(model_cls, xgb_params): def _get_or_create_tmp_dir(): root_dir = SparkFiles.getRootDirectory() - xgb_tmp_dir = os.path.join(root_dir, 'xgboost-tmp') + xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp") if not os.path.exists(xgb_tmp_dir): os.makedirs(xgb_tmp_dir) return xgb_tmp_dir @@ -47,7 +46,7 @@ def serialize_xgb_model(model): xgboost.XGBClassifier or xgboost.XGBRegressor instance """ # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") model.save_model(tmp_file_name) with open(tmp_file_name) as f: ser_model_string = f.read() @@ -60,7 +59,7 @@ def deserialize_xgb_model(ser_model_string, xgb_model_creator): """ xgb_model = xgb_model_creator() # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") with open(tmp_file_name, "w") as f: f.write(ser_model_string) xgb_model.load_model(tmp_file_name) @@ -77,7 +76,7 @@ def serialize_booster(booster): an xgboost.core.Booster instance """ # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") booster.save_model(tmp_file_name) with open(tmp_file_name) as f: ser_model_string = f.read() @@ -90,7 +89,7 @@ def deserialize_booster(ser_model_string): """ booster = Booster() # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f'{uuid.uuid4()}.json') + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") with open(tmp_file_name, "w") as f: f.write(ser_model_string) booster.load_model(tmp_file_name) @@ -105,7 +104,6 @@ def _get_spark_session(): class XgboostSharedReadWrite: - @staticmethod def saveMetadata(instance, path, sc, logger, extraMetadata=None): """ @@ -113,7 +111,7 @@ def saveMetadata(instance, path, sc, logger, extraMetadata=None): xgboost.spark._XgboostModel. """ instance._validate_params() - skipParams = ['callbacks', 'xgb_model'] + skipParams = ["callbacks", "xgb_model"] jsonParams = {} for p, v in instance._paramMap.items(): if p.name not in skipParams: @@ -122,22 +120,27 @@ def saveMetadata(instance, path, sc, logger, extraMetadata=None): extraMetadata = extraMetadata or {} callbacks = instance.getOrDefault(instance.callbacks) if callbacks is not None: - logger.warning('The callbacks parameter is saved using cloudpickle and it ' - 'is not a fully self-contained format. It may fail to load ' - 'with different versions of dependencies.') - serialized_callbacks = \ - base64.encodebytes(cloudpickle.dumps(callbacks)).decode('ascii') - extraMetadata['serialized_callbacks'] = serialized_callbacks + logger.warning( + "The callbacks parameter is saved using cloudpickle and it " + "is not a fully self-contained format. It may fail to load " + "with different versions of dependencies." + ) + serialized_callbacks = base64.encodebytes( + cloudpickle.dumps(callbacks) + ).decode("ascii") + extraMetadata["serialized_callbacks"] = serialized_callbacks init_booster = instance.getOrDefault(instance.xgb_model) if init_booster is not None: - extraMetadata['init_booster'] = _INIT_BOOSTER_SAVE_PATH + extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH DefaultParamsWriter.saveMetadata( - instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams) + instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams + ) if init_booster is not None: ser_init_booster = serialize_booster(init_booster) save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH) _get_spark_session().createDataFrame( - [(ser_init_booster,)], ['init_booster']).write.parquet(save_path) + [(ser_init_booster,)], ["init_booster"] + ).write.parquet(save_path) @staticmethod def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): @@ -148,24 +151,29 @@ def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): :return: a tuple of (metadata, instance) """ metadata = DefaultParamsReader.loadMetadata( - path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)) + path, sc, expectedClassName=get_class_name(pyspark_xgb_cls) + ) pyspark_xgb = pyspark_xgb_cls() DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata) - if 'serialized_callbacks' in metadata: - serialized_callbacks = metadata['serialized_callbacks'] + if "serialized_callbacks" in metadata: + serialized_callbacks = metadata["serialized_callbacks"] try: - callbacks = \ - cloudpickle.loads(base64.decodebytes(serialized_callbacks.encode('ascii'))) + callbacks = cloudpickle.loads( + base64.decodebytes(serialized_callbacks.encode("ascii")) + ) pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) - except Exception as e: # pylint: disable=W0703 - logger.warning('Fails to load the callbacks param due to {}. Please set the ' - 'callbacks param manually for the loaded estimator.'.format(e)) - - if 'init_booster' in metadata: - load_path = os.path.join(path, metadata['init_booster']) - ser_init_booster = _get_spark_session().read.parquet(load_path) \ - .collect()[0].init_booster + except Exception as e: # pylint: disable=W0703 + logger.warning( + "Fails to load the callbacks param due to {}. Please set the " + "callbacks param manually for the loaded estimator.".format(e) + ) + + if "init_booster" in metadata: + load_path = os.path.join(path, metadata["init_booster"]) + ser_init_booster = ( + _get_spark_session().read.parquet(load_path).collect()[0].init_booster + ) init_booster = deserialize_booster(ser_init_booster) pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) @@ -174,35 +182,33 @@ def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): class XgboostWriter(MLWriter): - def __init__(self, instance): super().__init__() self.instance = instance - self.logger = get_logger(self.__class__.__name__, level='WARN') + self.logger = get_logger(self.__class__.__name__, level="WARN") def saveImpl(self, path): XgboostSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) class XgboostReader(MLReader): - def __init__(self, cls): super().__init__() self.cls = cls - self.logger = get_logger(self.__class__.__name__, level='WARN') + self.logger = get_logger(self.__class__.__name__, level="WARN") def load(self, path): - _, pyspark_xgb = XgboostSharedReadWrite \ - .loadMetadataAndInstance(self.cls, path, self.sc, self.logger) + _, pyspark_xgb = XgboostSharedReadWrite.loadMetadataAndInstance( + self.cls, path, self.sc, self.logger + ) return pyspark_xgb class XgboostModelWriter(MLWriter): - def __init__(self, instance): super().__init__() self.instance = instance - self.logger = get_logger(self.__class__.__name__, level='WARN') + self.logger = get_logger(self.__class__.__name__, level="WARN") def saveImpl(self, path): """ @@ -215,15 +221,15 @@ def saveImpl(self, path): model_save_path = os.path.join(path, "model.json") ser_xgb_model = serialize_xgb_model(xgb_model) _get_spark_session().createDataFrame( - [(ser_xgb_model,)], ['xgb_sklearn_model']).write.parquet(model_save_path) + [(ser_xgb_model,)], ["xgb_sklearn_model"] + ).write.parquet(model_save_path) class XgboostModelReader(MLReader): - def __init__(self, cls): super().__init__() self.cls = cls - self.logger = get_logger(self.__class__.__name__, level='WARN') + self.logger = get_logger(self.__class__.__name__, level="WARN") def load(self, path): """ @@ -232,14 +238,20 @@ def load(self, path): :return: XgboostRegressorModel or XgboostClassifierModel instance """ _, py_model = XgboostSharedReadWrite.loadMetadataAndInstance( - self.cls, path, self.sc, self.logger) + self.cls, path, self.sc, self.logger + ) xgb_params = py_model._gen_xgb_params_dict() model_load_path = os.path.join(path, "model.json") - ser_xgb_model = _get_spark_session().read.parquet(model_load_path) \ - .collect()[0].xgb_sklearn_model - xgb_model = deserialize_xgb_model(ser_xgb_model, - lambda: self.cls._xgb_cls()(**xgb_params)) + ser_xgb_model = ( + _get_spark_session() + .read.parquet(model_load_path) + .collect()[0] + .xgb_sklearn_model + ) + xgb_model = deserialize_xgb_model( + ser_xgb_model, lambda: self.cls._xgb_cls()(**xgb_params) + ) py_model._xgb_sklearn_model = xgb_model return py_model diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 6f4be9ec838b..5ad7b1ddbce1 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -23,8 +23,10 @@ def _get_default_params_from_func(func, unsupported_set): filtered_params_dict = dict() for parameter in sig.parameters.values(): # Remove parameters without a default value and those in the unsupported_set - if parameter.default is not parameter.empty \ - and parameter.name not in unsupported_set: + if ( + parameter.default is not parameter.empty + and parameter.name not in unsupported_set + ): filtered_params_dict[parameter.name] = parameter.default return filtered_params_dict @@ -36,10 +38,13 @@ class HasArbitraryParamsDict(Params): input. """ - arbitraryParamsDict = Param(Params._dummy(), "arbitraryParamsDict", - "This parameter holds all of the user defined parameters that" - " the sklearn implementation of XGBoost can't recognize. " - "It is stored as a dictionary.") + arbitraryParamsDict = Param( + Params._dummy(), + "arbitraryParamsDict", + "This parameter holds all of the user defined parameters that" + " the sklearn implementation of XGBoost can't recognize. " + "It is stored as a dictionary.", + ) def setArbitraryParamsDict(self, value): return self._set(arbitraryParamsDict=value) @@ -53,9 +58,12 @@ class HasBaseMarginCol(Params): This is a Params based class that is extended by _XGBoostParams and holds the variable to store the base margin column part of XGboost. """ + baseMarginCol = Param( - Params._dummy(), "baseMarginCol", - "This stores the name for the column of the base margin") + Params._dummy(), + "baseMarginCol", + "This stores the name for the column of the base margin", + ) def setBaseMarginCol(self, value): return self._set(baseMarginCol=value) @@ -69,10 +77,10 @@ class RabitContext: A context controlling rabit initialization and finalization. This isn't specificially necessary (note Part 3), but it is more understandable coding-wise. """ + def __init__(self, args, context): self.args = args - self.args.append( - ('DMLC_TASK_ID=' + str(context.partitionId())).encode()) + self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode()) def __enter__(self): rabit.init(self.args) @@ -85,7 +93,7 @@ def _start_tracker(context, n_workers): """ Start Rabit tracker with n_workers """ - env = {'DMLC_NUM_WORKER': n_workers} + env = {"DMLC_NUM_WORKER": n_workers} host = get_host_ip(context) rabit_context = RabitTracker(host_ip=host, n_workers=n_workers) env.update(rabit_context.worker_envs()) @@ -101,7 +109,7 @@ def _get_rabit_args(context, n_workers): Get rabit context arguments to send to each worker. """ env = _start_tracker(context, n_workers) - rabit_args = [('%s=%s' % item).encode() for item in env.items()] + rabit_args = [("%s=%s" % item).encode() for item in env.items()] return rabit_args @@ -109,9 +117,7 @@ def get_host_ip(context): """ Gets the hostIP for Spark. This essentially gets the IP of the first worker. """ - task_ip_list = [ - info.address.split(":")[0] for info in context.getTaskInfos() - ] + task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()] return task_ip_list[0] @@ -124,9 +130,7 @@ def _get_args_from_message_list(messages): if message != "": output = message break - return [ - elem.split("'")[1].encode() for elem in output.strip('][').split(', ') - ] + return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")] def _get_spark_session(): @@ -134,7 +138,8 @@ def _get_spark_session(): if pyspark.TaskContext.get() is not None: # This is a safety check. raise RuntimeError( - '_get_spark_session should not be invoked from executor side.') + "_get_spark_session should not be invoked from executor side." + ) return SparkSession.builder.getOrCreate() @@ -154,17 +159,19 @@ def _getConfBoolean(sqlContext, key, defaultValue): val = sqlContext.getConf(key, str(defaultValue)) # Convert val to str to handle unicode issues across Python 2 and 3. lowercase_val = str(val.lower()) - if lowercase_val == 'true': + if lowercase_val == "true": return True - elif lowercase_val == 'false': + elif lowercase_val == "false": return False else: - raise Exception("_getConfBoolean expected a boolean conf value but found value of type {} " - "with value: {}".format(type(val), val)) + raise Exception( + "_getConfBoolean expected a boolean conf value but found value of type {} " + "with value: {}".format(type(val), val) + ) -def get_logger(name, level='INFO'): - """ Gets a logger by name, or creates and configures it for the first time. """ +def get_logger(name, level="INFO"): + """Gets a logger by name, or creates and configures it for the first time.""" logger = logging.getLogger(name) logger.setLevel(level) # If the logger is configured, skip the configure @@ -177,7 +184,7 @@ def get_logger(name, level='INFO'): def _get_max_num_concurrent_tasks(sc): """Gets the current max number of concurrent tasks.""" # spark 3.1 and above has a different API for fetching max concurrent tasks - if sc._jsc.sc().version() >= '3.1': + if sc._jsc.sc().version() >= "3.1": return sc._jsc.sc().maxNumConcurrentTasks( sc._jsc.sc().resourceProfileManager().resourceProfileFromId(0) ) diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 29b0108ae9ba..1136e030e47a 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -4,17 +4,21 @@ import pandas as pd from scipy.sparse import csr_matrix -from xgboost.spark.data import _row_tuple_list_to_feature_matrix_y_w, convert_partition_data_to_dmatrix, _dump_libsvm +from xgboost.spark.data import ( + _row_tuple_list_to_feature_matrix_y_w, + convert_partition_data_to_dmatrix, + _dump_libsvm, +) from xgboost import DMatrix, XGBClassifier from xgboost.training import train as worker_train from .utils_test import SparkTestCase import logging + logging.getLogger("py4j").setLevel(logging.INFO) class DataTest(SparkTestCase): - def test_sparse_dense_vector(self): def row_tup_iter(data): pdf = pd.DataFrame(data) @@ -25,7 +29,12 @@ def row_tup_iter(data): expected_ndarray = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]} feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( - list(row_tup_iter(data)), train=False, has_weight=False, has_fit_base_margin=False, has_predict_base_margin=False) + list(row_tup_iter(data)), + train=False, + has_weight=False, + has_fit_base_margin=False, + has_predict_base_margin=False, + ) self.assertIsNone(y) self.assertIsNone(w) # self.assertTrue(isinstance(feature_matrix, csr_matrix)) @@ -33,19 +42,29 @@ def row_tup_iter(data): data["label"] = [1, 0] feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( - row_tup_iter(data), train=True, has_weight=False, has_fit_base_margin=False, has_predict_base_margin=False) + row_tup_iter(data), + train=True, + has_weight=False, + has_fit_base_margin=False, + has_predict_base_margin=False, + ) self.assertIsNone(w) # self.assertTrue(isinstance(feature_matrix, csr_matrix)) self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) - self.assertTrue(np.array_equal(y, np.array(data['label']))) + self.assertTrue(np.array_equal(y, np.array(data["label"]))) data["weight"] = [0.2, 0.8] feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( - list(row_tup_iter(data)), train=True, has_weight=True, has_fit_base_margin=False, has_predict_base_margin=False) + list(row_tup_iter(data)), + train=True, + has_weight=True, + has_fit_base_margin=False, + has_predict_base_margin=False, + ) # self.assertTrue(isinstance(feature_matrix, csr_matrix)) self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) - self.assertTrue(np.array_equal(y, np.array(data['label']))) - self.assertTrue(np.array_equal(w, np.array(data['weight']))) + self.assertTrue(np.array_equal(y, np.array(data["label"]))) + self.assertTrue(np.array_equal(w, np.array(data["weight"]))) def test_dmatrix_creator(self): @@ -59,8 +78,13 @@ def row_tup_iter(data): expected_labels = np.array([1, 0] * 100) expected_dmatrix = DMatrix(data=expected_features, label=expected_labels) - data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, "label": [1, 0] * 100} - output_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=False, has_validation=False) + data = { + "values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, + "label": [1, 0] * 100, + } + output_dmatrix = convert_partition_data_to_dmatrix( + [pd.DataFrame(data)], has_weight=False, has_validation=False + ) # You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using # the same classifier and making sure the outputs are equal model = XGBClassifier() @@ -71,10 +95,14 @@ def row_tup_iter(data): # DMatrix creation with weights expected_weight = np.array([0.2, 0.8] * 100) - expected_dmatrix = DMatrix(data=expected_features, label=expected_labels, weight=expected_weight) + expected_dmatrix = DMatrix( + data=expected_features, label=expected_labels, weight=expected_weight + ) data["weight"] = [0.2, 0.8] * 100 - output_dmatrix = convert_partition_data_to_dmatrix([pd.DataFrame(data)], has_weight=True, has_validation=False) + output_dmatrix = convert_partition_data_to_dmatrix( + [pd.DataFrame(data)], has_weight=True, has_validation=False + ) model.fit(expected_features, expected_labels, sample_weight=expected_weight) expected_preds = model.get_booster().predict(expected_dmatrix) @@ -88,7 +116,10 @@ def test_external_storage(self): normal_dmatrix = DMatrix(features, labels) test_dmatrix = DMatrix(features) - data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, "label": [1, 0] * 100} + data = { + "values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, + "label": [1, 0] * 100, + } # Creating the dmatrix based on storage temporary_path = tempfile.mkdtemp() @@ -98,7 +129,7 @@ def test_external_storage(self): # Testing without weights normal_booster = worker_train({}, normal_dmatrix) - storage_booster = worker_train({}, storage_dmatrix) + storage_booster = worker_train({}, storage_dmatrix) normal_preds = normal_booster.predict(test_dmatrix) storage_preds = storage_booster.predict(test_dmatrix) self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3)) @@ -124,8 +155,8 @@ def test_external_storage(self): def test_dump_libsvm(self): num_features = 3 features_test_list = [ - [[1,2,3],[0,1,5.5]], - csr_matrix(([1, 2, 3], [0, 2, 2], [0, 2, 3]), shape=(2, 3)) + [[1, 2, 3], [0, 1, 5.5]], + csr_matrix(([1, 2, 3], [0, 2, 2], [0, 2, 3]), shape=(2, 3)), ] labels = [0, 1] @@ -164,8 +195,16 @@ def test_dump_libsvm(self): loaded_feature[int(split[0])] = float(split[1]) self.assertListEqual(loaded_feature, list(features_array[i])) - features = [[1.34234,2.342321,3.34322],[0.344234,1.123123,5.534322],[3.553423e10,3.5632e10,0.00000000000012345]] - features_prec = [[1.34, 2.34, 3.34], [0.344, 1.12, 5.53],[3.55e10, 3.56e10, 1.23e-13]] + features = [ + [1.34234, 2.342321, 3.34322], + [0.344234, 1.123123, 5.534322], + [3.553423e10, 3.5632e10, 0.00000000000012345], + ] + features_prec = [ + [1.34, 2.34, 3.34], + [0.344, 1.12, 5.53], + [3.55e10, 3.56e10, 1.23e-13], + ] labels = [0, 1] output = _dump_libsvm(features, labels, external_storage_precision=3) for i, line in enumerate(output): diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index dd1eaa60dd73..8ea006ad8335 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -18,9 +18,7 @@ class UtilsTest(unittest.TestCase): - def test_get_default_params(self): - class Foo: def func1(self, x, y, key1=None, key2="val2", key3=0, key4=None): pass @@ -30,11 +28,16 @@ def func1(self, x, y, key1=None, key2="val2", key3=0, key4=None): "key1": None, "key3": 0, } - actual_default_params = _get_default_params_from_func(Foo.func1, unsupported_params) - self.assertEqual(len(expected_default_params.keys()), len(actual_default_params.keys())) + actual_default_params = _get_default_params_from_func( + Foo.func1, unsupported_params + ) + self.assertEqual( + len(expected_default_params.keys()), len(actual_default_params.keys()) + ) for k, v in actual_default_params.items(): self.assertEqual(expected_default_params[k], v) + @contextlib.contextmanager def patch_stdout(): """patch stdout and give an output""" @@ -76,11 +79,11 @@ def remove_tempdir(cls): class TestSparkContext(object): @classmethod def setup_env(cls, spark_config): - builder = SparkSession.builder.appName('xgboost spark python API Tests') + builder = SparkSession.builder.appName("xgboost spark python API Tests") for k, v in spark_config.items(): builder.config(k, v) spark = builder.getOrCreate() - logging.getLogger('pyspark').setLevel(logging.INFO) + logging.getLogger("pyspark").setLevel(logging.INFO) cls.sc = spark.sparkContext cls.sql = SQLContext(cls.sc) @@ -96,13 +99,14 @@ def tear_down_env(cls): class SparkTestCase(TestSparkContext, TestTempDir, unittest.TestCase): - @classmethod def setUpClass(cls): - cls.setup_env({ - 'spark.master': 'local[2]', - 'spark.python.worker.reuse': 'false', - }) + cls.setup_env( + { + "spark.master": "local[2]", + "spark.python.worker.reuse": "false", + } + ) @classmethod def tearDownClass(cls): @@ -110,20 +114,21 @@ def tearDownClass(cls): class SparkLocalClusterTestCase(TestSparkContext, TestTempDir, unittest.TestCase): - @classmethod def setUpClass(cls): - cls.setup_env({ - 'spark.master': 'local-cluster[2, 2, 1024]', - 'spark.python.worker.reuse': 'false', - 'spark.cores.max': '4', - 'spark.task.cpus': '1', - 'spark.executor.cores': '2', - 'spark.worker.resource.gpu.amount': '4', - 'spark.task.resource.gpu.amount': '2', - 'spark.executor.resource.gpu.amount': '4', - 'spark.worker.resource.gpu.discoveryScript': 'test_spark/discover_gpu.sh' - }) + cls.setup_env( + { + "spark.master": "local-cluster[2, 2, 1024]", + "spark.python.worker.reuse": "false", + "spark.cores.max": "4", + "spark.task.cpus": "1", + "spark.executor.cores": "2", + "spark.worker.resource.gpu.amount": "4", + "spark.task.resource.gpu.amount": "2", + "spark.executor.resource.gpu.amount": "4", + "spark.worker.resource.gpu.discoveryScript": "test_spark/discover_gpu.sh", + } + ) cls.make_tempdir() # We run a dummy job so that we block until the workers have connected to the master cls.sc.parallelize(range(4), 4).barrier().mapPartitions(lambda _: []).collect() diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index fb25f3bbf3a3..7b4b4643a4aa 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -13,7 +13,6 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): - def setUp(self): random.seed(2020) @@ -38,96 +37,178 @@ def setUp(self): # >>> reg2.fit(X, y) # >>> reg2.predict(X, ntree_limit=5) # array([0.22185263, 0.77814734], dtype=float32) - self.reg_params = {'max_depth': 5, 'n_estimators': 10, 'ntree_limit': 5} - self.reg_df_train = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1) - ], ["features", "label"]) - self.reg_df_test = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759) - ], ["features", "expected_prediction", "expected_prediction_with_params", - "expected_prediction_with_callbacks"]) + self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5} + self.reg_df_train = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + ], + ["features", "label"], + ) + self.reg_df_test = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759), + ], + [ + "features", + "expected_prediction", + "expected_prediction_with_params", + "expected_prediction_with_callbacks", + ], + ) # Distributed section # Binary classification - self.cls_df_train_distributed = self.session.createDataFrame([ + self.cls_df_train_distributed = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), (Vectors.dense(4.0, 5.0, 6.0), 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1), - ] * 100, ["features", "label"]) - self.cls_df_test_distributed = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0, [0.9949826, 0.0050174]), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0050174, 0.9949826]), - (Vectors.dense(4.0, 5.0, 6.0), 0, [0.9949826, 0.0050174]), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0050174, 0.9949826]), - ], ["features", "expected_label", "expected_probability"]) + ] + * 100, + ["features", "label"], + ) + self.cls_df_test_distributed = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, [0.9949826, 0.0050174]), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0050174, 0.9949826]), + (Vectors.dense(4.0, 5.0, 6.0), 0, [0.9949826, 0.0050174]), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0050174, 0.9949826]), + ], + ["features", "expected_label", "expected_probability"], + ) # Binary classification with different num_estimators - self.cls_df_test_distributed_lower_estimators = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0, [0.9735, 0.0265]), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0265, 0.9735]), - (Vectors.dense(4.0, 5.0, 6.0), 0, [0.9735, 0.0265]), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0265 , 0.9735]), - ], ["features", "expected_label", "expected_probability"]) + self.cls_df_test_distributed_lower_estimators = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, [0.9735, 0.0265]), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0265, 0.9735]), + (Vectors.dense(4.0, 5.0, 6.0), 0, [0.9735, 0.0265]), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0265, 0.9735]), + ], + ["features", "expected_label", "expected_probability"], + ) # Multiclass classification - self.cls_df_train_distributed_multiclass = self.session.createDataFrame([ + self.cls_df_train_distributed_multiclass = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), (Vectors.dense(4.0, 5.0, 6.0), 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2), - ] * 100, ["features", "label"]) - self.cls_df_test_distributed_multiclass = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0, [ 4.294563, -2.449409, -2.449409 ]), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [-2.3796105, 3.669014, -2.449409 ]), - (Vectors.dense(4.0, 5.0, 6.0), 0, [ 4.294563, -2.449409, -2.449409 ]), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2, [-2.3796105, -2.449409, 3.669014 ]), - ], ["features", "expected_label", "expected_margins"]) + ] + * 100, + ["features", "label"], + ) + self.cls_df_test_distributed_multiclass = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, [4.294563, -2.449409, -2.449409]), + ( + Vectors.sparse(3, {1: 1.0, 2: 5.5}), + 1, + [-2.3796105, 3.669014, -2.449409], + ), + (Vectors.dense(4.0, 5.0, 6.0), 0, [4.294563, -2.449409, -2.449409]), + ( + Vectors.sparse(3, {1: 6.0, 2: 7.5}), + 2, + [-2.3796105, -2.449409, 3.669014], + ), + ], + ["features", "expected_label", "expected_margins"], + ) # Regression - self.reg_df_train_distributed = self.session.createDataFrame([ + self.reg_df_train_distributed = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), (Vectors.dense(4.0, 5.0, 6.0), 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2), - ] * 100, ["features", "label"]) - self.reg_df_test_distributed = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 1.533e-04), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 9.999e-01), - (Vectors.dense(4.0, 5.0, 6.0), 1.533e-04), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1.999e+00), - ], ["features", "expected_label"]) + ] + * 100, + ["features", "label"], + ) + self.reg_df_test_distributed = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 1.533e-04), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 9.999e-01), + (Vectors.dense(4.0, 5.0, 6.0), 1.533e-04), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1.999e00), + ], + ["features", "expected_label"], + ) # Adding weight and validation - self.clf_params_with_eval_dist = {'validationIndicatorCol': 'isVal','early_stopping_rounds': 1, 'eval_metric': 'logloss'} - self.clf_params_with_weight_dist = {'weightCol': 'weight'} - self.cls_df_train_distributed_with_eval_weight = self.session.createDataFrame([ + self.clf_params_with_eval_dist = { + "validationIndicatorCol": "isVal", + "early_stopping_rounds": 1, + "eval_metric": "logloss", + } + self.clf_params_with_weight_dist = {"weightCol": "weight"} + self.cls_df_train_distributed_with_eval_weight = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), - ] * 100, ["features", "label", "isVal", "weight"]) - self.cls_df_test_distributed_with_eval_weight = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), [0.9955, 0.0044], [0.9904, 0.0096], [0.9903, 0.0097]), - ], ["features", "expected_prob_with_weight", "expected_prob_with_eval", - "expected_prob_with_weight_and_eval"]) + ] + * 100, + ["features", "label", "isVal", "weight"], + ) + self.cls_df_test_distributed_with_eval_weight = self.session.createDataFrame( + [ + ( + Vectors.dense(1.0, 2.0, 3.0), + [0.9955, 0.0044], + [0.9904, 0.0096], + [0.9903, 0.0097], + ), + ], + [ + "features", + "expected_prob_with_weight", + "expected_prob_with_eval", + "expected_prob_with_weight_and_eval", + ], + ) self.clf_best_score_eval = 0.009677 self.clf_best_score_weight_and_eval = 0.006628 - self.reg_params_with_eval_dist = {'validationIndicatorCol': 'isVal','early_stopping_rounds': 1, 'eval_metric': 'rmse'} - self.reg_params_with_weight_dist = {'weightCol': 'weight'} - self.reg_df_train_distributed_with_eval_weight = self.session.createDataFrame([ + self.reg_params_with_eval_dist = { + "validationIndicatorCol": "isVal", + "early_stopping_rounds": 1, + "eval_metric": "rmse", + } + self.reg_params_with_weight_dist = {"weightCol": "weight"} + self.reg_df_train_distributed_with_eval_weight = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), - ] * 100, ["features", "label", "isVal", "weight"]) - self.reg_df_test_distributed_with_eval_weight = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 4.583e-05, 5.239e-05, 6.03e-05), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 9.9997e-01, 9.99947e-01, 9.9995e-01) - ], ["features", "expected_prediction_with_weight", "expected_prediction_with_eval", - "expected_prediction_with_weight_and_eval"]) + ] + * 100, + ["features", "label", "isVal", "weight"], + ) + self.reg_df_test_distributed_with_eval_weight = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 4.583e-05, 5.239e-05, 6.03e-05), + ( + Vectors.sparse(3, {1: 1.0, 2: 5.5}), + 9.9997e-01, + 9.99947e-01, + 9.9995e-01, + ), + ], + [ + "features", + "expected_prediction_with_weight", + "expected_prediction_with_eval", + "expected_prediction_with_weight_and_eval", + ], + ) self.reg_best_score_eval = 5.2e-05 self.reg_best_score_weight_and_eval = 4.9e-05 @@ -137,12 +218,14 @@ def test_regressor_basic_with_params(self): pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_params, atol=1e-3) + np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) ) def test_callbacks(self): from xgboost.callback import LearningRateScheduler + path = os.path.join(self.tempdir, str(uuid.uuid4())) def custom_learning_rate(boosting_round): @@ -160,8 +243,9 @@ def custom_learning_rate(boosting_round): pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_callbacks, atol=1e-3) + np.isclose( + row.prediction, row.expected_prediction_with_callbacks, atol=1e-3 + ) ) def test_classifier_distributed_basic(self): @@ -169,9 +253,10 @@ def test_classifier_distributed_basic(self): model = classifier.fit(self.cls_df_train_distributed) pred_result = model.transform(self.cls_df_test_distributed).collect() for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) - self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) + self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) + self.assertTrue( + np.allclose(row.expected_probability, row.probability, atol=1e-3) + ) def test_classifier_distributed_multiclass(self): # There is no built-in multiclass option for external storage @@ -179,9 +264,10 @@ def test_classifier_distributed_multiclass(self): model = classifier.fit(self.cls_df_train_distributed_multiclass) pred_result = model.transform(self.cls_df_test_distributed_multiclass).collect() for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) - self.assertTrue(np.allclose(row.expected_margins, row.rawPrediction, atol=1e-3)) + self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) + self.assertTrue( + np.allclose(row.expected_margins, row.rawPrediction, atol=1e-3) + ) def test_regressor_distributed_basic(self): regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100) @@ -193,18 +279,23 @@ def test_regressor_distributed_basic(self): @unittest.skip def test_check_use_gpu_param(self): # Classifier - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, use_gpu=True) - self.assertTrue(hasattr(classifier, 'use_gpu')) + classifier = XgboostClassifier( + num_workers=self.n_workers, n_estimators=100, use_gpu=True + ) + self.assertTrue(hasattr(classifier, "use_gpu")) self.assertTrue(classifier.getOrDefault(classifier.use_gpu)) clf_model = classifier.fit(self.cls_df_train_distributed) pred_result = clf_model.transform(self.cls_df_test_distributed).collect() for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) - self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) - - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_gpu=True) - self.assertTrue(hasattr(regressor, 'use_gpu')) + self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) + self.assertTrue( + np.allclose(row.expected_probability, row.probability, atol=1e-3) + ) + + regressor = XgboostRegressor( + num_workers=self.n_workers, n_estimators=100, use_gpu=True + ) + self.assertTrue(hasattr(regressor, "use_gpu")) self.assertTrue(regressor.getOrDefault(regressor.use_gpu)) model = regressor.fit(self.reg_df_train_distributed) pred_result = model.transform(self.reg_df_test_distributed).collect() @@ -213,76 +304,143 @@ def test_check_use_gpu_param(self): def test_classifier_distributed_weight_eval(self): # with weight - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_weight_dist) + classifier = XgboostClassifier( + num_workers=self.n_workers, + n_estimators=100, + **self.clf_params_with_weight_dist + ) model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + pred_result = model.transform( + self.cls_df_test_distributed_with_eval_weight + ).collect() for row in pred_result: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_weight, atol=1e-3)) + self.assertTrue( + np.allclose(row.probability, row.expected_prob_with_weight, atol=1e-3) + ) # with eval only - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_eval_dist) + classifier = XgboostClassifier( + num_workers=self.n_workers, + n_estimators=100, + **self.clf_params_with_eval_dist + ) model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + pred_result = model.transform( + self.cls_df_test_distributed_with_eval_weight + ).collect() for row in pred_result: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_eval) + self.assertTrue( + np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3) + ) + self.assertEqual( + float(model.get_booster().attributes()["best_score"]), + self.clf_best_score_eval, + ) # with both weight and eval - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_eval_dist, **self.clf_params_with_weight_dist) + classifier = XgboostClassifier( + num_workers=self.n_workers, + n_estimators=100, + **self.clf_params_with_eval_dist, + **self.clf_params_with_weight_dist + ) model = classifier.fit(self.cls_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.cls_df_test_distributed_with_eval_weight).collect() + pred_result = model.transform( + self.cls_df_test_distributed_with_eval_weight + ).collect() for row in pred_result: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_weight_and_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.clf_best_score_weight_and_eval) + self.assertTrue( + np.allclose( + row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3 + ) + ) + self.assertEqual( + float(model.get_booster().attributes()["best_score"]), + self.clf_best_score_weight_and_eval, + ) def test_regressor_distributed_weight_eval(self): # with weight - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, **self.reg_params_with_weight_dist) + regressor = XgboostRegressor( + num_workers=self.n_workers, + n_estimators=100, + **self.reg_params_with_weight_dist + ) model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + pred_result = model.transform( + self.reg_df_test_distributed_with_eval_weight + ).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_weight, atol=1e-3)) + np.isclose( + row.prediction, row.expected_prediction_with_weight, atol=1e-3 + ) + ) # with eval only - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, **self.reg_params_with_eval_dist) + regressor = XgboostRegressor( + num_workers=self.n_workers, + n_estimators=100, + **self.reg_params_with_eval_dist + ) model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + pred_result = model.transform( + self.reg_df_test_distributed_with_eval_weight + ).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_eval) + np.isclose(row.prediction, row.expected_prediction_with_eval, atol=1e-3) + ) + self.assertEqual( + float(model.get_booster().attributes()["best_score"]), + self.reg_best_score_eval, + ) # with both weight and eval - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100, use_external_storage=False, **self.reg_params_with_eval_dist, **self.reg_params_with_weight_dist) + regressor = XgboostRegressor( + num_workers=self.n_workers, + n_estimators=100, + use_external_storage=False, + **self.reg_params_with_eval_dist, + **self.reg_params_with_weight_dist + ) model = regressor.fit(self.reg_df_train_distributed_with_eval_weight) - pred_result = model.transform(self.reg_df_test_distributed_with_eval_weight).collect() + pred_result = model.transform( + self.reg_df_test_distributed_with_eval_weight + ).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_weight_and_eval, atol=1e-3)) - self.assertEqual(float(model.get_booster().attributes()["best_score"]), self.reg_best_score_weight_and_eval) + np.isclose( + row.prediction, + row.expected_prediction_with_weight_and_eval, + atol=1e-3, + ) + ) + self.assertEqual( + float(model.get_booster().attributes()["best_score"]), + self.reg_best_score_weight_and_eval, + ) def test_num_estimators(self): classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=10) model = classifier.fit(self.cls_df_train_distributed) - pred_result = model.transform(self.cls_df_test_distributed_lower_estimators).collect() + pred_result = model.transform( + self.cls_df_test_distributed_lower_estimators + ).collect() print(pred_result) for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, - row.prediction, atol=1e-3)) - self.assertTrue(np.allclose(row.expected_probability, row.probability, atol=1e-3)) + self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) + self.assertTrue( + np.allclose(row.expected_probability, row.probability, atol=1e-3) + ) def test_distributed_params(self): classifier = XgboostClassifier(num_workers=self.n_workers, max_depth=7) model = classifier.fit(self.cls_df_train_distributed) - self.assertTrue(hasattr(classifier, 'max_depth')) + self.assertTrue(hasattr(classifier, "max_depth")) self.assertEqual(classifier.getOrDefault(classifier.max_depth), 7) booster_config = json.loads(model.get_booster().save_config()) - max_depth = booster_config["learner"]["gradient_booster"]["updater"]["grow_histmaker"]["train_param"]["max_depth"] + max_depth = booster_config["learner"]["gradient_booster"]["updater"][ + "grow_histmaker" + ]["train_param"]["max_depth"] self.assertEqual(int(max_depth), 7) def test_repartition(self): @@ -300,6 +458,8 @@ def test_repartition(self): self.assertFalse(classifier._repartition_needed(good_repartitioned)) # Now testing if force_repartition returns True regardless of whether the data is well partitioned - classifier = XgboostClassifier(num_workers=self.n_workers, force_repartition=True) + classifier = XgboostClassifier( + num_workers=self.n_workers, force_repartition=True + ) good_repartitioned = basic.repartition(self.n_workers) self.assertTrue(classifier._repartition_needed(good_repartitioned)) diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index fb8673e72cf4..6d8d8b1eb791 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -4,13 +4,19 @@ import numpy as np from pyspark.ml import Pipeline, PipelineModel -from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ - MulticlassClassificationEvaluator +from pyspark.ml.evaluation import ( + BinaryClassificationEvaluator, + MulticlassClassificationEvaluator, +) from pyspark.ml.linalg import Vectors from pyspark.ml.tuning import CrossValidator, ParamGridBuilder -from xgboost.spark import (XgboostClassifier, XgboostClassifierModel, - XgboostRegressor, XgboostRegressorModel) +from xgboost.spark import ( + XgboostClassifier, + XgboostClassifierModel, + XgboostRegressor, + XgboostRegressorModel, +) from .utils_test import SparkTestCase from xgboost import XGBClassifier, XGBRegressor @@ -18,9 +24,8 @@ class XgboostLocalTest(SparkTestCase): - def setUp(self): - logging.getLogger().setLevel('INFO') + logging.getLogger().setLevel("INFO") random.seed(2020) # The following code use xgboost python library to train xgb model and predict. @@ -43,16 +48,26 @@ def setUp(self): # >>> reg2.fit(X, y) # >>> reg2.predict(X, ntree_limit=5) # array([0.22185266, 0.77814734], dtype=float32) - self.reg_params = {'max_depth': 5, 'n_estimators': 10, 'ntree_limit': 5} - self.reg_df_train = self.session.createDataFrame([ + self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5} + self.reg_df_train = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1) - ], ["features", "label"]) - self.reg_df_test = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759) - ], ["features", "expected_prediction", "expected_prediction_with_params", - "expected_prediction_with_callbacks"]) + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + ], + ["features", "label"], + ) + self.reg_df_test = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759), + ], + [ + "features", + "expected_prediction", + "expected_prediction_with_params", + "expected_prediction_with_callbacks", + ], + ) # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) # >>> y = np.array([0, 1]) @@ -70,25 +85,46 @@ def setUp(self): # >>> cl2.predict_proba(X) # array([[0.27574146, 0.72425854 ], # [0.27574146, 0.72425854 ]], dtype=float32) - self.cls_params = {'max_depth': 5, 'n_estimators': 10, 'scale_pos_weight': 4} + self.cls_params = {"max_depth": 5, "n_estimators": 10, "scale_pos_weight": 4} cls_df_train_data = [ (Vectors.dense(1.0, 2.0, 3.0), 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1) + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), ] self.cls_df_train = self.session.createDataFrame( - cls_df_train_data, ["features", "label"]) + cls_df_train_data, ["features", "label"] + ) self.cls_df_train_large = self.session.createDataFrame( - cls_df_train_data * 100, ["features", "label"]) - self.cls_df_test = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0, [0.5, 0.5], 1, [0.27574146, 0.72425854]), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, [0.5, 0.5], 1, [0.27574146, 0.72425854]) - ], ["features", - "expected_prediction", "expected_probability", - "expected_prediction_with_params", "expected_probability_with_params"]) + cls_df_train_data * 100, ["features", "label"] + ) + self.cls_df_test = self.session.createDataFrame( + [ + ( + Vectors.dense(1.0, 2.0, 3.0), + 0, + [0.5, 0.5], + 1, + [0.27574146, 0.72425854], + ), + ( + Vectors.sparse(3, {1: 1.0, 2: 5.5}), + 0, + [0.5, 0.5], + 1, + [0.27574146, 0.72425854], + ), + ], + [ + "features", + "expected_prediction", + "expected_probability", + "expected_prediction_with_params", + "expected_probability_with_params", + ], + ) # kwargs test (using the above data, train, we get the same results) - self.cls_params_kwargs = {'tree_method': 'approx', 'sketch_eps':0.03} + self.cls_params_kwargs = {"tree_method": "approx", "sketch_eps": 0.03} # >>> X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]]) # >>> y = np.array([0, 0, 1, 2]) @@ -103,12 +139,17 @@ def setUp(self): (Vectors.dense(-1.0, -2.0, 1.0), 2), ] self.multi_cls_df_train = self.session.createDataFrame( - multi_cls_df_train_data, ["features", "label"]) + multi_cls_df_train_data, ["features", "label"] + ) self.multi_cls_df_train_large = self.session.createDataFrame( - multi_cls_df_train_data * 100, ["features", "label"]) - self.multi_cls_df_test = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), [0.5374, 0.2312, 0.2312]), - ], ["features", "expected_probability"]) + multi_cls_df_train_data * 100, ["features", "label"] + ) + self.multi_cls_df_test = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), [0.5374, 0.2312, 0.2312]), + ], + ["features", "expected_probability"], + ) # Test regressor with weight and eval set # >>> import numpy as np @@ -141,19 +182,32 @@ def setUp(self): # >>> array([0.03155671, 0.98874104,... # >>> reg3.best_score # 1.9970891552124017 - self.reg_df_train_with_eval_weight = self.session.createDataFrame([ + self.reg_df_train_with_eval_weight = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0), - ], ["features", "label", "isVal", "weight"]) - self.reg_params_with_eval = {'validationIndicatorCol': 'isVal', - 'early_stopping_rounds': 1, 'eval_metric': 'rmse'} - self.reg_df_test_with_eval_weight = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0.001068, 0.00088, 0.03155), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.000055, 0.9991, 0.9887) - ], ["features", "expected_prediction_with_weight", "expected_prediction_with_eval", - "expected_prediction_with_weight_and_eval"]) + ], + ["features", "label", "isVal", "weight"], + ) + self.reg_params_with_eval = { + "validationIndicatorCol": "isVal", + "early_stopping_rounds": 1, + "eval_metric": "rmse", + } + self.reg_df_test_with_eval_weight = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0.001068, 0.00088, 0.03155), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.000055, 0.9991, 0.9887), + ], + [ + "features", + "expected_prediction_with_weight", + "expected_prediction_with_eval", + "expected_prediction_with_weight_and_eval", + ], + ) self.reg_with_eval_best_score = 2.0 self.reg_with_eval_and_weight_best_score = 1.997 @@ -188,18 +242,36 @@ def setUp(self): # array([[0.3344962, 0.6655038],... # >>> cls3.best_score # 0.6365 - self.cls_df_train_with_eval_weight = self.session.createDataFrame([ + self.cls_df_train_with_eval_weight = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), - ], ["features", "label", "isVal", "weight"]) - self.cls_params_with_eval = {'validationIndicatorCol': 'isVal', - 'early_stopping_rounds': 1, 'eval_metric': 'logloss'} - self.cls_df_test_with_eval_weight = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], [0.5, 0.5], [0.3097, 0.6903]), - ], ["features", "expected_prob_with_weight", "expected_prob_with_eval", - "expected_prob_with_weight_and_eval"]) + ], + ["features", "label", "isVal", "weight"], + ) + self.cls_params_with_eval = { + "validationIndicatorCol": "isVal", + "early_stopping_rounds": 1, + "eval_metric": "logloss", + } + self.cls_df_test_with_eval_weight = self.session.createDataFrame( + [ + ( + Vectors.dense(1.0, 2.0, 3.0), + [0.3333, 0.6666], + [0.5, 0.5], + [0.3097, 0.6903], + ), + ], + [ + "features", + "expected_prob_with_weight", + "expected_prob_with_eval", + "expected_prob_with_weight_and_eval", + ], + ) self.cls_with_eval_best_score = 0.6931 self.cls_with_eval_and_weight_best_score = 0.6378 @@ -210,7 +282,7 @@ def setUp(self): # >>> w = np.array([1.0, 2.0, 1.0, 2.0]) # >>> y = np.array([0, 1, 0, 1]) # >>> base_margin = np.array([1,0,0,1]) - # + # # This is without the base margin # >>> cls1 = xgboost.XGBClassifier() # >>> cls1.fit(X, y, sample_weight=w) @@ -218,7 +290,7 @@ def setUp(self): # array([[0.3333333, 0.6666667]], dtype=float32) # >>> cls1.predict(np.array([[1.0, 2.0, 3.0]])) # array([1]) - # + # # This is with the same base margin for predict # >>> cls2 = xgboost.XGBClassifier() # >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin) @@ -234,44 +306,76 @@ def setUp(self): # array([[0.2252, 0.7747 ]], dtype=float32) # >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0]) # array([1]) - self.cls_df_train_without_base_margin = self.session.createDataFrame([ + self.cls_df_train_without_base_margin = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0), - ], ["features", "label", "weight"]) - self.cls_df_test_without_base_margin = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], 1), - ], ["features", "expected_prob_without_base_margin", "expected_prediction_without_base_margin"]) + ], + ["features", "label", "weight"], + ) + self.cls_df_test_without_base_margin = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], 1), + ], + [ + "features", + "expected_prob_without_base_margin", + "expected_prediction_without_base_margin", + ], + ) - self.cls_df_train_with_same_base_margin = self.session.createDataFrame([ + self.cls_df_train_with_same_base_margin = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0), (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), - ], ["features", "label", "weight", "baseMarginCol"]) - self.cls_df_test_with_same_base_margin = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 0, [0.4415, 0.5585], 1), - ], ["features", "baseMarginCol", "expected_prob_with_base_margin", "expected_prediction_with_base_margin"]) + ], + ["features", "label", "weight", "baseMarginCol"], + ) + self.cls_df_test_with_same_base_margin = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, [0.4415, 0.5585], 1), + ], + [ + "features", + "baseMarginCol", + "expected_prob_with_base_margin", + "expected_prediction_with_base_margin", + ], + ) - self.cls_df_train_with_different_base_margin = self.session.createDataFrame([ + self.cls_df_train_with_different_base_margin = self.session.createDataFrame( + [ (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0), (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), - ], ["features", "label", "weight", "baseMarginCol"]) - self.cls_df_test_with_different_base_margin = self.session.createDataFrame([ - (Vectors.dense(1.0, 2.0, 3.0), 1, [0.2252, 0.7747], 1), - ], ["features", "baseMarginCol", "expected_prob_with_base_margin", "expected_prediction_with_base_margin"]) + ], + ["features", "label", "weight", "baseMarginCol"], + ) + self.cls_df_test_with_different_base_margin = self.session.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 1, [0.2252, 0.7747], 1), + ], + [ + "features", + "baseMarginCol", + "expected_prob_with_base_margin", + "expected_prediction_with_base_margin", + ], + ) def get_local_tmp_dir(self): return "/tmp/xgboost_local_test/" + str(uuid.uuid4()) def test_regressor_params_basic(self): py_reg = XgboostRegressor() - self.assertTrue(hasattr(py_reg, 'n_estimators')) + self.assertTrue(hasattr(py_reg, "n_estimators")) self.assertEqual(py_reg.n_estimators.parent, py_reg.uid) - self.assertFalse(hasattr(py_reg, 'gpu_id')) + self.assertFalse(hasattr(py_reg, "gpu_id")) self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100) self.assertEqual(py_reg._get_xgb_model_creator()().n_estimators, 100) py_reg2 = XgboostRegressor(n_estimators=200) @@ -283,9 +387,9 @@ def test_regressor_params_basic(self): def test_classifier_params_basic(self): py_cls = XgboostClassifier() - self.assertTrue(hasattr(py_cls, 'n_estimators')) + self.assertTrue(hasattr(py_cls, "n_estimators")) self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) - self.assertFalse(hasattr(py_cls, 'gpu_id')) + self.assertFalse(hasattr(py_cls, "gpu_id")) self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100) self.assertEqual(py_cls._get_xgb_model_creator()().n_estimators, 100) py_cls2 = XgboostClassifier(n_estimators=200) @@ -297,15 +401,19 @@ def test_classifier_params_basic(self): def test_classifier_kwargs_basic(self): py_cls = XgboostClassifier(**self.cls_params_kwargs) - self.assertTrue(hasattr(py_cls, 'n_estimators')) + self.assertTrue(hasattr(py_cls, "n_estimators")) self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) - self.assertFalse(hasattr(py_cls, 'gpu_id')) - self.assertTrue(hasattr(py_cls, 'arbitraryParamsDict')) - expected_kwargs = {'sketch_eps':0.03} - self.assertEqual(py_cls.getOrDefault(py_cls.arbitraryParamsDict), expected_kwargs) + self.assertFalse(hasattr(py_cls, "gpu_id")) + self.assertTrue(hasattr(py_cls, "arbitraryParamsDict")) + expected_kwargs = {"sketch_eps": 0.03} + self.assertEqual( + py_cls.getOrDefault(py_cls.arbitraryParamsDict), expected_kwargs + ) self.assertTrue("sketch_eps" in py_cls._get_xgb_model_creator()().get_params()) # We want all of the new params to be in the .get_params() call and be an attribute of py_cls, but not of the actual model - self.assertTrue("arbitraryParamsDict" not in py_cls._get_xgb_model_creator()().get_params()) + self.assertTrue( + "arbitraryParamsDict" not in py_cls._get_xgb_model_creator()().get_params() + ) # Testing overwritten params py_cls = XgboostClassifier() @@ -322,16 +430,22 @@ def test_classifier_kwargs_basic(self): def test_param_value_converter(): py_cls = XgboostClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3)) # don't check by isintance(v, float) because for numpy scalar it will also return True - assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == 'float' - assert py_cls.getOrDefault(py_cls.arbitraryParamsDict)['sketch_eps'].__class__.__name__ \ - == 'float64' + assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == "float" + assert ( + py_cls.getOrDefault(py_cls.arbitraryParamsDict)[ + "sketch_eps" + ].__class__.__name__ + == "float64" + ) def test_regressor_basic(self): regressor = XgboostRegressor() model = regressor.fit(self.reg_df_train) pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: - self.assertTrue(np.isclose(row.prediction, row.expected_prediction, atol=1e-3)) + self.assertTrue( + np.isclose(row.prediction, row.expected_prediction, atol=1e-3) + ) def test_classifier_basic(self): classifier = XgboostClassifier() @@ -339,14 +453,18 @@ def test_classifier_basic(self): pred_result = model.transform(self.cls_df_test).collect() for row in pred_result: self.assertEqual(row.prediction, row.expected_prediction) - self.assertTrue(np.allclose(row.probability, row.expected_probability, rtol=1e-3)) + self.assertTrue( + np.allclose(row.probability, row.expected_probability, rtol=1e-3) + ) def test_multi_classifier(self): classifier = XgboostClassifier() model = classifier.fit(self.multi_cls_df_train) pred_result = model.transform(self.multi_cls_df_test).collect() for row in pred_result: - self.assertTrue(np.allclose(row.probability, row.expected_probability, rtol=1e-3)) + self.assertTrue( + np.allclose(row.probability, row.expected_probability, rtol=1e-3) + ) def _check_sub_dict_match(self, sub_dist, whole_dict): for k in sub_dist: @@ -355,42 +473,55 @@ def _check_sub_dict_match(self, sub_dist, whole_dict): def test_regressor_with_params(self): regressor = XgboostRegressor(**self.reg_params) - all_params = dict(**(regressor._gen_xgb_params_dict()), - **(regressor._gen_fit_params_dict()), - **(regressor._gen_predict_params_dict())) + all_params = dict( + **(regressor._gen_xgb_params_dict()), + **(regressor._gen_fit_params_dict()), + **(regressor._gen_predict_params_dict()), + ) self._check_sub_dict_match(self.reg_params, all_params) model = regressor.fit(self.reg_df_train) - all_params = dict(**(model._gen_xgb_params_dict()), - **(model._gen_fit_params_dict()), - **(model._gen_predict_params_dict())) + all_params = dict( + **(model._gen_xgb_params_dict()), + **(model._gen_fit_params_dict()), + **(model._gen_predict_params_dict()), + ) self._check_sub_dict_match(self.reg_params, all_params) pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_params, atol=1e-3) + np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) ) def test_classifier_with_params(self): classifier = XgboostClassifier(**self.cls_params) - all_params = dict(**(classifier._gen_xgb_params_dict()), - **(classifier._gen_fit_params_dict()), - **(classifier._gen_predict_params_dict())) + all_params = dict( + **(classifier._gen_xgb_params_dict()), + **(classifier._gen_fit_params_dict()), + **(classifier._gen_predict_params_dict()), + ) self._check_sub_dict_match(self.cls_params, all_params) model = classifier.fit(self.cls_df_train) - all_params = dict(**(model._gen_xgb_params_dict()), - **(model._gen_fit_params_dict()), - **(model._gen_predict_params_dict())) + all_params = dict( + **(model._gen_xgb_params_dict()), + **(model._gen_fit_params_dict()), + **(model._gen_predict_params_dict()), + ) self._check_sub_dict_match(self.cls_params, all_params) pred_result = model.transform(self.cls_df_test).collect() for row in pred_result: self.assertEqual(row.prediction, row.expected_prediction_with_params) - self.assertTrue(np.allclose(row.probability, row.expected_probability_with_params, rtol=1e-3)) + self.assertTrue( + np.allclose( + row.probability, row.expected_probability_with_params, rtol=1e-3 + ) + ) def test_regressor_model_save_load(self): - path = 'file:' + self.get_local_tmp_dir() + path = "file:" + self.get_local_tmp_dir() regressor = XgboostRegressor(**self.reg_params) model = regressor.fit(self.reg_df_train) model.save(path) @@ -402,13 +533,16 @@ def test_regressor_model_save_load(self): pred_result = loaded_model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, row.expected_prediction_with_params, atol=1e-3)) + np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + ) - with self.assertRaisesRegex(AssertionError, 'Expected class name'): + with self.assertRaisesRegex(AssertionError, "Expected class name"): XgboostClassifierModel.load(path) def test_classifier_model_save_load(self): - path = 'file:' + self.get_local_tmp_dir() + path = "file:" + self.get_local_tmp_dir() regressor = XgboostClassifier(**self.cls_params) model = regressor.fit(self.cls_df_train) model.save(path) @@ -420,9 +554,12 @@ def test_classifier_model_save_load(self): pred_result = loaded_model.transform(self.cls_df_test).collect() for row in pred_result: self.assertTrue( - np.allclose(row.probability, row.expected_probability_with_params, atol=1e-3)) + np.allclose( + row.probability, row.expected_probability_with_params, atol=1e-3 + ) + ) - with self.assertRaisesRegex(AssertionError, 'Expected class name'): + with self.assertRaisesRegex(AssertionError, "Expected class name"): XgboostRegressorModel.load(path) @staticmethod @@ -430,7 +567,7 @@ def _get_params_map(params_kv, estimator): return {getattr(estimator, k): v for k, v in params_kv.items()} def test_regressor_model_pipeline_save_load(self): - path = 'file:' + self.get_local_tmp_dir() + path = "file:" + self.get_local_tmp_dir() regressor = XgboostRegressor() pipeline = Pipeline(stages=[regressor]) pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor)) @@ -444,13 +581,18 @@ def test_regressor_model_pipeline_save_load(self): pred_result = loaded_model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, row.expected_prediction_with_params, atol=1e-3)) + np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + ) def test_classifier_model_pipeline_save_load(self): - path = 'file:' + self.get_local_tmp_dir() + path = "file:" + self.get_local_tmp_dir() classifier = XgboostClassifier() pipeline = Pipeline(stages=[classifier]) - pipeline = pipeline.copy(extra=self._get_params_map(self.cls_params, classifier)) + pipeline = pipeline.copy( + extra=self._get_params_map(self.cls_params, classifier) + ) model = pipeline.fit(self.cls_df_train) model.save(path) @@ -461,17 +603,28 @@ def test_classifier_model_pipeline_save_load(self): pred_result = loaded_model.transform(self.cls_df_test).collect() for row in pred_result: self.assertTrue( - np.allclose(row.probability, row.expected_probability_with_params, atol=1e-3)) + np.allclose( + row.probability, row.expected_probability_with_params, atol=1e-3 + ) + ) def test_classifier_with_cross_validator(self): xgb_classifer = XgboostClassifier() paramMaps = ParamGridBuilder().addGrid(xgb_classifer.max_depth, [1, 2]).build() - cvBin = CrossValidator(estimator=xgb_classifer, estimatorParamMaps=paramMaps, - evaluator=BinaryClassificationEvaluator(), seed=1) + cvBin = CrossValidator( + estimator=xgb_classifer, + estimatorParamMaps=paramMaps, + evaluator=BinaryClassificationEvaluator(), + seed=1, + ) cvBinModel = cvBin.fit(self.cls_df_train_large) cvBinModel.transform(self.cls_df_test) - cvMulti = CrossValidator(estimator=xgb_classifer, estimatorParamMaps=paramMaps, - evaluator=MulticlassClassificationEvaluator(), seed=1) + cvMulti = CrossValidator( + estimator=xgb_classifer, + estimatorParamMaps=paramMaps, + evaluator=MulticlassClassificationEvaluator(), + seed=1, + ) cvMultiModel = cvMulti.fit(self.multi_cls_df_train_large) cvMultiModel.transform(self.multi_cls_df_test) @@ -495,8 +648,9 @@ def custom_learning_rate(boosting_round): pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( - np.isclose(row.prediction, - row.expected_prediction_with_callbacks, atol=1e-3) + np.isclose( + row.prediction, row.expected_prediction_with_callbacks, atol=1e-3 + ) ) def test_train_with_initial_model(self): @@ -517,101 +671,190 @@ def test_train_with_initial_model(self): self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) def test_classifier_with_base_margin(self): - cls_without_base_margin = XgboostClassifier(weightCol = "weight") - model_without_base_margin = cls_without_base_margin.fit(self.cls_df_train_without_base_margin) - pred_result_without_base_margin = model_without_base_margin.transform(self.cls_df_test_without_base_margin).collect() + cls_without_base_margin = XgboostClassifier(weightCol="weight") + model_without_base_margin = cls_without_base_margin.fit( + self.cls_df_train_without_base_margin + ) + pred_result_without_base_margin = model_without_base_margin.transform( + self.cls_df_test_without_base_margin + ).collect() for row in pred_result_without_base_margin: - self.assertTrue(np.isclose(row.prediction, - row.expected_prediction_without_base_margin, atol=1e-3)) - self.assertTrue(np.allclose(row.probability, - row.expected_prob_without_base_margin, atol=1e-3)) - - cls_with_same_base_margin = XgboostClassifier(weightCol = "weight", baseMarginCol = "baseMarginCol") - model_with_same_base_margin = cls_with_same_base_margin.fit(self.cls_df_train_with_same_base_margin) - pred_result_with_same_base_margin = model_with_same_base_margin.transform(self.cls_df_test_with_same_base_margin).collect() + self.assertTrue( + np.isclose( + row.prediction, + row.expected_prediction_without_base_margin, + atol=1e-3, + ) + ) + self.assertTrue( + np.allclose( + row.probability, row.expected_prob_without_base_margin, atol=1e-3 + ) + ) + + cls_with_same_base_margin = XgboostClassifier( + weightCol="weight", baseMarginCol="baseMarginCol" + ) + model_with_same_base_margin = cls_with_same_base_margin.fit( + self.cls_df_train_with_same_base_margin + ) + pred_result_with_same_base_margin = model_with_same_base_margin.transform( + self.cls_df_test_with_same_base_margin + ).collect() for row in pred_result_with_same_base_margin: - self.assertTrue(np.isclose(row.prediction, - row.expected_prediction_with_base_margin, atol=1e-3)) - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_base_margin, atol=1e-3)) - - cls_with_different_base_margin = XgboostClassifier(weightCol = "weight", baseMarginCol = "baseMarginCol") - model_with_different_base_margin = cls_with_different_base_margin.fit(self.cls_df_train_with_different_base_margin) - pred_result_with_different_base_margin = model_with_different_base_margin.transform(self.cls_df_test_with_different_base_margin).collect() + self.assertTrue( + np.isclose( + row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 + ) + ) + self.assertTrue( + np.allclose( + row.probability, row.expected_prob_with_base_margin, atol=1e-3 + ) + ) + + cls_with_different_base_margin = XgboostClassifier( + weightCol="weight", baseMarginCol="baseMarginCol" + ) + model_with_different_base_margin = cls_with_different_base_margin.fit( + self.cls_df_train_with_different_base_margin + ) + pred_result_with_different_base_margin = ( + model_with_different_base_margin.transform( + self.cls_df_test_with_different_base_margin + ).collect() + ) for row in pred_result_with_different_base_margin: - self.assertTrue(np.isclose(row.prediction, - row.expected_prediction_with_base_margin, atol=1e-3)) - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_base_margin, atol=1e-3)) + self.assertTrue( + np.isclose( + row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 + ) + ) + self.assertTrue( + np.allclose( + row.probability, row.expected_prob_with_base_margin, atol=1e-3 + ) + ) def test_regressor_with_weight_eval(self): # with weight - regressor_with_weight = XgboostRegressor(weightCol='weight') - model_with_weight = regressor_with_weight.fit(self.reg_df_train_with_eval_weight) - pred_result_with_weight = model_with_weight \ - .transform(self.reg_df_test_with_eval_weight).collect() + regressor_with_weight = XgboostRegressor(weightCol="weight") + model_with_weight = regressor_with_weight.fit( + self.reg_df_train_with_eval_weight + ) + pred_result_with_weight = model_with_weight.transform( + self.reg_df_test_with_eval_weight + ).collect() for row in pred_result_with_weight: - self.assertTrue(np.isclose(row.prediction, - row.expected_prediction_with_weight, atol=1e-3)) + self.assertTrue( + np.isclose( + row.prediction, row.expected_prediction_with_weight, atol=1e-3 + ) + ) # with eval regressor_with_eval = XgboostRegressor(**self.reg_params_with_eval) model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight) - self.assertTrue(np.isclose(model_with_eval._xgb_sklearn_model.best_score, - self.reg_with_eval_best_score, atol=1e-3), - f"Expected best score: {self.reg_with_eval_best_score}, " - f"but get {model_with_eval._xgb_sklearn_model.best_score}") - pred_result_with_eval = model_with_eval \ - .transform(self.reg_df_test_with_eval_weight).collect() + self.assertTrue( + np.isclose( + model_with_eval._xgb_sklearn_model.best_score, + self.reg_with_eval_best_score, + atol=1e-3, + ), + f"Expected best score: {self.reg_with_eval_best_score}, " + f"but get {model_with_eval._xgb_sklearn_model.best_score}", + ) + pred_result_with_eval = model_with_eval.transform( + self.reg_df_test_with_eval_weight + ).collect() for row in pred_result_with_eval: - self.assertTrue(np.isclose(row.prediction, - row.expected_prediction_with_eval, atol=1e-3), - f"Expect prediction is {row.expected_prediction_with_eval}," - f"but get {row.prediction}") + self.assertTrue( + np.isclose( + row.prediction, row.expected_prediction_with_eval, atol=1e-3 + ), + f"Expect prediction is {row.expected_prediction_with_eval}," + f"but get {row.prediction}", + ) # with weight and eval regressor_with_weight_eval = XgboostRegressor( - weightCol='weight', **self.reg_params_with_eval) - model_with_weight_eval = regressor_with_weight_eval.fit(self.reg_df_train_with_eval_weight) - pred_result_with_weight_eval = model_with_weight_eval \ - .transform(self.reg_df_test_with_eval_weight).collect() - self.assertTrue(np.isclose(model_with_weight_eval._xgb_sklearn_model.best_score, - self.reg_with_eval_and_weight_best_score, atol=1e-3)) + weightCol="weight", **self.reg_params_with_eval + ) + model_with_weight_eval = regressor_with_weight_eval.fit( + self.reg_df_train_with_eval_weight + ) + pred_result_with_weight_eval = model_with_weight_eval.transform( + self.reg_df_test_with_eval_weight + ).collect() + self.assertTrue( + np.isclose( + model_with_weight_eval._xgb_sklearn_model.best_score, + self.reg_with_eval_and_weight_best_score, + atol=1e-3, + ) + ) for row in pred_result_with_weight_eval: - self.assertTrue(np.isclose(row.prediction, - row.expected_prediction_with_weight_and_eval, atol=1e-3)) + self.assertTrue( + np.isclose( + row.prediction, + row.expected_prediction_with_weight_and_eval, + atol=1e-3, + ) + ) def test_classifier_with_weight_eval(self): # with weight - classifier_with_weight = XgboostClassifier(weightCol='weight') - model_with_weight = classifier_with_weight.fit(self.cls_df_train_with_eval_weight) - pred_result_with_weight = model_with_weight \ - .transform(self.cls_df_test_with_eval_weight).collect() + classifier_with_weight = XgboostClassifier(weightCol="weight") + model_with_weight = classifier_with_weight.fit( + self.cls_df_train_with_eval_weight + ) + pred_result_with_weight = model_with_weight.transform( + self.cls_df_test_with_eval_weight + ).collect() for row in pred_result_with_weight: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_weight, atol=1e-3)) + self.assertTrue( + np.allclose(row.probability, row.expected_prob_with_weight, atol=1e-3) + ) # with eval classifier_with_eval = XgboostClassifier(**self.cls_params_with_eval) model_with_eval = classifier_with_eval.fit(self.cls_df_train_with_eval_weight) - self.assertTrue(np.isclose(model_with_eval._xgb_sklearn_model.best_score, - self.cls_with_eval_best_score, atol=1e-3)) - pred_result_with_eval = model_with_eval \ - .transform(self.cls_df_test_with_eval_weight).collect() + self.assertTrue( + np.isclose( + model_with_eval._xgb_sklearn_model.best_score, + self.cls_with_eval_best_score, + atol=1e-3, + ) + ) + pred_result_with_eval = model_with_eval.transform( + self.cls_df_test_with_eval_weight + ).collect() for row in pred_result_with_eval: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_eval, atol=1e-3)) + self.assertTrue( + np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3) + ) # with weight and eval # Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which # doesn't really indicate this working correctly. classifier_with_weight_eval = XgboostClassifier( - weightCol='weight', scale_pos_weight=4, **self.cls_params_with_eval) - model_with_weight_eval = classifier_with_weight_eval \ - .fit(self.cls_df_train_with_eval_weight) - pred_result_with_weight_eval = model_with_weight_eval \ - .transform(self.cls_df_test_with_eval_weight).collect() - self.assertTrue(np.isclose(model_with_weight_eval._xgb_sklearn_model.best_score, - self.cls_with_eval_and_weight_best_score, atol=1e-3)) + weightCol="weight", scale_pos_weight=4, **self.cls_params_with_eval + ) + model_with_weight_eval = classifier_with_weight_eval.fit( + self.cls_df_train_with_eval_weight + ) + pred_result_with_weight_eval = model_with_weight_eval.transform( + self.cls_df_test_with_eval_weight + ).collect() + self.assertTrue( + np.isclose( + model_with_weight_eval._xgb_sklearn_model.best_score, + self.cls_with_eval_and_weight_best_score, + atol=1e-3, + ) + ) for row in pred_result_with_weight_eval: - self.assertTrue(np.allclose(row.probability, - row.expected_prob_with_weight_and_eval, atol=1e-3)) + self.assertTrue( + np.allclose( + row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3 + ) + ) def test_num_workers_param(self): regressor = XgboostRegressor(num_workers=-1) @@ -637,10 +880,18 @@ def test_convert_to_model(self): reg_model = regressor.fit(self.reg_df_train) # Check that regardless of what booster, _convert_to_model converts to the correct class type - self.assertEqual(type(classifier._convert_to_model(clf_model.get_booster())), XGBClassifier) - self.assertEqual(type(classifier._convert_to_model(reg_model.get_booster())), XGBClassifier) - self.assertEqual(type(regressor._convert_to_model(clf_model.get_booster())), XGBRegressor) - self.assertEqual(type(regressor._convert_to_model(reg_model.get_booster())), XGBRegressor) + self.assertEqual( + type(classifier._convert_to_model(clf_model.get_booster())), XGBClassifier + ) + self.assertEqual( + type(classifier._convert_to_model(reg_model.get_booster())), XGBClassifier + ) + self.assertEqual( + type(regressor._convert_to_model(clf_model.get_booster())), XGBRegressor + ) + self.assertEqual( + type(regressor._convert_to_model(reg_model.get_booster())), XGBRegressor + ) def test_feature_importances(self): reg1 = XgboostRegressor(**self.reg_params) @@ -648,7 +899,6 @@ def test_feature_importances(self): booster = model.get_booster() self.assertEqual(model.get_feature_importances(), booster.get_score()) self.assertEqual( - model.get_feature_importances(importance_type='gain'), - booster.get_score(importance_type='gain') + model.get_feature_importances(importance_type="gain"), + booster.get_score(importance_type="gain"), ) - From 5d4122ca2812eea75c7c499ad1648b1e35507c26 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 26 Jun 2022 22:45:11 +0800 Subject: [PATCH 07/73] update Signed-off-by: Weichen Xu --- python-package/xgboost/spark/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py index 06e3499e65e6..98ea12c9bc18 100644 --- a/python-package/xgboost/spark/__init__.py +++ b/python-package/xgboost/spark/__init__.py @@ -1,6 +1,4 @@ -"""XGBoost: eXtreme Gradient Boosting library. - -Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md +"""PySpark XGBoost integration interface """ try: From 0424c1985c81dd8bd8bc2fd54bc4a4f00b5b5211 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 1 Jul 2022 17:44:18 +0800 Subject: [PATCH 08/73] refactor Signed-off-by: Weichen Xu --- python-package/xgboost/spark/__init__.py | 16 +-- python-package/xgboost/spark/core.py | 16 +-- python-package/xgboost/spark/estimator.py | 18 ++-- python-package/xgboost/spark/params.py | 42 ++++++++ python-package/xgboost/spark/utils.py | 73 +------------ .../test_spark/xgboost_local_cluster_test.py | 38 +++---- tests/python/test_spark/xgboost_local_test.py | 100 +++++++++--------- 7 files changed, 138 insertions(+), 165 deletions(-) create mode 100644 python-package/xgboost/spark/params.py diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py index 98ea12c9bc18..9d1e4b3c91b3 100644 --- a/python-package/xgboost/spark/__init__.py +++ b/python-package/xgboost/spark/__init__.py @@ -7,15 +7,15 @@ raise RuntimeError("xgboost spark python API requires pyspark package installed.") from .estimator import ( - XgboostClassifier, - XgboostClassifierModel, - XgboostRegressor, - XgboostRegressorModel, + SparkXGBClassifier, + SparkXGBClassifierModel, + SparkXGBRegressor, + SparkXGBRegressorModel, ) __all__ = [ - "XgboostClassifier", - "XgboostClassifierModel", - "XgboostRegressor", - "XgboostRegressorModel", + "SparkXGBClassifier", + "SparkXGBClassifierModel", + "SparkXGBRegressor", + "SparkXGBRegressorModel", ] diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 47f0d4e34904..1192db5c23e0 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -41,13 +41,15 @@ from .utils import ( _get_default_params_from_func, get_class_name, - HasArbitraryParamsDict, - HasBaseMarginCol, RabitContext, _get_rabit_args, _get_args_from_message_list, _get_spark_session, ) +from .params import ( + HasArbitraryParamsDict, + HasBaseMarginCol, +) from pyspark.ml.functions import array_to_vector, vector_to_array @@ -275,7 +277,7 @@ def _validate_params(self): ) -class _XgboostEstimator(Estimator, _XgboostParams, MLReadable, MLWritable): +class _SparkXGBEstimator(Estimator, _XgboostParams, MLReadable, MLWritable): def __init__(self): super().__init__() self._set_xgb_params_default() @@ -472,8 +474,6 @@ def _train_booster(pandas_df_iter): def _fit(self, dataset): self._validate_params() - # Unwrap the VectorUDT type column "feature" to 4 primitive columns: - # ['features.type', 'features.size', 'features.indices', 'features.values'] features_col = col(self.getOrDefault(self.featuresCol)) label_col = col(self.getOrDefault(self.labelCol)).alias("label") features_array_col = vector_to_array(features_col, dtype="float32").alias( @@ -583,7 +583,7 @@ def read(cls): return XgboostReader(cls) -class _XgboostModel(Model, _XgboostParams, MLReadable, MLWritable): +class _SparkXGBModel(Model, _XgboostParams, MLReadable, MLWritable): def __init__(self, xgb_sklearn_model=None): super().__init__() self._xgb_sklearn_model = xgb_sklearn_model @@ -628,7 +628,7 @@ def _transform(self, dataset): raise NotImplementedError() -class XgboostRegressorModel(_XgboostModel): +class SparkXGBRegressorModel(_SparkXGBModel): """ The model returned by :func:`xgboost.spark.XgboostRegressor.fit` @@ -687,7 +687,7 @@ def predict_udf_base_margin( return dataset.withColumn(predictionColName, pred_col) -class XgboostClassifierModel(_XgboostModel, HasProbabilityCol, HasRawPredictionCol): +class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol): """ The model returned by :func:`xgboost.spark.XgboostClassifier.fit` diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 804fd24950be..eb00e3518024 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,14 +1,14 @@ from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from xgboost import XGBClassifier, XGBRegressor from .core import ( - _XgboostEstimator, - XgboostClassifierModel, - XgboostRegressorModel, + _SparkXGBEstimator, + SparkXGBClassifierModel, + SparkXGBRegressorModel, _set_pyspark_xgb_cls_param_attrs, ) -class XgboostRegressor(_XgboostEstimator): +class SparkXGBRegressor(_SparkXGBEstimator): """ XgboostRegressor is a PySpark ML estimator. It implements the XGBoost regression algorithm based on XGBoost python library, and it can be used in PySpark Pipeline @@ -98,13 +98,13 @@ def _xgb_cls(cls): @classmethod def _pyspark_model_cls(cls): - return XgboostRegressorModel + return SparkXGBRegressorModel -_set_pyspark_xgb_cls_param_attrs(XgboostRegressor, XgboostRegressorModel) +_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel) -class XgboostClassifier(_XgboostEstimator, HasProbabilityCol, HasRawPredictionCol): +class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPredictionCol): """ XgboostClassifier is a PySpark ML estimator. It implements the XGBoost classification algorithm based on XGBoost python library, and it can be used in PySpark Pipeline @@ -202,7 +202,7 @@ def _xgb_cls(cls): @classmethod def _pyspark_model_cls(cls): - return XgboostClassifierModel + return SparkXGBClassifierModel -_set_pyspark_xgb_cls_param_attrs(XgboostClassifier, XgboostClassifierModel) +_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel) diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py new file mode 100644 index 000000000000..c61e785ba6ae --- /dev/null +++ b/python-package/xgboost/spark/params.py @@ -0,0 +1,42 @@ +from pyspark.ml.param.shared import Param, Params + + +class HasArbitraryParamsDict(Params): + """ + This is a Params based class that is extended by _XGBoostParams + and holds the variable to store the **kwargs parts of the XGBoost + input. + """ + + arbitraryParamsDict = Param( + Params._dummy(), + "arbitraryParamsDict", + "This parameter holds all of the user defined parameters that" + " the sklearn implementation of XGBoost can't recognize. " + "It is stored as a dictionary.", + ) + + def setArbitraryParamsDict(self, value): + return self._set(arbitraryParamsDict=value) + + def getArbitraryParamsDict(self, value): + return self.getOrDefault(self.arbitraryParamsDict) + + +class HasBaseMarginCol(Params): + """ + This is a Params based class that is extended by _XGBoostParams + and holds the variable to store the base margin column part of XGboost. + """ + + baseMarginCol = Param( + Params._dummy(), + "baseMarginCol", + "This stores the name for the column of the base margin", + ) + + def setBaseMarginCol(self, value): + return self._set(baseMarginCol=value) + + def getBaseMarginCol(self, value): + return self.getOrDefault(self.baseMarginCol) diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 5ad7b1ddbce1..132897d9ad32 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -7,7 +7,6 @@ from xgboost.tracker import RabitTracker import pyspark from pyspark.sql.session import SparkSession -from pyspark.ml.param.shared import Param, Params def get_class_name(cls): @@ -31,47 +30,6 @@ def _get_default_params_from_func(func, unsupported_set): return filtered_params_dict -class HasArbitraryParamsDict(Params): - """ - This is a Params based class that is extended by _XGBoostParams - and holds the variable to store the **kwargs parts of the XGBoost - input. - """ - - arbitraryParamsDict = Param( - Params._dummy(), - "arbitraryParamsDict", - "This parameter holds all of the user defined parameters that" - " the sklearn implementation of XGBoost can't recognize. " - "It is stored as a dictionary.", - ) - - def setArbitraryParamsDict(self, value): - return self._set(arbitraryParamsDict=value) - - def getArbitraryParamsDict(self, value): - return self.getOrDefault(self.arbitraryParamsDict) - - -class HasBaseMarginCol(Params): - """ - This is a Params based class that is extended by _XGBoostParams - and holds the variable to store the base margin column part of XGboost. - """ - - baseMarginCol = Param( - Params._dummy(), - "baseMarginCol", - "This stores the name for the column of the base margin", - ) - - def setBaseMarginCol(self, value): - return self._set(baseMarginCol=value) - - def getBaseMarginCol(self, value): - return self.getOrDefault(self.baseMarginCol) - - class RabitContext: """ A context controlling rabit initialization and finalization. @@ -94,7 +52,7 @@ def _start_tracker(context, n_workers): Start Rabit tracker with n_workers """ env = {"DMLC_NUM_WORKER": n_workers} - host = get_host_ip(context) + host = _get_host_ip(context) rabit_context = RabitTracker(host_ip=host, n_workers=n_workers) env.update(rabit_context.worker_envs()) rabit_context.start(n_workers) @@ -113,7 +71,7 @@ def _get_rabit_args(context, n_workers): return rabit_args -def get_host_ip(context): +def _get_host_ip(context): """ Gets the hostIP for Spark. This essentially gets the IP of the first worker. """ @@ -143,33 +101,6 @@ def _get_spark_session(): return SparkSession.builder.getOrCreate() -def _getConfBoolean(sqlContext, key, defaultValue): - """ - Get the conf "key" from the given sqlContext, - or return the default value if the conf is not set. - This expects the conf value to be a boolean or string; if the value is a string, - this checks for all capitalization patterns of "true" and "false" to match Scala. - - Parameters - ---------- - key: - string for conf name - """ - # Convert default value to str to avoid a Spark 2.3.1 + Python 3 bug: SPARK-25397 - val = sqlContext.getConf(key, str(defaultValue)) - # Convert val to str to handle unicode issues across Python 2 and 3. - lowercase_val = str(val.lower()) - if lowercase_val == "true": - return True - elif lowercase_val == "false": - return False - else: - raise Exception( - "_getConfBoolean expected a boolean conf value but found value of type {} " - "with value: {}".format(type(val), val) - ) - - def get_logger(name, level="INFO"): """Gets a logger by name, or creates and configures it for the first time.""" logger = logging.getLogger(name) diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index 7b4b4643a4aa..30a0f6db06b7 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -4,7 +4,7 @@ import numpy as np from pyspark.ml.linalg import Vectors -from xgboost.spark import XgboostClassifier, XgboostRegressor +from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from .utils_test import SparkLocalClusterTestCase from xgboost.spark.utils import _get_max_num_concurrent_tasks import json @@ -213,7 +213,7 @@ def setUp(self): self.reg_best_score_weight_and_eval = 4.9e-05 def test_regressor_basic_with_params(self): - regressor = XgboostRegressor(**self.reg_params) + regressor = SparkXGBRegressor(**self.reg_params) model = regressor.fit(self.reg_df_train) pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: @@ -232,12 +232,12 @@ def custom_learning_rate(boosting_round): return 1.0 / (boosting_round + 1) cb = [LearningRateScheduler(custom_learning_rate)] - regressor = XgboostRegressor(callbacks=cb) + regressor = SparkXGBRegressor(callbacks=cb) # Test the save/load of the estimator instead of the model, since # the callbacks param only exists in the estimator but not in the model regressor.save(path) - regressor = XgboostRegressor.load(path) + regressor = SparkXGBRegressor.load(path) model = regressor.fit(self.reg_df_train) pred_result = model.transform(self.reg_df_test).collect() @@ -249,7 +249,7 @@ def custom_learning_rate(boosting_round): ) def test_classifier_distributed_basic(self): - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100) + classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100) model = classifier.fit(self.cls_df_train_distributed) pred_result = model.transform(self.cls_df_test_distributed).collect() for row in pred_result: @@ -260,7 +260,7 @@ def test_classifier_distributed_basic(self): def test_classifier_distributed_multiclass(self): # There is no built-in multiclass option for external storage - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=100) + classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100) model = classifier.fit(self.cls_df_train_distributed_multiclass) pred_result = model.transform(self.cls_df_test_distributed_multiclass).collect() for row in pred_result: @@ -270,7 +270,7 @@ def test_classifier_distributed_multiclass(self): ) def test_regressor_distributed_basic(self): - regressor = XgboostRegressor(num_workers=self.n_workers, n_estimators=100) + regressor = SparkXGBRegressor(num_workers=self.n_workers, n_estimators=100) model = regressor.fit(self.reg_df_train_distributed) pred_result = model.transform(self.reg_df_test_distributed).collect() for row in pred_result: @@ -279,7 +279,7 @@ def test_regressor_distributed_basic(self): @unittest.skip def test_check_use_gpu_param(self): # Classifier - classifier = XgboostClassifier( + classifier = SparkXGBClassifier( num_workers=self.n_workers, n_estimators=100, use_gpu=True ) self.assertTrue(hasattr(classifier, "use_gpu")) @@ -292,7 +292,7 @@ def test_check_use_gpu_param(self): np.allclose(row.expected_probability, row.probability, atol=1e-3) ) - regressor = XgboostRegressor( + regressor = SparkXGBRegressor( num_workers=self.n_workers, n_estimators=100, use_gpu=True ) self.assertTrue(hasattr(regressor, "use_gpu")) @@ -304,7 +304,7 @@ def test_check_use_gpu_param(self): def test_classifier_distributed_weight_eval(self): # with weight - classifier = XgboostClassifier( + classifier = SparkXGBClassifier( num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_weight_dist @@ -319,7 +319,7 @@ def test_classifier_distributed_weight_eval(self): ) # with eval only - classifier = XgboostClassifier( + classifier = SparkXGBClassifier( num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_eval_dist @@ -338,7 +338,7 @@ def test_classifier_distributed_weight_eval(self): ) # with both weight and eval - classifier = XgboostClassifier( + classifier = SparkXGBClassifier( num_workers=self.n_workers, n_estimators=100, **self.clf_params_with_eval_dist, @@ -361,7 +361,7 @@ def test_classifier_distributed_weight_eval(self): def test_regressor_distributed_weight_eval(self): # with weight - regressor = XgboostRegressor( + regressor = SparkXGBRegressor( num_workers=self.n_workers, n_estimators=100, **self.reg_params_with_weight_dist @@ -377,7 +377,7 @@ def test_regressor_distributed_weight_eval(self): ) ) # with eval only - regressor = XgboostRegressor( + regressor = SparkXGBRegressor( num_workers=self.n_workers, n_estimators=100, **self.reg_params_with_eval_dist @@ -395,7 +395,7 @@ def test_regressor_distributed_weight_eval(self): self.reg_best_score_eval, ) # with both weight and eval - regressor = XgboostRegressor( + regressor = SparkXGBRegressor( num_workers=self.n_workers, n_estimators=100, use_external_storage=False, @@ -420,7 +420,7 @@ def test_regressor_distributed_weight_eval(self): ) def test_num_estimators(self): - classifier = XgboostClassifier(num_workers=self.n_workers, n_estimators=10) + classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=10) model = classifier.fit(self.cls_df_train_distributed) pred_result = model.transform( self.cls_df_test_distributed_lower_estimators @@ -433,7 +433,7 @@ def test_num_estimators(self): ) def test_distributed_params(self): - classifier = XgboostClassifier(num_workers=self.n_workers, max_depth=7) + classifier = SparkXGBClassifier(num_workers=self.n_workers, max_depth=7) model = classifier.fit(self.cls_df_train_distributed) self.assertTrue(hasattr(classifier, "max_depth")) self.assertEqual(classifier.getOrDefault(classifier.max_depth), 7) @@ -449,7 +449,7 @@ def test_repartition(self): # or poorly partitioned. We only want to repartition when the dataset # is poorly partitioned so _repartition_needed is true in those instances. - classifier = XgboostClassifier(num_workers=self.n_workers) + classifier = SparkXGBClassifier(num_workers=self.n_workers) basic = self.cls_df_train_distributed self.assertTrue(classifier._repartition_needed(basic)) bad_repartitioned = basic.repartition(self.n_workers + 1) @@ -458,7 +458,7 @@ def test_repartition(self): self.assertFalse(classifier._repartition_needed(good_repartitioned)) # Now testing if force_repartition returns True regardless of whether the data is well partitioned - classifier = XgboostClassifier( + classifier = SparkXGBClassifier( num_workers=self.n_workers, force_repartition=True ) good_repartitioned = basic.repartition(self.n_workers) diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 6d8d8b1eb791..c9baef62f9e8 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -12,10 +12,10 @@ from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from xgboost.spark import ( - XgboostClassifier, - XgboostClassifierModel, - XgboostRegressor, - XgboostRegressorModel, + SparkXGBClassifier, + SparkXGBClassifierModel, + SparkXGBRegressor, + SparkXGBRegressorModel, ) from .utils_test import SparkTestCase from xgboost import XGBClassifier, XGBRegressor @@ -372,13 +372,13 @@ def get_local_tmp_dir(self): return "/tmp/xgboost_local_test/" + str(uuid.uuid4()) def test_regressor_params_basic(self): - py_reg = XgboostRegressor() + py_reg = SparkXGBRegressor() self.assertTrue(hasattr(py_reg, "n_estimators")) self.assertEqual(py_reg.n_estimators.parent, py_reg.uid) self.assertFalse(hasattr(py_reg, "gpu_id")) self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100) self.assertEqual(py_reg._get_xgb_model_creator()().n_estimators, 100) - py_reg2 = XgboostRegressor(n_estimators=200) + py_reg2 = SparkXGBRegressor(n_estimators=200) self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200) self.assertEqual(py_reg2._get_xgb_model_creator()().n_estimators, 200) py_reg3 = py_reg2.copy({py_reg2.max_depth: 10}) @@ -386,13 +386,13 @@ def test_regressor_params_basic(self): self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10) def test_classifier_params_basic(self): - py_cls = XgboostClassifier() + py_cls = SparkXGBClassifier() self.assertTrue(hasattr(py_cls, "n_estimators")) self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) self.assertFalse(hasattr(py_cls, "gpu_id")) self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100) self.assertEqual(py_cls._get_xgb_model_creator()().n_estimators, 100) - py_cls2 = XgboostClassifier(n_estimators=200) + py_cls2 = SparkXGBClassifier(n_estimators=200) self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200) self.assertEqual(py_cls2._get_xgb_model_creator()().n_estimators, 200) py_cls3 = py_cls2.copy({py_cls2.max_depth: 10}) @@ -400,7 +400,7 @@ def test_classifier_params_basic(self): self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10) def test_classifier_kwargs_basic(self): - py_cls = XgboostClassifier(**self.cls_params_kwargs) + py_cls = SparkXGBClassifier(**self.cls_params_kwargs) self.assertTrue(hasattr(py_cls, "n_estimators")) self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) self.assertFalse(hasattr(py_cls, "gpu_id")) @@ -416,7 +416,7 @@ def test_classifier_kwargs_basic(self): ) # Testing overwritten params - py_cls = XgboostClassifier() + py_cls = SparkXGBClassifier() py_cls.setParams(x=1, y=2) py_cls.setParams(y=1, z=2) self.assertTrue("x" in py_cls._get_xgb_model_creator()().get_params()) @@ -428,7 +428,7 @@ def test_classifier_kwargs_basic(self): @staticmethod def test_param_value_converter(): - py_cls = XgboostClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3)) + py_cls = SparkXGBClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3)) # don't check by isintance(v, float) because for numpy scalar it will also return True assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == "float" assert ( @@ -439,7 +439,7 @@ def test_param_value_converter(): ) def test_regressor_basic(self): - regressor = XgboostRegressor() + regressor = SparkXGBRegressor() model = regressor.fit(self.reg_df_train) pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: @@ -448,7 +448,7 @@ def test_regressor_basic(self): ) def test_classifier_basic(self): - classifier = XgboostClassifier() + classifier = SparkXGBClassifier() model = classifier.fit(self.cls_df_train) pred_result = model.transform(self.cls_df_test).collect() for row in pred_result: @@ -458,7 +458,7 @@ def test_classifier_basic(self): ) def test_multi_classifier(self): - classifier = XgboostClassifier() + classifier = SparkXGBClassifier() model = classifier.fit(self.multi_cls_df_train) pred_result = model.transform(self.multi_cls_df_test).collect() for row in pred_result: @@ -472,7 +472,7 @@ def _check_sub_dict_match(self, sub_dist, whole_dict): self.assertEqual(sub_dist[k], whole_dict[k]) def test_regressor_with_params(self): - regressor = XgboostRegressor(**self.reg_params) + regressor = SparkXGBRegressor(**self.reg_params) all_params = dict( **(regressor._gen_xgb_params_dict()), **(regressor._gen_fit_params_dict()), @@ -496,7 +496,7 @@ def test_regressor_with_params(self): ) def test_classifier_with_params(self): - classifier = XgboostClassifier(**self.cls_params) + classifier = SparkXGBClassifier(**self.cls_params) all_params = dict( **(classifier._gen_xgb_params_dict()), **(classifier._gen_fit_params_dict()), @@ -522,10 +522,10 @@ def test_classifier_with_params(self): def test_regressor_model_save_load(self): path = "file:" + self.get_local_tmp_dir() - regressor = XgboostRegressor(**self.reg_params) + regressor = SparkXGBRegressor(**self.reg_params) model = regressor.fit(self.reg_df_train) model.save(path) - loaded_model = XgboostRegressorModel.load(path) + loaded_model = SparkXGBRegressorModel.load(path) self.assertEqual(model.uid, loaded_model.uid) for k, v in self.reg_params.items(): self.assertEqual(loaded_model.getOrDefault(k), v) @@ -539,14 +539,14 @@ def test_regressor_model_save_load(self): ) with self.assertRaisesRegex(AssertionError, "Expected class name"): - XgboostClassifierModel.load(path) + SparkXGBClassifierModel.load(path) def test_classifier_model_save_load(self): path = "file:" + self.get_local_tmp_dir() - regressor = XgboostClassifier(**self.cls_params) + regressor = SparkXGBClassifier(**self.cls_params) model = regressor.fit(self.cls_df_train) model.save(path) - loaded_model = XgboostClassifierModel.load(path) + loaded_model = SparkXGBClassifierModel.load(path) self.assertEqual(model.uid, loaded_model.uid) for k, v in self.cls_params.items(): self.assertEqual(loaded_model.getOrDefault(k), v) @@ -560,7 +560,7 @@ def test_classifier_model_save_load(self): ) with self.assertRaisesRegex(AssertionError, "Expected class name"): - XgboostRegressorModel.load(path) + SparkXGBRegressorModel.load(path) @staticmethod def _get_params_map(params_kv, estimator): @@ -568,7 +568,7 @@ def _get_params_map(params_kv, estimator): def test_regressor_model_pipeline_save_load(self): path = "file:" + self.get_local_tmp_dir() - regressor = XgboostRegressor() + regressor = SparkXGBRegressor() pipeline = Pipeline(stages=[regressor]) pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor)) model = pipeline.fit(self.reg_df_train) @@ -588,7 +588,7 @@ def test_regressor_model_pipeline_save_load(self): def test_classifier_model_pipeline_save_load(self): path = "file:" + self.get_local_tmp_dir() - classifier = XgboostClassifier() + classifier = SparkXGBClassifier() pipeline = Pipeline(stages=[classifier]) pipeline = pipeline.copy( extra=self._get_params_map(self.cls_params, classifier) @@ -609,7 +609,7 @@ def test_classifier_model_pipeline_save_load(self): ) def test_classifier_with_cross_validator(self): - xgb_classifer = XgboostClassifier() + xgb_classifer = SparkXGBClassifier() paramMaps = ParamGridBuilder().addGrid(xgb_classifer.max_depth, [1, 2]).build() cvBin = CrossValidator( estimator=xgb_classifer, @@ -637,12 +637,12 @@ def custom_learning_rate(boosting_round): return 1.0 / (boosting_round + 1) cb = [LearningRateScheduler(custom_learning_rate)] - regressor = XgboostRegressor(callbacks=cb) + regressor = SparkXGBRegressor(callbacks=cb) # Test the save/load of the estimator instead of the model, since # the callbacks param only exists in the estimator but not in the model regressor.save(path) - regressor = XgboostRegressor.load(path) + regressor = SparkXGBRegressor.load(path) model = regressor.fit(self.reg_df_train) pred_result = model.transform(self.reg_df_test).collect() @@ -655,14 +655,14 @@ def custom_learning_rate(boosting_round): def test_train_with_initial_model(self): path = self.get_local_tmp_dir() - reg1 = XgboostRegressor(**self.reg_params) + reg1 = SparkXGBRegressor(**self.reg_params) model = reg1.fit(self.reg_df_train) init_booster = model.get_booster() - reg2 = XgboostRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster) + reg2 = SparkXGBRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster) model21 = reg2.fit(self.reg_df_train) pred_res21 = model21.transform(self.reg_df_test).collect() reg2.save(path) - reg2 = XgboostRegressor.load(path) + reg2 = SparkXGBRegressor.load(path) self.assertTrue(reg2.getOrDefault(reg2.xgb_model) is not None) model22 = reg2.fit(self.reg_df_train) pred_res22 = model22.transform(self.reg_df_test).collect() @@ -671,7 +671,7 @@ def test_train_with_initial_model(self): self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) def test_classifier_with_base_margin(self): - cls_without_base_margin = XgboostClassifier(weightCol="weight") + cls_without_base_margin = SparkXGBClassifier(weightCol="weight") model_without_base_margin = cls_without_base_margin.fit( self.cls_df_train_without_base_margin ) @@ -692,7 +692,7 @@ def test_classifier_with_base_margin(self): ) ) - cls_with_same_base_margin = XgboostClassifier( + cls_with_same_base_margin = SparkXGBClassifier( weightCol="weight", baseMarginCol="baseMarginCol" ) model_with_same_base_margin = cls_with_same_base_margin.fit( @@ -713,7 +713,7 @@ def test_classifier_with_base_margin(self): ) ) - cls_with_different_base_margin = XgboostClassifier( + cls_with_different_base_margin = SparkXGBClassifier( weightCol="weight", baseMarginCol="baseMarginCol" ) model_with_different_base_margin = cls_with_different_base_margin.fit( @@ -738,7 +738,7 @@ def test_classifier_with_base_margin(self): def test_regressor_with_weight_eval(self): # with weight - regressor_with_weight = XgboostRegressor(weightCol="weight") + regressor_with_weight = SparkXGBRegressor(weightCol="weight") model_with_weight = regressor_with_weight.fit( self.reg_df_train_with_eval_weight ) @@ -752,7 +752,7 @@ def test_regressor_with_weight_eval(self): ) ) # with eval - regressor_with_eval = XgboostRegressor(**self.reg_params_with_eval) + regressor_with_eval = SparkXGBRegressor(**self.reg_params_with_eval) model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight) self.assertTrue( np.isclose( @@ -775,7 +775,7 @@ def test_regressor_with_weight_eval(self): f"but get {row.prediction}", ) # with weight and eval - regressor_with_weight_eval = XgboostRegressor( + regressor_with_weight_eval = SparkXGBRegressor( weightCol="weight", **self.reg_params_with_eval ) model_with_weight_eval = regressor_with_weight_eval.fit( @@ -802,7 +802,7 @@ def test_regressor_with_weight_eval(self): def test_classifier_with_weight_eval(self): # with weight - classifier_with_weight = XgboostClassifier(weightCol="weight") + classifier_with_weight = SparkXGBClassifier(weightCol="weight") model_with_weight = classifier_with_weight.fit( self.cls_df_train_with_eval_weight ) @@ -814,7 +814,7 @@ def test_classifier_with_weight_eval(self): np.allclose(row.probability, row.expected_prob_with_weight, atol=1e-3) ) # with eval - classifier_with_eval = XgboostClassifier(**self.cls_params_with_eval) + classifier_with_eval = SparkXGBClassifier(**self.cls_params_with_eval) model_with_eval = classifier_with_eval.fit(self.cls_df_train_with_eval_weight) self.assertTrue( np.isclose( @@ -833,7 +833,7 @@ def test_classifier_with_weight_eval(self): # with weight and eval # Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which # doesn't really indicate this working correctly. - classifier_with_weight_eval = XgboostClassifier( + classifier_with_weight_eval = SparkXGBClassifier( weightCol="weight", scale_pos_weight=4, **self.cls_params_with_eval ) model_with_weight_eval = classifier_with_weight_eval.fit( @@ -857,26 +857,26 @@ def test_classifier_with_weight_eval(self): ) def test_num_workers_param(self): - regressor = XgboostRegressor(num_workers=-1) + regressor = SparkXGBRegressor(num_workers=-1) self.assertRaises(ValueError, regressor._validate_params) - classifier = XgboostClassifier(num_workers=0) + classifier = SparkXGBClassifier(num_workers=0) self.assertRaises(ValueError, classifier._validate_params) def test_use_gpu_param(self): - classifier = XgboostClassifier(use_gpu=True, tree_method="exact") + classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact") self.assertRaises(ValueError, classifier._validate_params) - regressor = XgboostRegressor(use_gpu=True, tree_method="exact") + regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact") self.assertRaises(ValueError, regressor._validate_params) - regressor = XgboostRegressor(use_gpu=True, tree_method="gpu_hist") - regressor = XgboostRegressor(use_gpu=True) - classifier = XgboostClassifier(use_gpu=True, tree_method="gpu_hist") - classifier = XgboostClassifier(use_gpu=True) + regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist") + regressor = SparkXGBRegressor(use_gpu=True) + classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist") + classifier = SparkXGBClassifier(use_gpu=True) def test_convert_to_model(self): - classifier = XgboostClassifier() + classifier = SparkXGBClassifier() clf_model = classifier.fit(self.cls_df_train) - regressor = XgboostRegressor() + regressor = SparkXGBRegressor() reg_model = regressor.fit(self.reg_df_train) # Check that regardless of what booster, _convert_to_model converts to the correct class type @@ -894,7 +894,7 @@ def test_convert_to_model(self): ) def test_feature_importances(self): - reg1 = XgboostRegressor(**self.reg_params) + reg1 = SparkXGBRegressor(**self.reg_params) model = reg1.fit(self.reg_df_train) booster = model.get_booster() self.assertEqual(model.get_feature_importances(), booster.get_score()) From f8f33bdcff96b668867799ec4f46b66788803923 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 1 Jul 2022 18:52:46 +0800 Subject: [PATCH 09/73] update params code Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 58 ++++++++++++++------------ python-package/xgboost/spark/params.py | 6 +-- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 1192db5c23e0..da20d4cb80d1 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -62,7 +62,12 @@ "rawPredictionCol", "predictionCol", "probabilityCol", - "validationIndicatorCol" "baseMarginCol", + "validationIndicatorCol", + "baseMarginCol", + "arbitraryParamsDict", + "force_repartition", + "num_workers", + "use_gpu", ] _unsupported_xgb_params = [ @@ -82,8 +87,6 @@ "base_margin", # TODO } -_created_params = {"num_workers", "use_gpu"} - class _XgboostParams( HasFeaturesCol, @@ -124,13 +127,8 @@ def _xgb_cls(cls): raise NotImplementedError() def _get_xgb_model_creator(self): - arbitaryParamsDict = self.getOrDefault(self.getParam("arbitraryParamsDict")) - total_params = {**self._gen_xgb_params_dict(), **arbitaryParamsDict} - # Once we have already added all of the elements of kwargs, we can just remove it - del total_params["arbitraryParamsDict"] - for param in _created_params: - del total_params[param] - return get_xgb_model_creator(self._xgb_cls(), total_params) + xgb_params = self._gen_xgb_params_dict() + return get_xgb_model_creator(self._xgb_cls(), xgb_params) # Parameters for xgboost.XGBModel() @classmethod @@ -145,7 +143,6 @@ def _get_xgb_params_default(cls): def _set_xgb_params_default(self): filtered_params_dict = self._get_xgb_params_default() self._setDefault(**filtered_params_dict) - self._setDefault(**{"arbitraryParamsDict": {}}) def _gen_xgb_params_dict(self): xgb_params = {} @@ -157,12 +154,10 @@ def _gen_xgb_params_dict(self): for param in self.extractParamMap(): if param.name not in non_xgb_params: xgb_params[param.name] = self.getOrDefault(param) - return xgb_params - def _set_distributed_params(self): - self.set(self.num_workers, 1) - self.set(self.use_gpu, False) - self.set(self.force_repartition, False) + arbitraryParamsDict = self.getOrDefault(self.getParam("arbitraryParamsDict")) + xgb_params.update(arbitraryParamsDict) + return xgb_params # Parameters for xgboost.XGBModel().fit() @classmethod @@ -283,18 +278,28 @@ def __init__(self): self._set_xgb_params_default() self._set_fit_params_default() self._set_predict_params_default() - self._set_distributed_params() + # Note: The default value for arbitraryParamsDict must always be empty dict. + # For additional settings added into "arbitraryParamsDict" by default, + # they are added in `setParams`. + self._setDefault( + num_workers=1, + use_gpu=False, + force_repartition=False, + arbitraryParamsDict={} + ) def setParams(self, **kwargs): - _user_defined = {} + _extra_params = {} + if 'arbitraryParamsDict' in kwargs: + raise ValueError("Wrong param name: 'arbitraryParamsDict'.") + for k, v in kwargs.items(): if self.hasParam(k): self._set(**{str(k): v}) else: - _user_defined[k] = v - _defined_args = self.getOrDefault(self.getParam("arbitraryParamsDict")) - _defined_args.update(_user_defined) - self._set(**{"arbitraryParamsDict": _defined_args}) + _extra_params[k] = v + _existing_extra_params = self.getOrDefault(self.arbitraryParamsDict) + self._set(arbitraryParamsDict={**_existing_extra_params, **_extra_params}) @classmethod def _pyspark_model_cls(cls): @@ -360,11 +365,11 @@ def _repartition_needed(self, dataset): pass return True - def _get_distributed_config(self, dataset, params): + def _get_distributed_config(self, dataset, fit_params): """ This just gets the configuration params for distributed xgboost """ - + params = fit_params.copy() classification = self._xgb_cls() == XGBClassifier num_classes = int(dataset.select(countDistinct("label")).collect()[0][0]) if classification and num_classes == 2: @@ -381,9 +386,8 @@ def _get_distributed_config(self, dataset, params): # On open-source spark, we need get the gpu id from the task allocated gpu resources. params["gpu_id"] = 0 params["num_boost_round"] = self.getOrDefault(self.n_estimators) - xgb_params = self._gen_xgb_params_dict() - xgb_params.update(params) - return xgb_params + params.update(self._gen_xgb_params_dict()) + return params @classmethod def _get_dist_booster_params(cls, train_params): diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index c61e785ba6ae..24d9a2d52aba 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -11,9 +11,9 @@ class HasArbitraryParamsDict(Params): arbitraryParamsDict = Param( Params._dummy(), "arbitraryParamsDict", - "This parameter holds all of the user defined parameters that" - " the sklearn implementation of XGBoost can't recognize. " - "It is stored as a dictionary.", + "arbitraryParamsDict This parameter holds all of the additional parameters which are " + "not exposed as the the XGBoost Spark estimator params but can be recognized by " + "underlying XGBoost library. It is stored as a dictionary.", ) def setArbitraryParamsDict(self, value): From ba4787f6458d9cbb4a0d6b27e77e3283e8e18a40 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 1 Jul 2022 19:06:14 +0800 Subject: [PATCH 10/73] pyspark param alias Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 23 +++++++++++++++---- tests/python/test_spark/xgboost_local_test.py | 5 ++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index da20d4cb80d1..5eeaa46ffe27 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,5 +1,3 @@ -import shutil -import tempfile from typing import Iterator, Tuple import numpy as np import pandas as pd @@ -70,6 +68,17 @@ "use_gpu", ] +_pyspark_param_alias_map = { + "features_col": "featuresCol", + "label_col": "labelCol", + "weight_col": "weightCol", + "raw_prediction_ol": "rawPredictionCol", + "prediction_col": "predictionCol", + "probability_col": "probabilityCol", + "validation_indicator_col": "validationIndicatorCol", + "baseMarginCol": "baseMarginCol", +} + _unsupported_xgb_params = [ "gpu_id", # we have "use_gpu" pyspark param instead. ] @@ -294,6 +303,12 @@ def setParams(self, **kwargs): raise ValueError("Wrong param name: 'arbitraryParamsDict'.") for k, v in kwargs.items(): + if k in _pyspark_param_alias_map: + real_k = _pyspark_param_alias_map[k] + if real_k in kwargs: + raise ValueError(f"You should set only one of param '{k}' and '{real_k}'") + k = real_k + if self.hasParam(k): self._set(**{str(k): v}) else: @@ -365,7 +380,7 @@ def _repartition_needed(self, dataset): pass return True - def _get_distributed_config(self, dataset, fit_params): + def _get_distributed_train_params(self, dataset, fit_params): """ This just gets the configuration params for distributed xgboost """ @@ -423,7 +438,7 @@ def _fit_distributed( "values", col("values").cast(ArrayType(FloatType())) ) dataset = dataset.repartition(num_workers) - train_params = self._get_distributed_config(dataset, fit_params) + train_params = self._get_distributed_train_params(dataset, fit_params) def _train_booster(pandas_df_iter): """ diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index c9baef62f9e8..7377866c5d7b 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -426,6 +426,11 @@ def test_classifier_kwargs_basic(self): self.assertTrue("z" in py_cls._get_xgb_model_creator()().get_params()) self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["z"], 2) + def test_param_alias(self): + py_cls = SparkXGBClassifier(featuresCol="f1", label_col="l1") + self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1") + self.assertEqual(py_cls.getOrDefault(py_cls.labelCol), "l1") + @staticmethod def test_param_value_converter(): py_cls = SparkXGBClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3)) From 94474867a208184195d393f77a2d443556b2d29c Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 1 Jul 2022 20:54:30 +0800 Subject: [PATCH 11/73] fix Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 5eeaa46ffe27..99774ab060a7 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -384,7 +384,8 @@ def _get_distributed_train_params(self, dataset, fit_params): """ This just gets the configuration params for distributed xgboost """ - params = fit_params.copy() + params = self._gen_xgb_params_dict() + params.update(fit_params) classification = self._xgb_cls() == XGBClassifier num_classes = int(dataset.select(countDistinct("label")).collect()[0][0]) if classification and num_classes == 2: @@ -395,13 +396,14 @@ def _get_distributed_train_params(self, dataset, fit_params): else: params["objective"] = "reg:squarederror" + params["num_boost_round"] = self.getOrDefault(self.n_estimators) + if self.getOrDefault(self.use_gpu): params["tree_method"] = "gpu_hist" # TODO: fix this. This only works on databricks runtime. # On open-source spark, we need get the gpu id from the task allocated gpu resources. params["gpu_id"] = 0 - params["num_boost_round"] = self.getOrDefault(self.n_estimators) - params.update(self._gen_xgb_params_dict()) + return params @classmethod From 016b1b77e257c0397509ba7e98ccfd446fbd30af Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 1 Jul 2022 21:09:07 +0800 Subject: [PATCH 12/73] add gpu param check test Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 4 ++-- tests/python/test_spark/xgboost_local_test.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 99774ab060a7..2a425cd0027f 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -418,7 +418,7 @@ def _get_dist_booster_params(cls, train_params): return booster_params, kwargs_params def _fit_distributed( - self, xgb_model_creator, dataset, has_weight, has_validation, fit_params + self, dataset, has_weight, has_validation, fit_params ): """ Takes in the dataset, the other parameters, and produces a valid booster @@ -533,7 +533,7 @@ def _fit(self, dataset): if self.getOrDefault(self.num_workers) > 1: return self._fit_distributed( - xgb_model_creator, dataset, has_weight, has_validation, fit_params + dataset, has_weight, has_validation, fit_params ) # Note: fit_params will be pickled to remote, it may include `xgb_model` param diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 7377866c5d7b..fe20bfbb745d 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -431,6 +431,12 @@ def test_param_alias(self): self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1") self.assertEqual(py_cls.getOrDefault(py_cls.labelCol), "l1") + def test_gpu_param_setting(self): + py_cls = SparkXGBClassifier(use_gpu=True) + train_params = py_cls._get_distributed_train_params(self.cls_df_train, {}) + assert train_params["gpu_id"] == 0 + assert train_params["tree_method"] == "gpu_hist" + @staticmethod def test_param_value_converter(): py_cls = SparkXGBClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3)) From 72af029704139a9acf742e7469bb4a1fbb23f9a7 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 1 Jul 2022 21:24:42 +0800 Subject: [PATCH 13/73] update _repartition_needed --- python-package/xgboost/spark/core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 2a425cd0027f..af8e41572c94 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -347,17 +347,14 @@ def _convert_to_model(self, booster): else: return None # check if this else statement is needed. - def _query_plan_contains_valid_repartition(self, query_plan, num_partitions): + def _query_plan_contains_valid_repartition(self, query_plan): """ Returns true if the latest element in the logical plan is a valid repartition """ + # TODO: Improve the method start = query_plan.index("== Optimized Logical Plan ==") start += len("== Optimized Logical Plan ==") + 1 - num_workers = self.getOrDefault(self.num_workers) - if ( - query_plan[start : start + len("Repartition")] == "Repartition" - and num_workers == num_partitions - ): + if query_plan[start: start + len("Repartition")] == "Repartition": return True return False @@ -369,8 +366,10 @@ def _repartition_needed(self, dataset): """ if self.getOrDefault(self.force_repartition): return True + num_partitions = dataset.rdd.getNumPartitions() + if self.getOrDefault(self.num_workers) != num_partitions: + return True try: - num_partitions = dataset.rdd.getNumPartitions() query_plan = dataset._sc._jvm.PythonSQLUtils.explainString( dataset._jdf.queryExecution(), "extended" ) From 55fa0522725f1bfe20d5856d081820c4c49ece2b Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 1 Jul 2022 23:05:31 +0800 Subject: [PATCH 14/73] merge fit/fit_distributed --- python-package/xgboost/spark/core.py | 152 +++++------------- tests/python/test_spark/xgboost_local_test.py | 2 + 2 files changed, 40 insertions(+), 114 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index af8e41572c94..f2c9bcba86b6 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -76,7 +76,7 @@ "prediction_col": "predictionCol", "probability_col": "probabilityCol", "validation_indicator_col": "validationIndicatorCol", - "baseMarginCol": "baseMarginCol", + "base_margin_col": "baseMarginCol", } _unsupported_xgb_params = [ @@ -351,7 +351,6 @@ def _query_plan_contains_valid_repartition(self, query_plan): """ Returns true if the latest element in the logical plan is a valid repartition """ - # TODO: Improve the method start = query_plan.index("== Optimized Logical Plan ==") start += len("== Optimized Logical Plan ==") + 1 if query_plan[start: start + len("Repartition")] == "Repartition": @@ -416,12 +415,43 @@ def _get_dist_booster_params(cls, train_params): booster_params[key] = value return booster_params, kwargs_params - def _fit_distributed( - self, dataset, has_weight, has_validation, fit_params - ): - """ - Takes in the dataset, the other parameters, and produces a valid booster - """ + def _fit(self, dataset): + self._validate_params() + features_col = col(self.getOrDefault(self.featuresCol)) + label_col = col(self.getOrDefault(self.labelCol)).alias("label") + features_array_col = vector_to_array(features_col, dtype="float32").alias( + "values" + ) + select_cols = [features_array_col, label_col] + + has_weight = False + has_validation = False + has_base_margin = False + + if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol): + has_weight = True + select_cols.append(col(self.getOrDefault(self.weightCol)).alias("weight")) + + if self.isDefined(self.validationIndicatorCol) and self.getOrDefault( + self.validationIndicatorCol + ): + has_validation = True + select_cols.append( + col(self.getOrDefault(self.validationIndicatorCol)).alias( + "validationIndicator" + ) + ) + + if self.isDefined(self.baseMarginCol) and self.getOrDefault( + self.baseMarginCol): + # TODO: fix baseMargin support + has_base_margin = True + select_cols.append( + col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin")) + + dataset = dataset.select(*select_cols) + fit_params = self._gen_fit_params_dict() + num_workers = self.getOrDefault(self.num_workers) sc = _get_spark_session().sparkContext max_concurrent_tasks = _get_max_num_concurrent_tasks(sc) @@ -435,9 +465,6 @@ def _fit_distributed( ) if self._repartition_needed(dataset): - dataset = dataset.withColumn( - "values", col("values").cast(ArrayType(FloatType())) - ) dataset = dataset.repartition(num_workers) train_params = self._get_distributed_train_params(dataset, fit_params) @@ -492,109 +519,6 @@ def _train_booster(pandas_df_iter): result_xgb_model = self._convert_to_model(cloudpickle.loads(result_ser_booster)) return self._copyValues(self._create_pyspark_model(result_xgb_model)) - def _fit(self, dataset): - self._validate_params() - features_col = col(self.getOrDefault(self.featuresCol)) - label_col = col(self.getOrDefault(self.labelCol)).alias("label") - features_array_col = vector_to_array(features_col, dtype="float32").alias( - "values" - ) - select_cols = [features_array_col, label_col] - - has_weight = False - has_validation = False - has_base_margin = False - - if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol): - has_weight = True - select_cols.append(col(self.getOrDefault(self.weightCol)).alias("weight")) - - if self.isDefined(self.validationIndicatorCol) and self.getOrDefault( - self.validationIndicatorCol - ): - has_validation = True - select_cols.append( - col(self.getOrDefault(self.validationIndicatorCol)).alias( - "validationIndicator" - ) - ) - - if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): - has_base_margin = True - select_cols.append( - col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") - ) - - dataset = dataset.select(*select_cols) - # create local var `xgb_model_creator` to avoid pickle `self` object to remote worker - xgb_model_creator = self._get_xgb_model_creator() # pylint: disable=E1111 - fit_params = self._gen_fit_params_dict() - - if self.getOrDefault(self.num_workers) > 1: - return self._fit_distributed( - dataset, has_weight, has_validation, fit_params - ) - - # Note: fit_params will be pickled to remote, it may include `xgb_model` param - # which is used as initial model in training. The initial model will be a - # `Booster` instance which support pickling. - def train_func(pandas_df_iter): - xgb_model = xgb_model_creator() - train_val_data = prepare_train_val_data( - pandas_df_iter, has_weight, has_validation, has_base_margin - ) - # We don't need to handle callbacks param in fit_params specially. - # User need to ensure callbacks is pickle-able. - if has_validation: - ( - train_X, - train_y, - train_w, - train_base_margin, - val_X, - val_y, - val_w, - _, - ) = train_val_data - eval_set = [(val_X, val_y)] - sample_weight_eval_set = [val_w] - # base_margin_eval_set = [val_base_margin] <- the underline - # Note that on XGBoost 1.2.0, the above doesn't exist. - xgb_model.fit( - train_X, - train_y, - sample_weight=train_w, - base_margin=train_base_margin, - eval_set=eval_set, - sample_weight_eval_set=sample_weight_eval_set, - **fit_params, - ) - else: - train_X, train_y, train_w, train_base_margin = train_val_data - xgb_model.fit( - train_X, - train_y, - sample_weight=train_w, - base_margin=train_base_margin, - **fit_params, - ) - - ser_model_string = serialize_xgb_model(xgb_model) - yield pd.DataFrame(data={"model_string": [ser_model_string]}) - - # Train on 1 remote worker, return the string of the serialized model - result_ser_model_string = ( - dataset.repartition(1) - .mapInPandas(train_func, schema="model_string string") - .collect()[0][0] - ) - - # Load model - result_xgb_model = deserialize_xgb_model( - result_ser_model_string, xgb_model_creator - ) - return self._copyValues(self._create_pyspark_model(result_xgb_model)) - def write(self): return XgboostWriter(self) diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index fe20bfbb745d..1aeddb688eb5 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -3,6 +3,7 @@ import uuid import numpy as np +import unittest from pyspark.ml import Pipeline, PipelineModel from pyspark.ml.evaluation import ( BinaryClassificationEvaluator, @@ -681,6 +682,7 @@ def test_train_with_initial_model(self): for row1, row2 in zip(pred_res21, pred_res22): self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) + @unittest.skip def test_classifier_with_base_margin(self): cls_without_base_margin = SparkXGBClassifier(weightCol="weight") model_without_base_margin = cls_without_base_margin.fit( From 7d9c37d5e48e99e6c9cd4bbe5ee6eeeec9e090d0 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sat, 2 Jul 2022 11:25:26 +0800 Subject: [PATCH 15/73] fix base margin support Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 69 ++++++++++--------- python-package/xgboost/spark/data.py | 29 ++++---- tests/python/test_spark/xgboost_local_test.py | 31 +++------ 3 files changed, 64 insertions(+), 65 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index f2c9bcba86b6..c1ad8b800d10 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -24,7 +24,6 @@ from .utils import get_logger, _get_max_num_concurrent_tasks from .data import ( prepare_predict_data, - prepare_train_val_data, convert_partition_data_to_dmatrix, ) from .model import ( @@ -68,6 +67,13 @@ "use_gpu", ] +_sklearn_estimator_specific_params = [ + "enable_categorical", + "missing", + "n_estimators", + "use_label_encoder", +] + _pyspark_param_alias_map = { "features_col": "featuresCol", "label_col": "labelCol", @@ -157,6 +163,7 @@ def _gen_xgb_params_dict(self): xgb_params = {} non_xgb_params = ( set(_pyspark_specific_params) + | set(_sklearn_estimator_specific_params) | self._get_fit_params_default().keys() | self._get_predict_params_default().keys() ) @@ -444,7 +451,6 @@ def _fit(self, dataset): if self.isDefined(self.baseMarginCol) and self.getOrDefault( self.baseMarginCol): - # TODO: fix baseMargin support has_base_margin = True select_cols.append( col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin")) @@ -477,17 +483,6 @@ def _train_booster(pandas_df_iter): context = BarrierTaskContext.get() - dtrain, dval = None, [] - if has_validation: - dtrain, dval = convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation - ) - dval = [(dtrain, "training"), (dval, "validation")] - else: - dtrain = convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation - ) - booster_params, kwargs_params = self._get_dist_booster_params(train_params) context.barrier() _rabit_args = "" @@ -498,6 +493,18 @@ def _train_booster(pandas_df_iter): _rabit_args = _get_args_from_message_list(messages) evals_result = {} with RabitContext(_rabit_args, context): + dtrain, dval = None, [] + if has_validation: + dtrain, dval = convert_partition_data_to_dmatrix( + pandas_df_iter, has_weight, has_validation, has_base_margin + ) + # TODO: Question: do we need to add dtrain to dval list ? + dval = [(dtrain, "training"), (dval, "validation")] + else: + dtrain = convert_partition_data_to_dmatrix( + pandas_df_iter, has_weight, has_validation, has_base_margin + ) + booster = worker_train( params=booster_params, dtrain=dtrain, @@ -590,7 +597,7 @@ def _transform(self, dataset): predict_params = self._gen_predict_params_dict() @pandas_udf("double") - def predict_udf(iterator: Iterator[Tuple[pd.Series]]) -> Iterator[pd.Series]: + def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, _ = prepare_predict_data(iterator, False) # Note: In every spark job task, pandas UDF will run in separate python process @@ -601,7 +608,7 @@ def predict_udf(iterator: Iterator[Tuple[pd.Series]]) -> Iterator[pd.Series]: @pandas_udf("double") def predict_udf_base_margin( - iterator: Iterator[Tuple[pd.Series, pd.Series]] + iterator: Iterator[pd.DataFrame] ) -> Iterator[pd.Series]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, b_m = prepare_predict_data(iterator, True) @@ -611,20 +618,21 @@ def predict_udf_base_margin( preds = xgb_sklearn_model.predict(X, base_margin=b_m, **predict_params) yield pd.Series(preds) - features_col = col(self.getOrDefault(self.featuresCol)) - features_col = struct( - vector_to_array(features_col, dtype="float32").alias("values") - ) + features_col = vector_to_array( + col(self.getOrDefault(self.featuresCol)), dtype="float32" + ).alias("values") has_base_margin = False if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): has_base_margin = True if has_base_margin: - base_margin_col = col(self.getOrDefault(self.baseMarginCol)) - pred_col = predict_udf_base_margin(features_col, base_margin_col) + base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") + pred_col = predict_udf_base_margin( + struct(features_col, base_margin_col) + ) else: - pred_col = predict_udf(features_col) + pred_col = predict_udf(struct(features_col)) predictionColName = self.getOrDefault(self.predictionCol) @@ -651,7 +659,7 @@ def _transform(self, dataset): @pandas_udf( "rawPrediction array, prediction double, probability array" ) - def predict_udf(iterator: Iterator[Tuple[pd.Series]]) -> Iterator[pd.DataFrame]: + def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, _ = prepare_predict_data(iterator, False) # Note: In every spark job task, pandas UDF will run in separate python process @@ -688,7 +696,7 @@ def predict_udf(iterator: Iterator[Tuple[pd.Series]]) -> Iterator[pd.DataFrame]: "rawPrediction array, prediction double, probability array" ) def predict_udf_base_margin( - iterator: Iterator[Tuple[pd.Series, pd.Series]] + iterator: Iterator[pd.DataFrame] ) -> Iterator[pd.DataFrame]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, b_m = prepare_predict_data(iterator, True) @@ -722,20 +730,19 @@ def predict_udf_base_margin( } ) - features_col = col(self.getOrDefault(self.featuresCol)) - features_col = struct( - vector_to_array(features_col, dtype="float32").alias("values") - ) + features_col = vector_to_array( + col(self.getOrDefault(self.featuresCol)), dtype="float32" + ).alias("values") has_base_margin = False if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): has_base_margin = True if has_base_margin: - base_margin_col = col(self.getOrDefault(self.baseMarginCol)) - pred_struct = predict_udf_base_margin(features_col, base_margin_col) + base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") + pred_struct = predict_udf_base_margin(struct(features_col, base_margin_col)) else: - pred_struct = predict_udf(features_col) + pred_struct = predict_udf(struct(features_col)) pred_struct_col = "_prediction_struct" diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 3060d4dc9184..f93304e2cdeb 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -97,7 +97,7 @@ def _create_dmatrix_from_file(file_name, cache_name): def prepare_train_val_data( - data_iterator, has_weight, has_validation, has_fit_base_margin=False + data_iterator, has_weight, has_validation, has_fit_base_margin ): def gen_data_pdf(): for pdf in data_iterator: @@ -162,9 +162,6 @@ def _row_tuple_list_to_feature_matrix_y_w( # Process rows for pdf in data_iterator: - if type(pdf) == tuple: - pdf = pd.concat(list(pdf), axis=1, names=["values", "baseMargin"]) - if len(pdf) == 0: continue if train and has_validation: @@ -184,7 +181,7 @@ def _row_tuple_list_to_feature_matrix_y_w( if has_weight: weight_list.append(pdf["weight"].to_list()) if has_fit_base_margin or has_predict_base_margin: - base_margin_list.append(pdf.iloc[:, -1].to_list()) + base_margin_list.append(pdf["baseMargin"].to_list()) if has_validation: values_val_list.append(pdf_val["values"].to_list()) if train: @@ -192,7 +189,7 @@ def _row_tuple_list_to_feature_matrix_y_w( if has_weight: weight_val_list.append(pdf_val["weight"].to_list()) if has_fit_base_margin or has_predict_base_margin: - base_margin_val_list.append(pdf_val.iloc[:, -1].to_list()) + base_margin_val_list.append(pdf_val["baseMargin"].to_list()) # Construct feature_matrix if expected_feature_dims is None: @@ -264,17 +261,23 @@ def _process_data_iter( ) -def convert_partition_data_to_dmatrix(partition_data_iter, has_weight, has_validation): +def convert_partition_data_to_dmatrix( + partition_data_iter, has_weight, has_validation, has_base_margin +): # if we are not using external storage, we use the standard method of parsing data. train_val_data = prepare_train_val_data( - partition_data_iter, has_weight, has_validation + partition_data_iter, has_weight, has_validation, has_base_margin ) if has_validation: - train_X, train_y, train_w, _, val_X, val_y, val_w, _ = train_val_data - training_dmatrix = DMatrix(data=train_X, label=train_y, weight=train_w) - val_dmatrix = DMatrix(data=val_X, label=val_y, weight=val_w) + train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m = train_val_data + training_dmatrix = DMatrix( + data=train_X, label=train_y, weight=train_w, base_margin=train_b_m + ) + val_dmatrix = DMatrix(data=val_X, label=val_y, weight=val_w, base_margin=val_b_m) return training_dmatrix, val_dmatrix else: - train_X, train_y, train_w, _ = train_val_data - training_dmatrix = DMatrix(data=train_X, label=train_y, weight=train_w) + train_X, train_y, train_w, train_b_m = train_val_data + training_dmatrix = DMatrix( + data=train_X, label=train_y, weight=train_w, base_margin=train_b_m + ) return training_dmatrix diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 1aeddb688eb5..62c611fa2d6c 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -334,7 +334,7 @@ def setUp(self): (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), ], - ["features", "label", "weight", "baseMarginCol"], + ["features", "label", "weight", "base_margin"], ) self.cls_df_test_with_same_base_margin = self.session.createDataFrame( [ @@ -342,7 +342,7 @@ def setUp(self): ], [ "features", - "baseMarginCol", + "base_margin", "expected_prob_with_base_margin", "expected_prediction_with_base_margin", ], @@ -355,7 +355,7 @@ def setUp(self): (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), ], - ["features", "label", "weight", "baseMarginCol"], + ["features", "label", "weight", "base_margin"], ) self.cls_df_test_with_different_base_margin = self.session.createDataFrame( [ @@ -363,7 +363,7 @@ def setUp(self): ], [ "features", - "baseMarginCol", + "base_margin", "expected_prob_with_base_margin", "expected_prediction_with_base_margin", ], @@ -682,7 +682,6 @@ def test_train_with_initial_model(self): for row1, row2 in zip(pred_res21, pred_res22): self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) - @unittest.skip def test_classifier_with_base_margin(self): cls_without_base_margin = SparkXGBClassifier(weightCol="weight") model_without_base_margin = cls_without_base_margin.fit( @@ -699,14 +698,12 @@ def test_classifier_with_base_margin(self): atol=1e-3, ) ) - self.assertTrue( - np.allclose( - row.probability, row.expected_prob_without_base_margin, atol=1e-3 - ) + np.testing.assert_allclose( + row.probability, row.expected_prob_without_base_margin, atol=1e-3 ) cls_with_same_base_margin = SparkXGBClassifier( - weightCol="weight", baseMarginCol="baseMarginCol" + weightCol="weight", baseMarginCol="base_margin" ) model_with_same_base_margin = cls_with_same_base_margin.fit( self.cls_df_train_with_same_base_margin @@ -720,14 +717,10 @@ def test_classifier_with_base_margin(self): row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 ) ) - self.assertTrue( - np.allclose( - row.probability, row.expected_prob_with_base_margin, atol=1e-3 - ) - ) + np.testing.assert_allclose(row.probability, row.expected_prob_with_base_margin, atol=1e-3) cls_with_different_base_margin = SparkXGBClassifier( - weightCol="weight", baseMarginCol="baseMarginCol" + weightCol="weight", baseMarginCol="base_margin" ) model_with_different_base_margin = cls_with_different_base_margin.fit( self.cls_df_train_with_different_base_margin @@ -743,11 +736,7 @@ def test_classifier_with_base_margin(self): row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 ) ) - self.assertTrue( - np.allclose( - row.probability, row.expected_prob_with_base_margin, atol=1e-3 - ) - ) + np.testing.assert_allclose(row.probability, row.expected_prob_with_base_margin, atol=1e-3) def test_regressor_with_weight_eval(self): # with weight From 14392a49a7e85bd7ef76a4b9dd54240875040c8f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sat, 2 Jul 2022 11:31:33 +0800 Subject: [PATCH 16/73] remove dump file code Signed-off-by: Weichen Xu --- python-package/xgboost/spark/data.py | 91 ---------------------------- tests/python/test_spark/data_test.py | 71 ++-------------------- 2 files changed, 4 insertions(+), 158 deletions(-) diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index f93304e2cdeb..4d825843b339 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -2,100 +2,9 @@ from typing import Iterator import numpy as np import pandas as pd -from scipy.sparse import csr_matrix from xgboost import DMatrix -# Since sklearn's SVM converter doesn't address weights, this one does address weights: -def _dump_libsvm(features, labels, weights=None, external_storage_precision=5): - esp = external_storage_precision - lines = [] - - def gen_label_str(row_idx): - if weights is not None: - return "{label:.{esp}g}:{weight:.{esp}g}".format( - label=labels[row_idx], esp=esp, weight=weights[row_idx] - ) - else: - return "{label:.{esp}g}".format(label=labels[row_idx], esp=esp) - - def gen_feature_value_str(feature_idx, feature_val): - return "{idx:.{esp}g}:{value:.{esp}g}".format( - idx=feature_idx, esp=esp, value=feature_val - ) - - is_csr_matrix = isinstance(features, csr_matrix) - - for i in range(len(labels)): - current = [gen_label_str(i)] - if is_csr_matrix: - idx_start = features.indptr[i] - idx_end = features.indptr[i + 1] - for idx in range(idx_start, idx_end): - j = features.indices[idx] - val = features.data[idx] - current.append(gen_feature_value_str(j, val)) - else: - for j, val in enumerate(features[i]): - current.append(gen_feature_value_str(j, val)) - lines.append(" ".join(current) + "\n") - return lines - - -# This is the updated version that handles weights -def _stream_train_val_data( - features, labels, weights, main_file, external_storage_precision -): - lines = _dump_libsvm(features, labels, weights, external_storage_precision) - main_file.writelines(lines) - - -def _stream_data_into_libsvm_file( - data_iterator, has_weight, has_validation, file_prefix, external_storage_precision -): - # getting the file names for storage - train_file_name = file_prefix + "/data.txt.train" - train_file = open(train_file_name, "w") - if has_validation: - validation_file_name = file_prefix + "/data.txt.val" - validation_file = open(validation_file_name, "w") - - train_val_data = _process_data_iter( - data_iterator, train=True, has_weight=has_weight, has_validation=has_validation - ) - if has_validation: - train_X, train_y, train_w, _, val_X, val_y, val_w, _ = train_val_data - _stream_train_val_data( - train_X, train_y, train_w, train_file, external_storage_precision - ) - _stream_train_val_data( - val_X, val_y, val_w, validation_file, external_storage_precision - ) - else: - train_X, train_y, train_w, _ = train_val_data - _stream_train_val_data( - train_X, train_y, train_w, train_file, external_storage_precision - ) - - if has_validation: - train_file.close() - validation_file.close() - return train_file_name, validation_file_name - else: - train_file.close() - return train_file_name - - -def _create_dmatrix_from_file(file_name, cache_name): - if os.path.exists(cache_name): - os.remove(cache_name) - if os.path.exists(cache_name + ".row.page"): - os.remove(cache_name + ".row.page") - if os.path.exists(cache_name + ".sorted.col.page"): - os.remove(cache_name + ".sorted.col.page") - return DMatrix(file_name + "#" + cache_name) - - def prepare_train_val_data( data_iterator, has_weight, has_validation, has_fit_base_margin ): diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 1136e030e47a..98b99fa11784 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -2,12 +2,10 @@ import shutil import numpy as np import pandas as pd -from scipy.sparse import csr_matrix from xgboost.spark.data import ( _row_tuple_list_to_feature_matrix_y_w, convert_partition_data_to_dmatrix, - _dump_libsvm, ) from xgboost import DMatrix, XGBClassifier @@ -83,7 +81,7 @@ def row_tup_iter(data): "label": [1, 0] * 100, } output_dmatrix = convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=False, has_validation=False + [pd.DataFrame(data)], has_weight=False, has_validation=False, has_base_margin=False ) # You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using # the same classifier and making sure the outputs are equal @@ -101,7 +99,7 @@ def row_tup_iter(data): data["weight"] = [0.2, 0.8] * 100 output_dmatrix = convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=True, has_validation=False + [pd.DataFrame(data)], has_weight=True, has_validation=False, has_base_margin=False ) model.fit(expected_features, expected_labels, sample_weight=expected_weight) @@ -124,7 +122,7 @@ def test_external_storage(self): # Creating the dmatrix based on storage temporary_path = tempfile.mkdtemp() storage_dmatrix = convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=False, has_validation=False + [pd.DataFrame(data)], has_weight=False, has_validation=False, has_base_margin=False ) # Testing without weights @@ -142,7 +140,7 @@ def test_external_storage(self): temporary_path = tempfile.mkdtemp() storage_dmatrix = convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=True, has_validation=False + [pd.DataFrame(data)], has_weight=True, has_validation=False, has_base_margin=False ) normal_booster = worker_train({}, normal_dmatrix) @@ -151,64 +149,3 @@ def test_external_storage(self): storage_preds = storage_booster.predict(test_dmatrix) self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3)) shutil.rmtree(temporary_path) - - def test_dump_libsvm(self): - num_features = 3 - features_test_list = [ - [[1, 2, 3], [0, 1, 5.5]], - csr_matrix(([1, 2, 3], [0, 2, 2], [0, 2, 3]), shape=(2, 3)), - ] - labels = [0, 1] - - for features in features_test_list: - if isinstance(features, csr_matrix): - features_array = features.toarray() - else: - features_array = features - # testing without weights - # The format should be label index:feature_value index:feature_value... - # Note: from initial testing, it seems all of the indices must be listed regardless of whether - # they exist or not - output = _dump_libsvm(features, labels) - for i, line in enumerate(output): - split_line = line.split(" ") - self.assertEqual(float(split_line[0]), labels[i]) - split_line = [elem.split(":") for elem in split_line[1:]] - loaded_feature = [0.0] * num_features - for split in split_line: - loaded_feature[int(split[0])] = float(split[1]) - self.assertListEqual(loaded_feature, list(features_array[i])) - - weights = [0.2, 0.8] - # testing with weights - # The format should be label:weight index:feature_value index:feature_value... - output = _dump_libsvm(features, labels, weights) - for i, line in enumerate(output): - split_line = line.split(" ") - split_line = [elem.split(":") for elem in split_line] - self.assertEqual(float(split_line[0][0]), labels[i]) - self.assertEqual(float(split_line[0][1]), weights[i]) - - split_line = split_line[1:] - loaded_feature = [0.0] * num_features - for split in split_line: - loaded_feature[int(split[0])] = float(split[1]) - self.assertListEqual(loaded_feature, list(features_array[i])) - - features = [ - [1.34234, 2.342321, 3.34322], - [0.344234, 1.123123, 5.534322], - [3.553423e10, 3.5632e10, 0.00000000000012345], - ] - features_prec = [ - [1.34, 2.34, 3.34], - [0.344, 1.12, 5.53], - [3.55e10, 3.56e10, 1.23e-13], - ] - labels = [0, 1] - output = _dump_libsvm(features, labels, external_storage_precision=3) - for i, line in enumerate(output): - split_line = line.split(" ") - self.assertEqual(float(split_line[0]), labels[i]) - split_line = [elem.split(":") for elem in split_line[1:]] - self.assertListEqual([float(v[1]) for v in split_line], features_prec[i]) From 5bee8314baf4fe385af419699b3368188a874532 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sat, 2 Jul 2022 12:18:39 +0800 Subject: [PATCH 17/73] fix verbose param --- python-package/xgboost/spark/core.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index c1ad8b800d10..30cb914fff91 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,4 +1,4 @@ -from typing import Iterator, Tuple +from typing import Iterator import numpy as np import pandas as pd from scipy.special import expit, softmax @@ -15,7 +15,6 @@ from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.util import MLReadable, MLWritable from pyspark.sql.functions import col, pandas_udf, countDistinct, struct -from pyspark.sql.types import ArrayType, FloatType from xgboost import XGBClassifier, XGBRegressor from xgboost.core import Booster import cloudpickle @@ -31,9 +30,7 @@ XgboostWriter, XgboostModelReader, XgboostModelWriter, - deserialize_xgb_model, get_xgb_model_creator, - serialize_xgb_model, ) from .utils import ( _get_default_params_from_func, @@ -68,7 +65,7 @@ ] _sklearn_estimator_specific_params = [ - "enable_categorical", + "enable_categorical", # TODO: support this "missing", "n_estimators", "use_label_encoder", @@ -385,12 +382,16 @@ def _repartition_needed(self, dataset): pass return True - def _get_distributed_train_params(self, dataset, fit_params): + def _get_distributed_train_params(self, dataset): """ This just gets the configuration params for distributed xgboost """ params = self._gen_xgb_params_dict() + fit_params = self._gen_fit_params_dict() + verbose_eval = fit_params.pop("verbose", None) + params.update(fit_params) + params["verbose_eval"] = verbose_eval classification = self._xgb_cls() == XGBClassifier num_classes = int(dataset.select(countDistinct("label")).collect()[0][0]) if classification and num_classes == 2: @@ -401,6 +402,7 @@ def _get_distributed_train_params(self, dataset, fit_params): else: params["objective"] = "reg:squarederror" + # TODO: support "num_parallel_tree" for random forest params["num_boost_round"] = self.getOrDefault(self.n_estimators) if self.getOrDefault(self.use_gpu): @@ -412,7 +414,7 @@ def _get_distributed_train_params(self, dataset, fit_params): return params @classmethod - def _get_dist_booster_params(cls, train_params): + def _get_xgb_train_call_args(cls, train_params): non_booster_params = _get_default_params_from_func(xgboost.train, {}) booster_params, kwargs_params = {}, {} for key, value in train_params.items(): @@ -456,7 +458,6 @@ def _fit(self, dataset): col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin")) dataset = dataset.select(*select_cols) - fit_params = self._gen_fit_params_dict() num_workers = self.getOrDefault(self.num_workers) sc = _get_spark_session().sparkContext @@ -472,7 +473,7 @@ def _fit(self, dataset): if self._repartition_needed(dataset): dataset = dataset.repartition(num_workers) - train_params = self._get_distributed_train_params(dataset, fit_params) + train_params = self._get_distributed_train_params(dataset) def _train_booster(pandas_df_iter): """ @@ -483,7 +484,8 @@ def _train_booster(pandas_df_iter): context = BarrierTaskContext.get() - booster_params, kwargs_params = self._get_dist_booster_params(train_params) + booster_params, train_call_kwargs_params = \ + self._get_xgb_train_call_args(train_params) context.barrier() _rabit_args = "" if context.partitionId() == 0: @@ -510,7 +512,7 @@ def _train_booster(pandas_df_iter): dtrain=dtrain, evals=dval, evals_result=evals_result, - **kwargs_params, + **train_call_kwargs_params, ) context.barrier() From a75ee881f50b5c8b9012be76c272db4a4ffff794 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sat, 2 Jul 2022 12:20:36 +0800 Subject: [PATCH 18/73] update _unsupported_xgb_params Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 30cb914fff91..42387bb39ed0 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -65,10 +65,8 @@ ] _sklearn_estimator_specific_params = [ - "enable_categorical", # TODO: support this "missing", "n_estimators", - "use_label_encoder", ] _pyspark_param_alias_map = { @@ -84,7 +82,10 @@ _unsupported_xgb_params = [ "gpu_id", # we have "use_gpu" pyspark param instead. + "enable_categorical", # TODO: support this + "use_label_encoder", ] + _unsupported_fit_params = { "sample_weight", # Supported by spark param weightCol # Supported by spark param weightCol # and validationIndicatorCol From d71e7e0ad14896dfbf944f4ce40ba2ecbe03b012 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 11:10:42 +0800 Subject: [PATCH 19/73] set nthread to be spark.task.cpus --- python-package/xgboost/spark/core.py | 39 ++++++++++++++-------------- python-package/xgboost/spark/data.py | 13 +++++++--- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 42387bb39ed0..1e6c9bcff84c 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -84,6 +84,8 @@ "gpu_id", # we have "use_gpu" pyspark param instead. "enable_categorical", # TODO: support this "use_label_encoder", + "n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead. + "nthread", # Ditto ] _unsupported_fit_params = { @@ -93,6 +95,7 @@ "sample_weight_eval_set", "base_margin", # Supported by spark param baseMarginCol } + _unsupported_predict_params = { # for classification, we can use rawPrediction as margin "output_margin", @@ -234,18 +237,6 @@ def _validate_params(self): f"It cannot be less than 1 [Default is 1]" ) - if self.getOrDefault(self.num_workers) > 1 and not self.getOrDefault( - self.use_gpu - ): - cpu_per_task = ( - _get_spark_session().sparkContext.getConf().get("spark.task.cpus") - ) - if cpu_per_task and int(cpu_per_task) > 1: - get_logger(self.__class__.__name__).warning( - f"You configured {cpu_per_task} CPU cores for each spark task, but in " - f"XGBoost training, every Spark task will only use one CPU core." - ) - if ( self.getOrDefault(self.force_repartition) and self.getOrDefault(self.num_workers) == 1 @@ -317,6 +308,10 @@ def setParams(self, **kwargs): if self.hasParam(k): self._set(**{str(k): v}) else: + if k in _unsupported_xgb_params or \ + k in _unsupported_fit_params or \ + k in _unsupported_predict_params: + raise ValueError(f"Unsupported param '{k}'.") _extra_params[k] = v _existing_extra_params = self.getOrDefault(self.arbitraryParamsDict) self._set(arbitraryParamsDict={**_existing_extra_params, **_extra_params}) @@ -416,10 +411,10 @@ def _get_distributed_train_params(self, dataset): @classmethod def _get_xgb_train_call_args(cls, train_params): - non_booster_params = _get_default_params_from_func(xgboost.train, {}) + xgb_train_default_args = _get_default_params_from_func(xgboost.train, {}) booster_params, kwargs_params = {}, {} for key, value in train_params.items(): - if key in non_booster_params: + if key in xgb_train_default_args: kwargs_params[key] = value else: booster_params[key] = value @@ -475,6 +470,13 @@ def _fit(self, dataset): if self._repartition_needed(dataset): dataset = dataset.repartition(num_workers) train_params = self._get_distributed_train_params(dataset) + booster_params, train_call_kwargs_params = \ + self._get_xgb_train_call_args(train_params) + + cpu_per_task = int( + _get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1") + ) + booster_params['nthread'] = cpu_per_task def _train_booster(pandas_df_iter): """ @@ -484,9 +486,6 @@ def _train_booster(pandas_df_iter): from pyspark import BarrierTaskContext context = BarrierTaskContext.get() - - booster_params, train_call_kwargs_params = \ - self._get_xgb_train_call_args(train_params) context.barrier() _rabit_args = "" if context.partitionId() == 0: @@ -499,13 +498,15 @@ def _train_booster(pandas_df_iter): dtrain, dval = None, [] if has_validation: dtrain, dval = convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation, has_base_margin + pandas_df_iter, has_weight, has_validation, has_base_margin, + cpu_per_task=cpu_per_task, ) # TODO: Question: do we need to add dtrain to dval list ? dval = [(dtrain, "training"), (dval, "validation")] else: dtrain = convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation, has_base_margin + pandas_df_iter, has_weight, has_validation, has_base_margin, + cpu_per_task=cpu_per_task, ) booster = worker_train( diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 4d825843b339..4aad679b89fe 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -171,7 +171,7 @@ def _process_data_iter( def convert_partition_data_to_dmatrix( - partition_data_iter, has_weight, has_validation, has_base_margin + partition_data_iter, has_weight, has_validation, has_base_margin, cpu_per_task=1 ): # if we are not using external storage, we use the standard method of parsing data. train_val_data = prepare_train_val_data( @@ -180,13 +180,18 @@ def convert_partition_data_to_dmatrix( if has_validation: train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m = train_val_data training_dmatrix = DMatrix( - data=train_X, label=train_y, weight=train_w, base_margin=train_b_m + data=train_X, label=train_y, weight=train_w, base_margin=train_b_m, + nthread=cpu_per_task, + ) + val_dmatrix = DMatrix( + data=val_X, label=val_y, weight=val_w, base_margin=val_b_m, + nthread=cpu_per_task, ) - val_dmatrix = DMatrix(data=val_X, label=val_y, weight=val_w, base_margin=val_b_m) return training_dmatrix, val_dmatrix else: train_X, train_y, train_w, train_b_m = train_val_data training_dmatrix = DMatrix( - data=train_X, label=train_y, weight=train_w, base_margin=train_b_m + data=train_X, label=train_y, weight=train_w, base_margin=train_b_m, + nthread=cpu_per_task, ) return training_dmatrix From 75cfe91cbd7c636c50a7e71be8b68a1ed4334883 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 14:17:29 +0800 Subject: [PATCH 20/73] support feature_types and feature_names --- python-package/xgboost/spark/core.py | 28 ++++++++++++++++++++++---- python-package/xgboost/spark/data.py | 9 +++++---- python-package/xgboost/spark/params.py | 4 ++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 1e6c9bcff84c..2bef2ac14397 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -64,9 +64,11 @@ "use_gpu", ] -_sklearn_estimator_specific_params = [ +_non_booster_params = [ "missing", "n_estimators", + "feature_types", + "feature_names", ] _pyspark_param_alias_map = { @@ -133,6 +135,16 @@ class _XgboostParams( + "Note: The auto repartitioning judgement is not fully accurate, so it is recommended" + "to have force_repartition be True.", ) + feature_names = Param( + Params._dummy(), + "feature_names", + "A list of str to specify feature names." + ) + feature_types = Param( + Params._dummy(), + "feature_names", + "A list of str to specify feature types." + ) @classmethod def _xgb_cls(cls): @@ -164,7 +176,7 @@ def _gen_xgb_params_dict(self): xgb_params = {} non_xgb_params = ( set(_pyspark_specific_params) - | set(_sklearn_estimator_specific_params) + | set(_non_booster_params) | self._get_fit_params_default().keys() | self._get_predict_params_default().keys() ) @@ -290,6 +302,8 @@ def __init__(self): num_workers=1, use_gpu=False, force_repartition=False, + feature_names=None, + feature_types=None, arbitraryParamsDict={} ) @@ -424,6 +438,7 @@ def _fit(self, dataset): self._validate_params() features_col = col(self.getOrDefault(self.featuresCol)) label_col = col(self.getOrDefault(self.labelCol)).alias("label") + features_array_col = vector_to_array(features_col, dtype="float32").alias( "values" ) @@ -476,6 +491,11 @@ def _fit(self, dataset): cpu_per_task = int( _get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1") ) + dmatrix_kwargs = { + "nthread": cpu_per_task, + "feature_types": self.getOrDefault(self.feature_types), + "feature_names": self.getOrDefault(self.feature_names), + } booster_params['nthread'] = cpu_per_task def _train_booster(pandas_df_iter): @@ -499,14 +519,14 @@ def _train_booster(pandas_df_iter): if has_validation: dtrain, dval = convert_partition_data_to_dmatrix( pandas_df_iter, has_weight, has_validation, has_base_margin, - cpu_per_task=cpu_per_task, + dmatrix_kwargs=dmatrix_kwargs, ) # TODO: Question: do we need to add dtrain to dval list ? dval = [(dtrain, "training"), (dval, "validation")] else: dtrain = convert_partition_data_to_dmatrix( pandas_df_iter, has_weight, has_validation, has_base_margin, - cpu_per_task=cpu_per_task, + dmatrix_kwargs=dmatrix_kwargs, ) booster = worker_train( diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 4aad679b89fe..bfd21834f4ee 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -171,8 +171,9 @@ def _process_data_iter( def convert_partition_data_to_dmatrix( - partition_data_iter, has_weight, has_validation, has_base_margin, cpu_per_task=1 + partition_data_iter, has_weight, has_validation, has_base_margin, dmatrix_kwargs=None ): + dmatrix_kwargs = dmatrix_kwargs or {} # if we are not using external storage, we use the standard method of parsing data. train_val_data = prepare_train_val_data( partition_data_iter, has_weight, has_validation, has_base_margin @@ -181,17 +182,17 @@ def convert_partition_data_to_dmatrix( train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m = train_val_data training_dmatrix = DMatrix( data=train_X, label=train_y, weight=train_w, base_margin=train_b_m, - nthread=cpu_per_task, + **dmatrix_kwargs, ) val_dmatrix = DMatrix( data=val_X, label=val_y, weight=val_w, base_margin=val_b_m, - nthread=cpu_per_task, + **dmatrix_kwargs, ) return training_dmatrix, val_dmatrix else: train_X, train_y, train_w, train_b_m = train_val_data training_dmatrix = DMatrix( data=train_X, label=train_y, weight=train_w, base_margin=train_b_m, - nthread=cpu_per_task, + **dmatrix_kwargs, ) return training_dmatrix diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 24d9a2d52aba..028968a874ad 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -19,7 +19,7 @@ class HasArbitraryParamsDict(Params): def setArbitraryParamsDict(self, value): return self._set(arbitraryParamsDict=value) - def getArbitraryParamsDict(self, value): + def getArbitraryParamsDict(self): return self.getOrDefault(self.arbitraryParamsDict) @@ -38,5 +38,5 @@ class HasBaseMarginCol(Params): def setBaseMarginCol(self, value): return self._set(baseMarginCol=value) - def getBaseMarginCol(self, value): + def getBaseMarginCol(self): return self.getOrDefault(self.baseMarginCol) From 7f68346ea01b5125cc049342bda0ecbd0d3748b1 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 15:18:15 +0800 Subject: [PATCH 21/73] update _repartition_needed --- python-package/xgboost/spark/core.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 2bef2ac14397..2a59ebaefa75 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -361,13 +361,21 @@ def _convert_to_model(self, booster): else: return None # check if this else statement is needed. - def _query_plan_contains_valid_repartition(self, query_plan): + def _query_plan_contains_valid_repartition(self, dataset): """ Returns true if the latest element in the logical plan is a valid repartition """ + num_partitions = dataset.rdd.getNumPartitions() + query_plan = dataset._sc._jvm.PythonSQLUtils.explainString( + dataset._jdf.queryExecution(), "extended" + ) start = query_plan.index("== Optimized Logical Plan ==") start += len("== Optimized Logical Plan ==") + 1 - if query_plan[start: start + len("Repartition")] == "Repartition": + num_workers = self.getOrDefault(self.num_workers) + if ( + query_plan[start : start + len("Repartition")] == "Repartition" + and num_workers == num_partitions + ): return True return False @@ -379,14 +387,8 @@ def _repartition_needed(self, dataset): """ if self.getOrDefault(self.force_repartition): return True - num_partitions = dataset.rdd.getNumPartitions() - if self.getOrDefault(self.num_workers) != num_partitions: - return True try: - query_plan = dataset._sc._jvm.PythonSQLUtils.explainString( - dataset._jdf.queryExecution(), "extended" - ) - if self._query_plan_contains_valid_repartition(query_plan, num_partitions): + if self._query_plan_contains_valid_repartition(dataset): return False except: # noqa: E722 pass From dfffe8ef396c126ca309f0751cef1d5e38b8cb86 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 16:10:29 +0800 Subject: [PATCH 22/73] support use array as features column --- python-package/xgboost/spark/core.py | 50 +++++++++++++++---- tests/python/test_spark/xgboost_local_test.py | 35 ++++++++++++- 2 files changed, 73 insertions(+), 12 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 2a59ebaefa75..5b4d71da5f12 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -46,6 +46,9 @@ ) from pyspark.ml.functions import array_to_vector, vector_to_array +from pyspark.sql.types import \ + ArrayType, DoubleType, FloatType, IntegerType, LongType, ShortType +from pyspark.ml.linalg import VectorUDT # Put pyspark specific params here, they won't be passed to XGBoost. # like `validationIndicatorCol`, `baseMarginCol` @@ -62,13 +65,13 @@ "force_repartition", "num_workers", "use_gpu", + "feature_types", + "feature_names", ] _non_booster_params = [ "missing", "n_estimators", - "feature_types", - "feature_names", ] _pyspark_param_alias_map = { @@ -289,6 +292,32 @@ def _validate_params(self): ) +def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name): + features_col_datatype = dataset.schema[features_col_name].dataType + features_col = col(features_col_name) + if isinstance(features_col_datatype, ArrayType): + if not isinstance( + features_col_datatype.elementType, + (DoubleType, FloatType, LongType, IntegerType, ShortType) + ): + raise ValueError( + "If feature column is array type, its elements must be number type." + ) + features_array_col = features_col.cast(ArrayType(FloatType())).alias("values") + elif isinstance(features_col_datatype, VectorUDT): + features_array_col = vector_to_array(features_col, dtype="float32").alias( + "values" + ) + else: + raise ValueError( + "feature column must be array type or `pyspark.ml.linalg.Vector` type, " + "if you want to use multiple numetric columns as features, please use " + "`pyspark.ml.transform.VectorAssembler` to assemble them into a vector " + "type column first." + ) + return features_array_col + + class _SparkXGBEstimator(Estimator, _XgboostParams, MLReadable, MLWritable): def __init__(self): super().__init__() @@ -438,11 +467,10 @@ def _get_xgb_train_call_args(cls, train_params): def _fit(self, dataset): self._validate_params() - features_col = col(self.getOrDefault(self.featuresCol)) label_col = col(self.getOrDefault(self.labelCol)).alias("label") - features_array_col = vector_to_array(features_col, dtype="float32").alias( - "values" + features_array_col = _validate_and_convert_feature_col_as_array_col( + dataset, self.getOrDefault(self.featuresCol) ) select_cols = [features_array_col, label_col] @@ -644,9 +672,9 @@ def predict_udf_base_margin( preds = xgb_sklearn_model.predict(X, base_margin=b_m, **predict_params) yield pd.Series(preds) - features_col = vector_to_array( - col(self.getOrDefault(self.featuresCol)), dtype="float32" - ).alias("values") + features_col = _validate_and_convert_feature_col_as_array_col( + dataset, self.getOrDefault(self.featuresCol) + ) has_base_margin = False if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): @@ -756,9 +784,9 @@ def predict_udf_base_margin( } ) - features_col = vector_to_array( - col(self.getOrDefault(self.featuresCol)), dtype="float32" - ).alias("values") + features_col = _validate_and_convert_feature_col_as_array_col( + dataset, self.getOrDefault(self.featuresCol) + ) has_base_margin = False if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 62c611fa2d6c..17f5f5ca9f5e 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -3,7 +3,8 @@ import uuid import numpy as np -import unittest +from pyspark.ml.functions import vector_to_array +from pyspark.sql import functions as spark_sql_func from pyspark.ml import Pipeline, PipelineModel from pyspark.ml.evaluation import ( BinaryClassificationEvaluator, @@ -904,3 +905,35 @@ def test_feature_importances(self): model.get_feature_importances(importance_type="gain"), booster.get_score(importance_type="gain"), ) + + def test_regressor_array_col_as_feature(self): + train_dataset = self.reg_df_train.withColumn( + "features", vector_to_array(spark_sql_func.col("features")) + ) + test_dataset = self.reg_df_test.withColumn( + "features", vector_to_array(spark_sql_func.col("features")) + ) + regressor = SparkXGBRegressor() + model = regressor.fit(train_dataset) + pred_result = model.transform(test_dataset).collect() + for row in pred_result: + self.assertTrue( + np.isclose(row.prediction, row.expected_prediction, atol=1e-3) + ) + + def test_classifier_array_col_as_feature(self): + train_dataset = self.cls_df_train.withColumn( + "features", vector_to_array(spark_sql_func.col("features")) + ) + test_dataset = self.cls_df_test.withColumn( + "features", vector_to_array(spark_sql_func.col("features")) + ) + classifier = SparkXGBClassifier() + model = classifier.fit(train_dataset) + + pred_result = model.transform(test_dataset).collect() + for row in pred_result: + self.assertEqual(row.prediction, row.expected_prediction) + self.assertTrue( + np.allclose(row.probability, row.expected_probability, rtol=1e-3) + ) From 9a8790985cd6993639fcd019c10ab3162b4dc84e Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 16:26:47 +0800 Subject: [PATCH 23/73] gpu mode support oss spark --- python-package/xgboost/spark/core.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 5b4d71da5f12..9812f07515a6 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -448,9 +448,6 @@ def _get_distributed_train_params(self, dataset): if self.getOrDefault(self.use_gpu): params["tree_method"] = "gpu_hist" - # TODO: fix this. This only works on databricks runtime. - # On open-source spark, we need get the gpu id from the task allocated gpu resources. - params["gpu_id"] = 0 return params @@ -527,6 +524,7 @@ def _fit(self, dataset): "feature_names": self.getOrDefault(self.feature_names), } booster_params['nthread'] = cpu_per_task + use_gpu = self.getOrDefault(self.use_gpu) def _train_booster(pandas_df_iter): """ @@ -537,6 +535,11 @@ def _train_booster(pandas_df_iter): context = BarrierTaskContext.get() context.barrier() + + if use_gpu: + # Set booster worker to use the first GPU allocated to the spark task. + booster_params["gpu_id"] = int(context._resources["gpu"].addresses[0].strip()) + _rabit_args = "" if context.partitionId() == 0: _rabit_args = str(_get_rabit_args(context, num_workers)) From 2aeaee8a5a46b867bf75ddc4dc073cd9fd811860 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 17:10:06 +0800 Subject: [PATCH 24/73] update comment --- python-package/xgboost/spark/core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 9812f07515a6..ee8730460de6 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -657,8 +657,6 @@ def _transform(self, dataset): def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: # deserialize model from ser_model_string, avoid pickling model to remote worker X, _, _, _ = prepare_predict_data(iterator, False) - # Note: In every spark job task, pandas UDF will run in separate python process - # so it is safe here to call the thread-unsafe model.predict method if len(X) > 0: preds = xgb_sklearn_model.predict(X, **predict_params) yield pd.Series(preds) From 10bf6b254a94bd45fa612a9fb24761410291ff2b Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 22:19:09 +0800 Subject: [PATCH 25/73] avoid call pd.Series.to_list --- python-package/xgboost/spark/data.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index bfd21834f4ee..bf258f744dfa 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -83,22 +83,22 @@ def _row_tuple_list_to_feature_matrix_y_w( num_feature_dims, expected_feature_dims ) - # TODO: Improve performance, avoid use python list + # Note: each element in `pdf["values"]` is an numpy array. values_list.append(pdf["values"].to_list()) if train: - label_list.append(pdf["label"].to_list()) + label_list.append(pdf["label"].to_numpy()) if has_weight: - weight_list.append(pdf["weight"].to_list()) + weight_list.append(pdf["weight"].to_numpy()) if has_fit_base_margin or has_predict_base_margin: - base_margin_list.append(pdf["baseMargin"].to_list()) + base_margin_list.append(pdf["baseMargin"].to_numpy()) if has_validation: values_val_list.append(pdf_val["values"].to_list()) if train: - label_val_list.append(pdf_val["label"].to_list()) + label_val_list.append(pdf_val["label"].to_numpy()) if has_weight: - weight_val_list.append(pdf_val["weight"].to_list()) + weight_val_list.append(pdf_val["weight"].to_numpy()) if has_fit_base_margin or has_predict_base_margin: - base_margin_val_list.append(pdf_val["baseMargin"].to_list()) + base_margin_val_list.append(pdf_val["baseMargin"].to_numpy()) # Construct feature_matrix if expected_feature_dims is None: From 53b1a5b2e66d731a5b7ec417ed111bf1f79d60c9 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 23:01:28 +0800 Subject: [PATCH 26/73] avoid data concatenation in predict_udf --- python-package/xgboost/spark/core.py | 155 +++++++++------------------ python-package/xgboost/spark/data.py | 11 -- 2 files changed, 53 insertions(+), 113 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index ee8730460de6..181677caa6dd 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -22,7 +22,6 @@ from xgboost.training import train as worker_train from .utils import get_logger, _get_max_num_concurrent_tasks from .data import ( - prepare_predict_data, convert_partition_data_to_dmatrix, ) from .model import ( @@ -105,7 +104,7 @@ # for classification, we can use rawPrediction as margin "output_margin", "validate_features", # TODO - "base_margin", # TODO + "base_margin", # Use pyspark baseMarginCol param instead. } @@ -653,39 +652,28 @@ def _transform(self, dataset): xgb_sklearn_model = self._xgb_sklearn_model predict_params = self._gen_predict_params_dict() - @pandas_udf("double") - def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: - # deserialize model from ser_model_string, avoid pickling model to remote worker - X, _, _, _ = prepare_predict_data(iterator, False) - if len(X) > 0: - preds = xgb_sklearn_model.predict(X, **predict_params) - yield pd.Series(preds) + has_base_margin = False + if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): + has_base_margin = True + base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") @pandas_udf("double") - def predict_udf_base_margin( - iterator: Iterator[pd.DataFrame] - ) -> Iterator[pd.Series]: - # deserialize model from ser_model_string, avoid pickling model to remote worker - X, _, _, b_m = prepare_predict_data(iterator, True) - # Note: In every spark job task, pandas UDF will run in separate python process - # so it is safe here to call the thread-unsafe model.predict method - if len(X) > 0: - preds = xgb_sklearn_model.predict(X, base_margin=b_m, **predict_params) - yield pd.Series(preds) + def predict_udf(input_data: pd.DataFrame) -> pd.Series: + X = np.array(input_data["values"].tolist()) + if has_base_margin: + base_margin = input_data["baseMargin"].to_numpy() + else: + base_margin = None + + preds = xgb_sklearn_model.predict(X, base_margin=base_margin, **predict_params) + return pd.Series(preds) features_col = _validate_and_convert_feature_col_as_array_col( dataset, self.getOrDefault(self.featuresCol) ) - has_base_margin = False - if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): - has_base_margin = True - if has_base_margin: - base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") - pred_col = predict_udf_base_margin( - struct(features_col, base_margin_col) - ) + pred_col = predict_udf(struct(features_col, base_margin_col)) else: pred_col = predict_udf(struct(features_col)) @@ -711,91 +699,54 @@ def _transform(self, dataset): xgb_sklearn_model = self._xgb_sklearn_model predict_params = self._gen_predict_params_dict() - @pandas_udf( - "rawPrediction array, prediction double, probability array" - ) - def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: - # deserialize model from ser_model_string, avoid pickling model to remote worker - X, _, _, _ = prepare_predict_data(iterator, False) - # Note: In every spark job task, pandas UDF will run in separate python process - # so it is safe here to call the thread-unsafe model.predict method - if len(X) > 0: - margins = xgb_sklearn_model.predict( - X, output_margin=True, **predict_params - ) - if margins.ndim == 1: - # binomial case - classone_probs = expit(margins) - classzero_probs = 1.0 - classone_probs - raw_preds = np.vstack((-margins, margins)).transpose() - class_probs = np.vstack( - (classzero_probs, classone_probs) - ).transpose() - else: - # multinomial case - raw_preds = margins - class_probs = softmax(raw_preds, axis=1) - - # It seems that they use argmax of class probs, - # not of margin to get the prediction (Note: scala implementation) - preds = np.argmax(class_probs, axis=1) - yield pd.DataFrame( - data={ - "rawPrediction": pd.Series(raw_preds.tolist()), - "prediction": pd.Series(preds), - "probability": pd.Series(class_probs.tolist()), - } - ) + has_base_margin = False + if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): + has_base_margin = True + base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") @pandas_udf( "rawPrediction array, prediction double, probability array" ) - def predict_udf_base_margin( - iterator: Iterator[pd.DataFrame] - ) -> Iterator[pd.DataFrame]: - # deserialize model from ser_model_string, avoid pickling model to remote worker - X, _, _, b_m = prepare_predict_data(iterator, True) - # Note: In every spark job task, pandas UDF will run in separate python process - # so it is safe here to call the thread-unsafe model.predict method - if len(X) > 0: - margins = xgb_sklearn_model.predict( - X, base_margin=b_m, output_margin=True, **predict_params - ) - if margins.ndim == 1: - # binomial case - classone_probs = expit(margins) - classzero_probs = 1.0 - classone_probs - raw_preds = np.vstack((-margins, margins)).transpose() - class_probs = np.vstack( - (classzero_probs, classone_probs) - ).transpose() - else: - # multinomial case - raw_preds = margins - class_probs = softmax(raw_preds, axis=1) - - # It seems that they use argmax of class probs, - # not of margin to get the prediction (Note: scala implementation) - preds = np.argmax(class_probs, axis=1) - yield pd.DataFrame( - data={ - "rawPrediction": pd.Series(raw_preds.tolist()), - "prediction": pd.Series(preds), - "probability": pd.Series(class_probs.tolist()), - } - ) + def predict_udf(input_data: pd.DataFrame) -> pd.DataFrame: + X = np.array(input_data["values"].tolist()) + if has_base_margin: + base_margin = input_data["baseMargin"].to_numpy() + else: + base_margin = None + + margins = xgb_sklearn_model.predict( + X, base_margin=base_margin, output_margin=True, **predict_params + ) + if margins.ndim == 1: + # binomial case + classone_probs = expit(margins) + classzero_probs = 1.0 - classone_probs + raw_preds = np.vstack((-margins, margins)).transpose() + class_probs = np.vstack( + (classzero_probs, classone_probs) + ).transpose() + else: + # multinomial case + raw_preds = margins + class_probs = softmax(raw_preds, axis=1) + + # It seems that they use argmax of class probs, + # not of margin to get the prediction (Note: scala implementation) + preds = np.argmax(class_probs, axis=1) + return pd.DataFrame( + data={ + "rawPrediction": pd.Series(raw_preds.tolist()), + "prediction": pd.Series(preds), + "probability": pd.Series(class_probs.tolist()), + } + ) features_col = _validate_and_convert_feature_col_as_array_col( dataset, self.getOrDefault(self.featuresCol) ) - has_base_margin = False - if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): - has_base_margin = True - if has_base_margin: - base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") - pred_struct = predict_udf_base_margin(struct(features_col, base_margin_col)) + pred_struct = predict_udf(struct(features_col, base_margin_col)) else: pred_struct = predict_udf(struct(features_col)) diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index bf258f744dfa..999f80025fea 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -22,17 +22,6 @@ def gen_data_pdf(): ) -def prepare_predict_data(data_iterator, has_predict_base_margin): - return _process_data_iter( - data_iterator, - train=False, - has_weight=False, - has_validation=False, - has_fit_base_margin=False, - has_predict_base_margin=has_predict_base_margin, - ) - - def _check_feature_dims(num_dims, expected_dims): """ Check all feature vectors has the same dimension From c907164771596416999f15f3f8493808cbffde9f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 3 Jul 2022 23:22:03 +0800 Subject: [PATCH 27/73] update comment --- python-package/xgboost/spark/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 181677caa6dd..753502ce5553 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -86,7 +86,7 @@ _unsupported_xgb_params = [ "gpu_id", # we have "use_gpu" pyspark param instead. - "enable_categorical", # TODO: support this + "enable_categorical", # Use feature_types param to specify categorical feature instead "use_label_encoder", "n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead. "nthread", # Ditto From 3a92fac499ad645b7a0eb9ea1fb6d5f592519fc3 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 5 Jul 2022 22:36:20 +0800 Subject: [PATCH 28/73] fix --- python-package/xgboost/spark/core.py | 8 ++++++-- tests/python/test_spark/xgboost_local_test.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 753502ce5553..832d008133b1 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -665,7 +665,10 @@ def predict_udf(input_data: pd.DataFrame) -> pd.Series: else: base_margin = None - preds = xgb_sklearn_model.predict(X, base_margin=base_margin, **predict_params) + preds = xgb_sklearn_model.predict( + X, base_margin=base_margin, validate_features=False + **predict_params + ) return pd.Series(preds) features_col = _validate_and_convert_feature_col_as_array_col( @@ -715,7 +718,8 @@ def predict_udf(input_data: pd.DataFrame) -> pd.DataFrame: base_margin = None margins = xgb_sklearn_model.predict( - X, base_margin=base_margin, output_margin=True, **predict_params + X, base_margin=base_margin, output_margin=True, validate_features=False + **predict_params ) if margins.ndim == 1: # binomial case diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 17f5f5ca9f5e..ada0176e0881 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -435,7 +435,7 @@ def test_param_alias(self): def test_gpu_param_setting(self): py_cls = SparkXGBClassifier(use_gpu=True) - train_params = py_cls._get_distributed_train_params(self.cls_df_train, {}) + train_params = py_cls._get_distributed_train_params(self.cls_df_train) assert train_params["gpu_id"] == 0 assert train_params["tree_method"] == "gpu_hist" From 9e6ad5506e6788a868fb8c99a1c89f919bf70e97 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 5 Jul 2022 23:29:14 +0800 Subject: [PATCH 29/73] fix tests --- python-package/xgboost/spark/core.py | 11 +++++---- .../test_spark/xgboost_local_cluster_test.py | 18 ++++++++------ tests/python/test_spark/xgboost_local_test.py | 24 +++++++++++++------ 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 832d008133b1..9e56c84a9f2b 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -157,7 +157,7 @@ def _xgb_cls(cls): raise NotImplementedError() def _get_xgb_model_creator(self): - xgb_params = self._gen_xgb_params_dict() + xgb_params = self._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True) return get_xgb_model_creator(self._xgb_cls(), xgb_params) # Parameters for xgboost.XGBModel() @@ -174,14 +174,15 @@ def _set_xgb_params_default(self): filtered_params_dict = self._get_xgb_params_default() self._setDefault(**filtered_params_dict) - def _gen_xgb_params_dict(self): + def _gen_xgb_params_dict(self, gen_xgb_sklearn_estimator_param=False): xgb_params = {} non_xgb_params = ( set(_pyspark_specific_params) - | set(_non_booster_params) | self._get_fit_params_default().keys() | self._get_predict_params_default().keys() ) + if not gen_xgb_sklearn_estimator_param: + non_xgb_params |= set(_non_booster_params) for param in self.extractParamMap(): if param.name not in non_xgb_params: xgb_params[param.name] = self.getOrDefault(param) @@ -666,7 +667,7 @@ def predict_udf(input_data: pd.DataFrame) -> pd.Series: base_margin = None preds = xgb_sklearn_model.predict( - X, base_margin=base_margin, validate_features=False + X, base_margin=base_margin, validate_features=False, **predict_params ) return pd.Series(preds) @@ -718,7 +719,7 @@ def predict_udf(input_data: pd.DataFrame) -> pd.DataFrame: base_margin = None margins = xgb_sklearn_model.predict( - X, base_margin=base_margin, output_margin=True, validate_features=False + X, base_margin=base_margin, output_margin=True, validate_features=False, **predict_params ) if margins.ndim == 1: diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index 30a0f6db06b7..b42f0231ba96 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -174,7 +174,7 @@ def setUp(self): ], ) self.clf_best_score_eval = 0.009677 - self.clf_best_score_weight_and_eval = 0.006628 + self.clf_best_score_weight_and_eval = 0.006626 self.reg_params_with_eval_dist = { "validationIndicatorCol": "isVal", @@ -209,8 +209,8 @@ def setUp(self): "expected_prediction_with_weight_and_eval", ], ) - self.reg_best_score_eval = 5.2e-05 - self.reg_best_score_weight_and_eval = 4.9e-05 + self.reg_best_score_eval = 5.239e-05 + self.reg_best_score_weight_and_eval = 4.810e-05 def test_regressor_basic_with_params(self): regressor = SparkXGBRegressor(**self.reg_params) @@ -332,9 +332,10 @@ def test_classifier_distributed_weight_eval(self): self.assertTrue( np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3) ) - self.assertEqual( + assert np.isclose( float(model.get_booster().attributes()["best_score"]), self.clf_best_score_eval, + rtol=1e-3 ) # with both weight and eval @@ -354,9 +355,10 @@ def test_classifier_distributed_weight_eval(self): row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3 ) ) - self.assertEqual( + np.isclose( float(model.get_booster().attributes()["best_score"]), self.clf_best_score_weight_and_eval, + rtol=1e-3 ) def test_regressor_distributed_weight_eval(self): @@ -390,9 +392,10 @@ def test_regressor_distributed_weight_eval(self): self.assertTrue( np.isclose(row.prediction, row.expected_prediction_with_eval, atol=1e-3) ) - self.assertEqual( + assert np.isclose( float(model.get_booster().attributes()["best_score"]), self.reg_best_score_eval, + rtol=1e-3 ) # with both weight and eval regressor = SparkXGBRegressor( @@ -414,9 +417,10 @@ def test_regressor_distributed_weight_eval(self): atol=1e-3, ) ) - self.assertEqual( + assert np.isclose( float(model.get_booster().attributes()["best_score"]), self.reg_best_score_weight_and_eval, + rtol=1e-3 ) def test_num_estimators(self): diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index ada0176e0881..99585b92f5c7 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -21,6 +21,7 @@ ) from .utils_test import SparkTestCase from xgboost import XGBClassifier, XGBRegressor +from xgboost.spark.core import _non_booster_params logging.getLogger("py4j").setLevel(logging.INFO) @@ -479,10 +480,11 @@ def test_multi_classifier(self): np.allclose(row.probability, row.expected_probability, rtol=1e-3) ) - def _check_sub_dict_match(self, sub_dist, whole_dict): + def _check_sub_dict_match(self, sub_dist, whole_dict, excluding_keys): for k in sub_dist: - self.assertTrue(k in whole_dict) - self.assertEqual(sub_dist[k], whole_dict[k]) + if k not in excluding_keys: + self.assertTrue(k in whole_dict, f"check on {k} failed") + self.assertEqual(sub_dist[k], whole_dict[k], f"check on {k} failed") def test_regressor_with_params(self): regressor = SparkXGBRegressor(**self.reg_params) @@ -491,7 +493,9 @@ def test_regressor_with_params(self): **(regressor._gen_fit_params_dict()), **(regressor._gen_predict_params_dict()), ) - self._check_sub_dict_match(self.reg_params, all_params) + self._check_sub_dict_match( + self.reg_params, all_params, excluding_keys=_non_booster_params + ) model = regressor.fit(self.reg_df_train) all_params = dict( @@ -499,7 +503,9 @@ def test_regressor_with_params(self): **(model._gen_fit_params_dict()), **(model._gen_predict_params_dict()), ) - self._check_sub_dict_match(self.reg_params, all_params) + self._check_sub_dict_match( + self.reg_params, all_params, excluding_keys=_non_booster_params + ) pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( @@ -515,7 +521,9 @@ def test_classifier_with_params(self): **(classifier._gen_fit_params_dict()), **(classifier._gen_predict_params_dict()), ) - self._check_sub_dict_match(self.cls_params, all_params) + self._check_sub_dict_match( + self.cls_params, all_params, excluding_keys=_non_booster_params + ) model = classifier.fit(self.cls_df_train) all_params = dict( @@ -523,7 +531,9 @@ def test_classifier_with_params(self): **(model._gen_fit_params_dict()), **(model._gen_predict_params_dict()), ) - self._check_sub_dict_match(self.cls_params, all_params) + self._check_sub_dict_match( + self.cls_params, all_params, excluding_keys=_non_booster_params + ) pred_result = model.transform(self.cls_df_test).collect() for row in pred_result: self.assertEqual(row.prediction, row.expected_prediction_with_params) From d24512d8571e7c76ea29000d76f8462df8adcaee Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 5 Jul 2022 23:39:32 +0800 Subject: [PATCH 30/73] forbid camel case param in setParams --- python-package/xgboost/spark/core.py | 10 +++++++++- tests/python/test_spark/xgboost_local_test.py | 5 ++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 9e56c84a9f2b..bace95898a26 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -84,6 +84,10 @@ "base_margin_col": "baseMarginCol", } +_inverse_pyspark_param_alias_map = { + v: k for k, v in _pyspark_param_alias_map.items() +} + _unsupported_xgb_params = [ "gpu_id", # we have "use_gpu" pyspark param instead. "enable_categorical", # Use feature_types param to specify categorical feature instead @@ -338,10 +342,14 @@ def __init__(self): def setParams(self, **kwargs): _extra_params = {} - if 'arbitraryParamsDict' in kwargs: + if 'arbitraryParamsDict' in kwargs or 'arbitrary_params_dict' in kwargs: raise ValueError("Wrong param name: 'arbitraryParamsDict'.") for k, v in kwargs.items(): + if k in _inverse_pyspark_param_alias_map: + raise ValueError( + f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead." + ) if k in _pyspark_param_alias_map: real_k = _pyspark_param_alias_map[k] if real_k in kwargs: diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 99585b92f5c7..8100e75382af 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -3,6 +3,7 @@ import uuid import numpy as np +import pytest from pyspark.ml.functions import vector_to_array from pyspark.sql import functions as spark_sql_func from pyspark.ml import Pipeline, PipelineModel @@ -430,9 +431,11 @@ def test_classifier_kwargs_basic(self): self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["z"], 2) def test_param_alias(self): - py_cls = SparkXGBClassifier(featuresCol="f1", label_col="l1") + py_cls = SparkXGBClassifier(features_col="f1", label_col="l1") self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1") self.assertEqual(py_cls.getOrDefault(py_cls.labelCol), "l1") + with pytest.raises(ValueError, match="Please use param name features_col instead"): + SparkXGBClassifier(featuresCol="f1") def test_gpu_param_setting(self): py_cls = SparkXGBClassifier(use_gpu=True) From b3fa18592bfeebd847c653e4ac2c90c9e0b8d136 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 5 Jul 2022 23:53:52 +0800 Subject: [PATCH 31/73] rename 2 camel case params --- python-package/xgboost/spark/core.py | 43 +++++++++---------- python-package/xgboost/spark/estimator.py | 12 +++--- python-package/xgboost/spark/params.py | 18 ++++---- tests/python/test_spark/xgboost_local_test.py | 27 ++++++------ 4 files changed, 49 insertions(+), 51 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index bace95898a26..abfa00c889cf 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -50,7 +50,7 @@ from pyspark.ml.linalg import VectorUDT # Put pyspark specific params here, they won't be passed to XGBoost. -# like `validationIndicatorCol`, `baseMarginCol` +# like `validationIndicatorCol`, `base_margin_col` _pyspark_specific_params = [ "featuresCol", "labelCol", @@ -59,8 +59,8 @@ "predictionCol", "probabilityCol", "validationIndicatorCol", - "baseMarginCol", - "arbitraryParamsDict", + "base_margin_col", + "arbitrary_params_dict", "force_repartition", "num_workers", "use_gpu", @@ -81,7 +81,6 @@ "prediction_col": "predictionCol", "probability_col": "probabilityCol", "validation_indicator_col": "validationIndicatorCol", - "base_margin_col": "baseMarginCol", } _inverse_pyspark_param_alias_map = { @@ -101,14 +100,14 @@ # Supported by spark param weightCol # and validationIndicatorCol "eval_set", "sample_weight_eval_set", - "base_margin", # Supported by spark param baseMarginCol + "base_margin", # Supported by spark param base_margin_col } _unsupported_predict_params = { # for classification, we can use rawPrediction as margin "output_margin", "validate_features", # TODO - "base_margin", # Use pyspark baseMarginCol param instead. + "base_margin", # Use pyspark base_margin_col param instead. } @@ -191,8 +190,8 @@ def _gen_xgb_params_dict(self, gen_xgb_sklearn_estimator_param=False): if param.name not in non_xgb_params: xgb_params[param.name] = self.getOrDefault(param) - arbitraryParamsDict = self.getOrDefault(self.getParam("arbitraryParamsDict")) - xgb_params.update(arbitraryParamsDict) + arbitrary_params_dict = self.getOrDefault(self.getParam("arbitrary_params_dict")) + xgb_params.update(arbitrary_params_dict) return xgb_params # Parameters for xgboost.XGBModel().fit() @@ -328,8 +327,8 @@ def __init__(self): self._set_xgb_params_default() self._set_fit_params_default() self._set_predict_params_default() - # Note: The default value for arbitraryParamsDict must always be empty dict. - # For additional settings added into "arbitraryParamsDict" by default, + # Note: The default value for arbitrary_params_dict must always be empty dict. + # For additional settings added into "arbitrary_params_dict" by default, # they are added in `setParams`. self._setDefault( num_workers=1, @@ -337,13 +336,13 @@ def __init__(self): force_repartition=False, feature_names=None, feature_types=None, - arbitraryParamsDict={} + arbitrary_params_dict={} ) def setParams(self, **kwargs): _extra_params = {} - if 'arbitraryParamsDict' in kwargs or 'arbitrary_params_dict' in kwargs: - raise ValueError("Wrong param name: 'arbitraryParamsDict'.") + if 'arbitrary_params_dict' in kwargs: + raise ValueError("Invalid param name: 'arbitrary_params_dict'.") for k, v in kwargs.items(): if k in _inverse_pyspark_param_alias_map: @@ -364,8 +363,8 @@ def setParams(self, **kwargs): k in _unsupported_predict_params: raise ValueError(f"Unsupported param '{k}'.") _extra_params[k] = v - _existing_extra_params = self.getOrDefault(self.arbitraryParamsDict) - self._set(arbitraryParamsDict={**_existing_extra_params, **_extra_params}) + _existing_extra_params = self.getOrDefault(self.arbitrary_params_dict) + self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params}) @classmethod def _pyspark_model_cls(cls): @@ -497,11 +496,11 @@ def _fit(self, dataset): ) ) - if self.isDefined(self.baseMarginCol) and self.getOrDefault( - self.baseMarginCol): + if self.isDefined(self.base_margin_col) and self.getOrDefault( + self.base_margin_col): has_base_margin = True select_cols.append( - col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin")) + col(self.getOrDefault(self.base_margin_col)).alias("baseMargin")) dataset = dataset.select(*select_cols) @@ -662,9 +661,9 @@ def _transform(self, dataset): predict_params = self._gen_predict_params_dict() has_base_margin = False - if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): + if self.isDefined(self.base_margin_col) and self.getOrDefault(self.base_margin_col): has_base_margin = True - base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") + base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias("baseMargin") @pandas_udf("double") def predict_udf(input_data: pd.DataFrame) -> pd.Series: @@ -712,9 +711,9 @@ def _transform(self, dataset): predict_params = self._gen_predict_params_dict() has_base_margin = False - if self.isDefined(self.baseMarginCol) and self.getOrDefault(self.baseMarginCol): + if self.isDefined(self.base_margin_col) and self.getOrDefault(self.base_margin_col): has_base_margin = True - base_margin_col = col(self.getOrDefault(self.baseMarginCol)).alias("baseMargin") + base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias("baseMargin") @pandas_udf( "rawPrediction array, prediction double, probability array" diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index eb00e3518024..798e809eaedc 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -23,7 +23,7 @@ class SparkXGBRegressor(_SparkXGBEstimator): see doc below for more details. XgboostRegressor doesn't support setting `base_margin` explicitly as well, but support - another param called `baseMarginCol`. see doc below for more details. + another param called `base_margin_col`. see doc below for more details. XgboostRegressor doesn't support `validate_features` and `output_margin` param. @@ -54,9 +54,9 @@ class SparkXGBRegressor(_SparkXGBEstimator): use_gpu: Boolean that specifies whether the executors are running on GPU instances. - baseMarginCol: + base_margin_col: To specify the base margins of the training and validation - dataset, set :py:attr:`xgboost.spark.XgboostRegressor.baseMarginCol` parameter + dataset, set :py:attr:`xgboost.spark.XgboostRegressor.base_margin_col` parameter instead of setting `base_margin` and `base_margin_eval_set` in the `xgboost.XGBRegressor` fit method. Note: this isn't available for distributed training. @@ -119,7 +119,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction see doc below for more details. XgboostClassifier doesn't support setting `base_margin` explicitly as well, but support - another param called `baseMarginCol`. see doc below for more details. + another param called `base_margin_col`. see doc below for more details. XgboostClassifier doesn't support setting `output_margin`, but we can get output margin from the raw prediction column. See `rawPredictionCol` param doc below for more details. @@ -159,9 +159,9 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction use_gpu: Boolean that specifies whether the executors are running on GPU instances. - baseMarginCol: + base_margin_col: To specify the base margins of the training and validation - dataset, set :py:attr:`xgboost.spark.XgboostClassifier.baseMarginCol` parameter + dataset, set :py:attr:`xgboost.spark.XgboostClassifier.base_margin_col` parameter instead of setting `base_margin` and `base_margin_eval_set` in the `xgboost.XGBClassifier` fit method. Note: this isn't available for distributed training. diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 028968a874ad..3172deac47cc 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -8,19 +8,19 @@ class HasArbitraryParamsDict(Params): input. """ - arbitraryParamsDict = Param( + arbitrary_params_dict = Param( Params._dummy(), - "arbitraryParamsDict", - "arbitraryParamsDict This parameter holds all of the additional parameters which are " + "arbitrary_params_dict", + "arbitrary_params_dict This parameter holds all of the additional parameters which are " "not exposed as the the XGBoost Spark estimator params but can be recognized by " "underlying XGBoost library. It is stored as a dictionary.", ) def setArbitraryParamsDict(self, value): - return self._set(arbitraryParamsDict=value) + return self._set(arbitrary_params_dict=value) def getArbitraryParamsDict(self): - return self.getOrDefault(self.arbitraryParamsDict) + return self.getOrDefault(self.arbitrary_params_dict) class HasBaseMarginCol(Params): @@ -29,14 +29,14 @@ class HasBaseMarginCol(Params): and holds the variable to store the base margin column part of XGboost. """ - baseMarginCol = Param( + base_margin_col = Param( Params._dummy(), - "baseMarginCol", + "base_margin_col", "This stores the name for the column of the base margin", ) def setBaseMarginCol(self, value): - return self._set(baseMarginCol=value) + return self._set(base_margin_col=value) def getBaseMarginCol(self): - return self.getOrDefault(self.baseMarginCol) + return self.getOrDefault(self.base_margin_col) diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 8100e75382af..4fe5204d9c5b 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -196,7 +196,7 @@ def setUp(self): ["features", "label", "isVal", "weight"], ) self.reg_params_with_eval = { - "validationIndicatorCol": "isVal", + "validation_indicator_col": "isVal", "early_stopping_rounds": 1, "eval_metric": "rmse", } @@ -256,7 +256,7 @@ def setUp(self): ["features", "label", "isVal", "weight"], ) self.cls_params_with_eval = { - "validationIndicatorCol": "isVal", + "validation_indicator_col": "isVal", "early_stopping_rounds": 1, "eval_metric": "logloss", } @@ -408,15 +408,15 @@ def test_classifier_kwargs_basic(self): self.assertTrue(hasattr(py_cls, "n_estimators")) self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) self.assertFalse(hasattr(py_cls, "gpu_id")) - self.assertTrue(hasattr(py_cls, "arbitraryParamsDict")) + self.assertTrue(hasattr(py_cls, "arbitrary_params_dict")) expected_kwargs = {"sketch_eps": 0.03} self.assertEqual( - py_cls.getOrDefault(py_cls.arbitraryParamsDict), expected_kwargs + py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs ) self.assertTrue("sketch_eps" in py_cls._get_xgb_model_creator()().get_params()) # We want all of the new params to be in the .get_params() call and be an attribute of py_cls, but not of the actual model self.assertTrue( - "arbitraryParamsDict" not in py_cls._get_xgb_model_creator()().get_params() + "arbitrary_params_dict" not in py_cls._get_xgb_model_creator()().get_params() ) # Testing overwritten params @@ -440,7 +440,6 @@ def test_param_alias(self): def test_gpu_param_setting(self): py_cls = SparkXGBClassifier(use_gpu=True) train_params = py_cls._get_distributed_train_params(self.cls_df_train) - assert train_params["gpu_id"] == 0 assert train_params["tree_method"] == "gpu_hist" @staticmethod @@ -449,7 +448,7 @@ def test_param_value_converter(): # don't check by isintance(v, float) because for numpy scalar it will also return True assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == "float" assert ( - py_cls.getOrDefault(py_cls.arbitraryParamsDict)[ + py_cls.getOrDefault(py_cls.arbitrary_params_dict)[ "sketch_eps" ].__class__.__name__ == "float64" @@ -697,7 +696,7 @@ def test_train_with_initial_model(self): self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) def test_classifier_with_base_margin(self): - cls_without_base_margin = SparkXGBClassifier(weightCol="weight") + cls_without_base_margin = SparkXGBClassifier(weight_col="weight") model_without_base_margin = cls_without_base_margin.fit( self.cls_df_train_without_base_margin ) @@ -717,7 +716,7 @@ def test_classifier_with_base_margin(self): ) cls_with_same_base_margin = SparkXGBClassifier( - weightCol="weight", baseMarginCol="base_margin" + weight_col="weight", base_margin_col="base_margin" ) model_with_same_base_margin = cls_with_same_base_margin.fit( self.cls_df_train_with_same_base_margin @@ -734,7 +733,7 @@ def test_classifier_with_base_margin(self): np.testing.assert_allclose(row.probability, row.expected_prob_with_base_margin, atol=1e-3) cls_with_different_base_margin = SparkXGBClassifier( - weightCol="weight", baseMarginCol="base_margin" + weight_col="weight", base_margin_col="base_margin" ) model_with_different_base_margin = cls_with_different_base_margin.fit( self.cls_df_train_with_different_base_margin @@ -754,7 +753,7 @@ def test_classifier_with_base_margin(self): def test_regressor_with_weight_eval(self): # with weight - regressor_with_weight = SparkXGBRegressor(weightCol="weight") + regressor_with_weight = SparkXGBRegressor(weight_col="weight") model_with_weight = regressor_with_weight.fit( self.reg_df_train_with_eval_weight ) @@ -792,7 +791,7 @@ def test_regressor_with_weight_eval(self): ) # with weight and eval regressor_with_weight_eval = SparkXGBRegressor( - weightCol="weight", **self.reg_params_with_eval + weight_col="weight", **self.reg_params_with_eval ) model_with_weight_eval = regressor_with_weight_eval.fit( self.reg_df_train_with_eval_weight @@ -818,7 +817,7 @@ def test_regressor_with_weight_eval(self): def test_classifier_with_weight_eval(self): # with weight - classifier_with_weight = SparkXGBClassifier(weightCol="weight") + classifier_with_weight = SparkXGBClassifier(weight_col="weight") model_with_weight = classifier_with_weight.fit( self.cls_df_train_with_eval_weight ) @@ -850,7 +849,7 @@ def test_classifier_with_weight_eval(self): # Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which # doesn't really indicate this working correctly. classifier_with_weight_eval = SparkXGBClassifier( - weightCol="weight", scale_pos_weight=4, **self.cls_params_with_eval + weight_col="weight", scale_pos_weight=4, **self.cls_params_with_eval ) model_with_weight_eval = classifier_with_weight_eval.fit( self.cls_df_train_with_eval_weight From 8cee5cb099945c27ae75ab64f091d15f83da57fb Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 00:02:49 +0800 Subject: [PATCH 32/73] address comments --- python-package/xgboost/spark/__init__.py | 4 ++-- python-package/xgboost/spark/core.py | 17 ++++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py index 9d1e4b3c91b3..6af666b185dc 100644 --- a/python-package/xgboost/spark/__init__.py +++ b/python-package/xgboost/spark/__init__.py @@ -3,8 +3,8 @@ try: import pyspark -except ImportError: - raise RuntimeError("xgboost spark python API requires pyspark package installed.") +except ImportError as e: + raise ImportError("pyspark package needs to be installed to use this module") from e from .estimator import ( SparkXGBClassifier, diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index abfa00c889cf..5e649e9bbb47 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,4 +1,3 @@ -from typing import Iterator import numpy as np import pandas as pd from scipy.special import expit, softmax @@ -45,8 +44,9 @@ ) from pyspark.ml.functions import array_to_vector, vector_to_array -from pyspark.sql.types import \ +from pyspark.sql.types import ( ArrayType, DoubleType, FloatType, IntegerType, LongType, ShortType +) from pyspark.ml.linalg import VectorUDT # Put pyspark specific params here, they won't be passed to XGBoost. @@ -147,7 +147,7 @@ class _XgboostParams( ) feature_types = Param( Params._dummy(), - "feature_names", + "feature_types", "A list of str to specify feature types." ) @@ -400,6 +400,17 @@ def _convert_to_model(self, booster): def _query_plan_contains_valid_repartition(self, dataset): """ Returns true if the latest element in the logical plan is a valid repartition + The logic plan string format is like: + + == Optimized Logical Plan == + Repartition 4, true + +- LogicalRDD [features#12, label#13L], false + + i.e., the top line in the logical plan is the last operation to execute. + so, in this method, we check the first line, if it is a "Repartition" operation, + and the result dataframe has the same partition number with num_workers param, + then it means the dataframe is well repartitioned and we don't need to + repartition the dataframe again. """ num_partitions = dataset.rdd.getNumPartitions() query_plan = dataset._sc._jvm.PythonSQLUtils.explainString( From 18956d138b20a878b50ee504ddf669711ca9e0e6 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 00:04:47 +0800 Subject: [PATCH 33/73] update doc --- python-package/xgboost/spark/core.py | 4 +- python-package/xgboost/spark/estimator.py | 58 +++++++++++------------ python-package/xgboost/spark/model.py | 2 +- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 5e649e9bbb47..ba2a3557bbf7 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -656,7 +656,7 @@ def _transform(self, dataset): class SparkXGBRegressorModel(_SparkXGBModel): """ - The model returned by :func:`xgboost.spark.XgboostRegressor.fit` + The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit` .. Note:: This API is experimental. """ @@ -706,7 +706,7 @@ def predict_udf(input_data: pd.DataFrame) -> pd.Series: class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol): """ - The model returned by :func:`xgboost.spark.XgboostClassifier.fit` + The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit` .. Note:: This API is experimental. """ diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 798e809eaedc..f44f9b3abd62 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -10,44 +10,44 @@ class SparkXGBRegressor(_SparkXGBEstimator): """ - XgboostRegressor is a PySpark ML estimator. It implements the XGBoost regression + SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression algorithm based on XGBoost python library, and it can be used in PySpark Pipeline and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest. - XgboostRegressor automatically supports most of the parameters in + SparkXGBRegressor automatically supports most of the parameters in `xgboost.XGBRegressor` constructor and most of the parameters used in `xgboost.XGBRegressor` fit and predict method (see `API docs `_ for details). - XgboostRegressor doesn't support setting `gpu_id` but support another param `use_gpu`, + SparkXGBRegressor doesn't support setting `gpu_id` but support another param `use_gpu`, see doc below for more details. - XgboostRegressor doesn't support setting `base_margin` explicitly as well, but support + SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support another param called `base_margin_col`. see doc below for more details. - XgboostRegressor doesn't support `validate_features` and `output_margin` param. + SparkXGBRegressor doesn't support `validate_features` and `output_margin` param. callbacks: The export and import of the callback functions are at best effort. - For details, see :py:attr:`xgboost.spark.XgboostRegressor.callbacks` param doc. + For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc. missing: - The parameter `missing` in XgboostRegressor has different semantics with + The parameter `missing` in SparkXGBRegressor has different semantics with that in `xgboost.XGBRegressor`. For details, see - :py:attr:`xgboost.spark.XgboostRegressor.missing` param doc. + :py:attr:`xgboost.spark.SparkXGBRegressor.missing` param doc. validationIndicatorCol For params related to `xgboost.XGBRegressor` training with evaluation dataset's supervision, set - :py:attr:`xgboost.spark.XgboostRegressor.validationIndicatorCol` + :py:attr:`xgboost.spark.SparkXGBRegressor.validationIndicatorCol` parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor` fit method. weightCol: To specify the weight of the training and validation dataset, set - :py:attr:`xgboost.spark.XgboostRegressor.weightCol` parameter instead of setting + :py:attr:`xgboost.spark.SparkXGBRegressor.weightCol` parameter instead of setting `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor` fit method. xgb_model: Set the value to be the instance returned by - :func:`xgboost.spark.XgboostRegressorModel.get_booster`. + :func:`xgboost.spark.SparkXGBRegressorModel.get_booster`. num_workers: Integer that specifies the number of XGBoost workers to use. Each XGBoost worker corresponds to one spark task. @@ -56,7 +56,7 @@ class SparkXGBRegressor(_SparkXGBEstimator): instances. base_margin_col: To specify the base margins of the training and validation - dataset, set :py:attr:`xgboost.spark.XgboostRegressor.base_margin_col` parameter + dataset, set :py:attr:`xgboost.spark.SparkXGBRegressor.base_margin_col` parameter instead of setting `base_margin` and `base_margin_eval_set` in the `xgboost.XGBRegressor` fit method. Note: this isn't available for distributed training. @@ -68,7 +68,7 @@ class SparkXGBRegressor(_SparkXGBEstimator): **Examples** - >>> from xgboost.spark import XgboostRegressor + >>> from xgboost.spark import SparkXGBRegressor >>> from pyspark.ml.linalg import Vectors >>> df_train = spark.createDataFrame([ ... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), @@ -80,7 +80,7 @@ class SparkXGBRegressor(_SparkXGBEstimator): ... (Vectors.dense(1.0, 2.0, 3.0), ), ... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), ) ... ], ["features"]) - >>> xgb_regressor = XgboostRegressor(max_depth=5, missing=0.0, + >>> xgb_regressor = SparkXGBRegressor(max_depth=5, missing=0.0, ... validationIndicatorCol='isVal', weightCol='weight', ... early_stopping_rounds=1, eval_metric='rmse') >>> xgb_reg_model = xgb_regressor.fit(df_train) @@ -106,35 +106,35 @@ def _pyspark_model_cls(cls): class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPredictionCol): """ - XgboostClassifier is a PySpark ML estimator. It implements the XGBoost classification + SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost classification algorithm based on XGBoost python library, and it can be used in PySpark Pipeline and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest. - XgboostClassifier automatically supports most of the parameters in + SparkXGBClassifier automatically supports most of the parameters in `xgboost.XGBClassifier` constructor and most of the parameters used in `xgboost.XGBClassifier` fit and predict method (see `API docs `_ for details). - XgboostClassifier doesn't support setting `gpu_id` but support another param `use_gpu`, + SparkXGBClassifier doesn't support setting `gpu_id` but support another param `use_gpu`, see doc below for more details. - XgboostClassifier doesn't support setting `base_margin` explicitly as well, but support + SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but support another param called `base_margin_col`. see doc below for more details. - XgboostClassifier doesn't support setting `output_margin`, but we can get output margin + SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin from the raw prediction column. See `rawPredictionCol` param doc below for more details. - XgboostClassifier doesn't support `validate_features` and `output_margin` param. + SparkXGBClassifier doesn't support `validate_features` and `output_margin` param. Parameters ---------- callbacks: The export and import of the callback functions are at best effort. For - details, see :py:attr:`xgboost.spark.XgboostClassifier.callbacks` param doc. + details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc. missing: - The parameter `missing` in XgboostClassifier has different semantics with + The parameter `missing` in SparkXGBClassifier has different semantics with that in `xgboost.XGBClassifier`. For details, see - :py:attr:`xgboost.spark.XgboostClassifier.missing` param doc. + :py:attr:`xgboost.spark.SparkXGBClassifier.missing` param doc. rawPredictionCol: The `output_margin=True` is implicitly supported by the `rawPredictionCol` output column, which is always returned with the predicted margin @@ -142,17 +142,17 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction validationIndicatorCol: For params related to `xgboost.XGBClassifier` training with evaluation dataset's supervision, - set :py:attr:`xgboost.spark.XgboostClassifier.validationIndicatorCol` + set :py:attr:`xgboost.spark.SparkXGBClassifier.validationIndicatorCol` parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier` fit method. weightCol: To specify the weight of the training and validation dataset, set - :py:attr:`xgboost.spark.XgboostClassifier.weightCol` parameter instead of setting + :py:attr:`xgboost.spark.SparkXGBClassifier.weightCol` parameter instead of setting `sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier` fit method. xgb_model: Set the value to be the instance returned by - :func:`xgboost.spark.XgboostClassifierModel.get_booster`. + :func:`xgboost.spark.SparkXGBClassifierModel.get_booster`. num_workers: Integer that specifies the number of XGBoost workers to use. Each XGBoost worker corresponds to one spark task. @@ -161,7 +161,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction instances. base_margin_col: To specify the base margins of the training and validation - dataset, set :py:attr:`xgboost.spark.XgboostClassifier.base_margin_col` parameter + dataset, set :py:attr:`xgboost.spark.SparkXGBClassifier.base_margin_col` parameter instead of setting `base_margin` and `base_margin_eval_set` in the `xgboost.XGBClassifier` fit method. Note: this isn't available for distributed training. @@ -173,7 +173,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction **Examples** - >>> from xgboost.spark import XgboostClassifier + >>> from xgboost.spark import SparkXGBClassifier >>> from pyspark.ml.linalg import Vectors >>> df_train = spark.createDataFrame([ ... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), @@ -184,7 +184,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction >>> df_test = spark.createDataFrame([ ... (Vectors.dense(1.0, 2.0, 3.0), ), ... ], ["features"]) - >>> xgb_classifier = XgboostClassifier(max_depth=5, missing=0.0, + >>> xgb_classifier = SparkXGBClassifier(max_depth=5, missing=0.0, ... validationIndicatorCol='isVal', weightCol='weight', ... early_stopping_rounds=1, eval_metric='logloss') >>> xgb_clf_model = xgb_classifier.fit(df_train) diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index b8ef24c0e7da..105d69bb877b 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -235,7 +235,7 @@ def load(self, path): """ Load metadata and model for a :py:class:`_XgboostModel` - :return: XgboostRegressorModel or XgboostClassifierModel instance + :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance """ _, py_model = XgboostSharedReadWrite.loadMetadataAndInstance( self.cls, path, self.sc, self.logger From d4f048d1d7df869eab90a3e89e34bbb22274b733 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 00:06:16 +0800 Subject: [PATCH 34/73] remove feature_types pyspark param --- python-package/xgboost/spark/core.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index ba2a3557bbf7..a997599fad01 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -64,7 +64,6 @@ "force_repartition", "num_workers", "use_gpu", - "feature_types", "feature_names", ] @@ -145,11 +144,6 @@ class _XgboostParams( "feature_names", "A list of str to specify feature names." ) - feature_types = Param( - Params._dummy(), - "feature_types", - "A list of str to specify feature types." - ) @classmethod def _xgb_cls(cls): From 9e66d2fe1cdc4d4537974ac95b0053da18018d96 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 00:26:53 +0800 Subject: [PATCH 35/73] fix-test --- python-package/xgboost/spark/estimator.py | 4 +++- tests/python/test_spark/xgboost_local_cluster_test.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index f44f9b3abd62..cb5e68a857d2 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -25,7 +25,9 @@ class SparkXGBRegressor(_SparkXGBEstimator): SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support another param called `base_margin_col`. see doc below for more details. - SparkXGBRegressor doesn't support `validate_features` and `output_margin` param. + SparkXGBRegressor doesn't support following params: + `gpu_id`, `enable_categorical`, `use_label_encoder`, `n_jobs`, `nthread`, + `validate_features`, `output_margin`, `base_margin` param. callbacks: The export and import of the callback functions are at best effort. diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index b42f0231ba96..32b095eaba7f 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -146,7 +146,7 @@ def setUp(self): "early_stopping_rounds": 1, "eval_metric": "logloss", } - self.clf_params_with_weight_dist = {"weightCol": "weight"} + self.clf_params_with_weight_dist = {"weight_col": "weight"} self.cls_df_train_distributed_with_eval_weight = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), @@ -181,7 +181,7 @@ def setUp(self): "early_stopping_rounds": 1, "eval_metric": "rmse", } - self.reg_params_with_weight_dist = {"weightCol": "weight"} + self.reg_params_with_weight_dist = {"weight_col": "weight"} self.reg_df_train_distributed_with_eval_weight = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), From c60ccadc7610579cee751141900bcf523e1f8163 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 00:28:46 +0800 Subject: [PATCH 36/73] update-doc --- python-package/xgboost/spark/estimator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index cb5e68a857d2..f44f9b3abd62 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -25,9 +25,7 @@ class SparkXGBRegressor(_SparkXGBEstimator): SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support another param called `base_margin_col`. see doc below for more details. - SparkXGBRegressor doesn't support following params: - `gpu_id`, `enable_categorical`, `use_label_encoder`, `n_jobs`, `nthread`, - `validate_features`, `output_margin`, `base_margin` param. + SparkXGBRegressor doesn't support `validate_features` and `output_margin` param. callbacks: The export and import of the callback functions are at best effort. From 4f247aba4863964f44da0ad3df51d93c41216b87 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 08:39:34 +0800 Subject: [PATCH 37/73] fix tests --- python-package/setup.py | 3 ++- tests/ci_build/conda_env/macos_cpu_test.yml | 1 + tests/python/test_spark/xgboost_local_cluster_test.py | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python-package/setup.py b/python-package/setup.py index 93f36de7b2d1..edea60087cdb 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -349,7 +349,8 @@ def run(self) -> None: 'scikit-learn': ['scikit-learn'], 'dask': ['dask', 'pandas', 'distributed'], 'datatable': ['datatable'], - 'plotting': ['graphviz', 'matplotlib'] + 'plotting': ['graphviz', 'matplotlib'], + 'pyspark': ['pyspark'], }, maintainer='Hyunsu Cho', maintainer_email='chohyu01@cs.washington.edu', diff --git a/tests/ci_build/conda_env/macos_cpu_test.yml b/tests/ci_build/conda_env/macos_cpu_test.yml index 38ac8aa1f421..7594c48894f0 100644 --- a/tests/ci_build/conda_env/macos_cpu_test.yml +++ b/tests/ci_build/conda_env/macos_cpu_test.yml @@ -35,6 +35,7 @@ dependencies: - py-ubjson - cffi - pyarrow +- pyspark - pip: - sphinx_rtd_theme - datatable diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index 32b095eaba7f..66c3e8bc3e14 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -142,7 +142,7 @@ def setUp(self): # Adding weight and validation self.clf_params_with_eval_dist = { - "validationIndicatorCol": "isVal", + "validation_indicator_col": "isVal", "early_stopping_rounds": 1, "eval_metric": "logloss", } @@ -177,7 +177,7 @@ def setUp(self): self.clf_best_score_weight_and_eval = 0.006626 self.reg_params_with_eval_dist = { - "validationIndicatorCol": "isVal", + "validation_indicator_col": "isVal", "early_stopping_rounds": 1, "eval_metric": "rmse", } From 40afa4f1547e4cb28d4f842b843eb916810e58da Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 09:11:25 +0800 Subject: [PATCH 38/73] fix lint, refactor --- python-package/xgboost/spark/core.py | 44 ++++++++++++++++------- python-package/xgboost/spark/data.py | 28 ++++++++------- python-package/xgboost/spark/estimator.py | 2 ++ python-package/xgboost/spark/model.py | 36 +++++++++++++------ python-package/xgboost/spark/params.py | 13 +------ python-package/xgboost/spark/utils.py | 8 ++++- 6 files changed, 83 insertions(+), 48 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index a997599fad01..1516dfd93a35 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,3 +1,8 @@ +"""Xgboost pyspark integration submodule for core code.""" +# pylint: disable=import-error, consider-using-f-string, too-many-arguments, too-many-locals +# pylint: disable=invalid-name, fixme, too-many-lines, unbalanced-tuple-unpacking, no-else-return +# pylint: disable=protected-access, logging-fstring-interpolation, no-name-in-module +# pylint: disable=wrong-import-order, ungrouped-imports import numpy as np import pandas as pd from scipy.special import expit, softmax @@ -21,13 +26,13 @@ from xgboost.training import train as worker_train from .utils import get_logger, _get_max_num_concurrent_tasks from .data import ( - convert_partition_data_to_dmatrix, + _convert_partition_data_to_dmatrix, ) from .model import ( - XgboostReader, - XgboostWriter, - XgboostModelReader, - XgboostModelWriter, + SparkXGBReader, + SparkXGBWriter, + SparkXGBModelReader, + SparkXGBModelWriter, get_xgb_model_creator, ) from .utils import ( @@ -334,6 +339,9 @@ def __init__(self): ) def setParams(self, **kwargs): + """ + Set params for the estimator. + """ _extra_params = {} if 'arbitrary_params_dict' in kwargs: raise ValueError("Invalid param name: 'arbitrary_params_dict'.") @@ -431,7 +439,7 @@ def _repartition_needed(self, dataset): try: if self._query_plan_contains_valid_repartition(dataset): return False - except: # noqa: E722 + except Exception: # noqa: E722 pass return True @@ -562,14 +570,14 @@ def _train_booster(pandas_df_iter): with RabitContext(_rabit_args, context): dtrain, dval = None, [] if has_validation: - dtrain, dval = convert_partition_data_to_dmatrix( + dtrain, dval = _convert_partition_data_to_dmatrix( pandas_df_iter, has_weight, has_validation, has_base_margin, dmatrix_kwargs=dmatrix_kwargs, ) # TODO: Question: do we need to add dtrain to dval list ? dval = [(dtrain, "training"), (dval, "validation")] else: - dtrain = convert_partition_data_to_dmatrix( + dtrain = _convert_partition_data_to_dmatrix( pandas_df_iter, has_weight, has_validation, has_base_margin, dmatrix_kwargs=dmatrix_kwargs, ) @@ -596,11 +604,17 @@ def _train_booster(pandas_df_iter): return self._copyValues(self._create_pyspark_model(result_xgb_model)) def write(self): - return XgboostWriter(self) + """ + Return the writer for saving the estimator. + """ + return SparkXGBWriter(self) @classmethod def read(cls): - return XgboostReader(cls) + """ + Return the reader for loading the estimator. + """ + return SparkXGBReader(cls) class _SparkXGBModel(Model, _XgboostParams, MLReadable, MLWritable): @@ -638,11 +652,17 @@ def get_feature_importances(self, importance_type="weight"): return self.get_booster().get_score(importance_type=importance_type) def write(self): - return XgboostModelWriter(self) + """ + Return the writer for saving the model. + """ + return SparkXGBModelWriter(self) @classmethod def read(cls): - return XgboostModelReader(cls) + """ + Return the reader for loading the model. + """ + return SparkXGBModelReader(cls) def _transform(self, dataset): raise NotImplementedError() diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 999f80025fea..7d7020778034 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -1,11 +1,13 @@ -import os +"""Xgboost pyspark integration submodule for data related functions.""" +# pylint: disable=import-error, consider-using-f-string, too-many-arguments, too-many-locals, +# pylint: disable=invalid-name, fixme, too-many-lines, unbalanced-tuple-unpacking, no-else-return from typing import Iterator import numpy as np import pandas as pd from xgboost import DMatrix -def prepare_train_val_data( +def _prepare_train_val_data( data_iterator, has_weight, has_validation, has_fit_base_margin ): def gen_data_pdf(): @@ -148,23 +150,23 @@ def _process_data_iter( has_validation, ) return train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m - else: - return _row_tuple_list_to_feature_matrix_y_w( - data_iterator, - train, - has_weight, - has_fit_base_margin, - has_predict_base_margin, - has_validation, - ) + + return _row_tuple_list_to_feature_matrix_y_w( + data_iterator, + train, + has_weight, + has_fit_base_margin, + has_predict_base_margin, + has_validation, + ) -def convert_partition_data_to_dmatrix( +def _convert_partition_data_to_dmatrix( partition_data_iter, has_weight, has_validation, has_base_margin, dmatrix_kwargs=None ): dmatrix_kwargs = dmatrix_kwargs or {} # if we are not using external storage, we use the standard method of parsing data. - train_val_data = prepare_train_val_data( + train_val_data = _prepare_train_val_data( partition_data_iter, has_weight, has_validation, has_base_margin ) if has_validation: diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index f44f9b3abd62..6f59a1990723 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,3 +1,5 @@ +"""Xgboost pyspark integration submodule for estimator API.""" +# pylint: disable=import-error from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from xgboost import XGBClassifier, XGBRegressor from .core import ( diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 105d69bb877b..6aa4d71498dd 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -1,3 +1,7 @@ +"""Xgboost pyspark integration submodule for model API.""" +# pylint: disable=import-error, consider-using-f-string, unspecified-encoding, +# pylint: disable=invalid-name, fixme, +# pylint: disable=protected-access, too-few-public-methods import base64 import os import uuid @@ -48,7 +52,7 @@ def serialize_xgb_model(model): # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") model.save_model(tmp_file_name) - with open(tmp_file_name) as f: + with open(tmp_file_name, 'r') as f: ser_model_string = f.read() return ser_model_string @@ -103,7 +107,7 @@ def _get_spark_session(): return SparkSession.builder.getOrCreate() -class XgboostSharedReadWrite: +class SparkXGBSharedReadWrite: @staticmethod def saveMetadata(instance, path, sc, logger, extraMetadata=None): """ @@ -181,30 +185,39 @@ def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): return metadata, pyspark_xgb -class XgboostWriter(MLWriter): +class SparkXGBWriter(MLWriter): + """ + Spark Xgboost estimator writer. + """ def __init__(self, instance): super().__init__() self.instance = instance self.logger = get_logger(self.__class__.__name__, level="WARN") def saveImpl(self, path): - XgboostSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) -class XgboostReader(MLReader): +class SparkXGBReader(MLReader): + """ + Spark Xgboost estimator reader. + """ def __init__(self, cls): super().__init__() self.cls = cls self.logger = get_logger(self.__class__.__name__, level="WARN") def load(self, path): - _, pyspark_xgb = XgboostSharedReadWrite.loadMetadataAndInstance( + _, pyspark_xgb = SparkXGBSharedReadWrite.loadMetadataAndInstance( self.cls, path, self.sc, self.logger ) return pyspark_xgb -class XgboostModelWriter(MLWriter): +class SparkXGBModelWriter(MLWriter): + """ + Spark Xgboost model writer. + """ def __init__(self, instance): super().__init__() self.instance = instance @@ -217,7 +230,7 @@ def saveImpl(self, path): - save model to path/model.json """ xgb_model = self.instance._xgb_sklearn_model - XgboostSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) model_save_path = os.path.join(path, "model.json") ser_xgb_model = serialize_xgb_model(xgb_model) _get_spark_session().createDataFrame( @@ -225,7 +238,10 @@ def saveImpl(self, path): ).write.parquet(model_save_path) -class XgboostModelReader(MLReader): +class SparkXGBModelReader(MLReader): + """ + Spark Xgboost model reader. + """ def __init__(self, cls): super().__init__() self.cls = cls @@ -237,7 +253,7 @@ def load(self, path): :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance """ - _, py_model = XgboostSharedReadWrite.loadMetadataAndInstance( + _, py_model = SparkXGBSharedReadWrite.loadMetadataAndInstance( self.cls, path, self.sc, self.logger ) diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 3172deac47cc..1d4cbac80f5e 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -1,3 +1,4 @@ +"""Xgboost pyspark integration submodule for params.""" from pyspark.ml.param.shared import Param, Params @@ -16,12 +17,6 @@ class HasArbitraryParamsDict(Params): "underlying XGBoost library. It is stored as a dictionary.", ) - def setArbitraryParamsDict(self, value): - return self._set(arbitrary_params_dict=value) - - def getArbitraryParamsDict(self): - return self.getOrDefault(self.arbitrary_params_dict) - class HasBaseMarginCol(Params): """ @@ -34,9 +29,3 @@ class HasBaseMarginCol(Params): "base_margin_col", "This stores the name for the column of the base margin", ) - - def setBaseMarginCol(self, value): - return self._set(base_margin_col=value) - - def getBaseMarginCol(self): - return self.getOrDefault(self.base_margin_col) diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 132897d9ad32..f9c41350d30b 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -1,3 +1,6 @@ +"""Xgboost pyspark integration submodule for helper functions.""" +# pylint: disable=import-error, consider-using-f-string, protected-access, wrong-import-order +# pylint: disable=invalid-name import inspect from threading import Thread import sys @@ -10,6 +13,9 @@ def get_class_name(cls): + """ + Return the class name. + """ return f"{cls.__module__}.{cls.__name__}" @@ -19,7 +25,7 @@ def _get_default_params_from_func(func, unsupported_set): Only the parameters with a default value will be included. """ sig = inspect.signature(func) - filtered_params_dict = dict() + filtered_params_dict = {} for parameter in sig.parameters.values(): # Remove parameters without a default value and those in the unsupported_set if ( From 508a36b126866d995e5ed0b34372e502191a16b1 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 09:21:08 +0800 Subject: [PATCH 39/73] fix test --- tests/python/test_spark/data_test.py | 4 ++-- tests/python/test_spark/utils_test.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 98b99fa11784..3679525399b3 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -5,7 +5,7 @@ from xgboost.spark.data import ( _row_tuple_list_to_feature_matrix_y_w, - convert_partition_data_to_dmatrix, + _convert_partition_data_to_dmatrix, ) from xgboost import DMatrix, XGBClassifier @@ -80,7 +80,7 @@ def row_tup_iter(data): "values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, "label": [1, 0] * 100, } - output_dmatrix = convert_partition_data_to_dmatrix( + output_dmatrix = _convert_partition_data_to_dmatrix( [pd.DataFrame(data)], has_weight=False, has_validation=False, has_base_margin=False ) # You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index 8ea006ad8335..1edf8c5b6abf 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -1,8 +1,6 @@ -import unittest import contextlib import logging import shutil -import subprocess import sys import tempfile @@ -12,7 +10,6 @@ from pyspark.sql import SQLContext from pyspark.sql import SparkSession -from pyspark.taskcontext import TaskContext from xgboost.spark.utils import _get_default_params_from_func From 39e2b45cf8daa262072d8845f763c402809aa2a9 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 09:22:24 +0800 Subject: [PATCH 40/73] fix test --- tests/python/test_spark/data_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 3679525399b3..3feb8ba096e0 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -98,7 +98,7 @@ def row_tup_iter(data): ) data["weight"] = [0.2, 0.8] * 100 - output_dmatrix = convert_partition_data_to_dmatrix( + output_dmatrix = _convert_partition_data_to_dmatrix( [pd.DataFrame(data)], has_weight=True, has_validation=False, has_base_margin=False ) @@ -121,7 +121,7 @@ def test_external_storage(self): # Creating the dmatrix based on storage temporary_path = tempfile.mkdtemp() - storage_dmatrix = convert_partition_data_to_dmatrix( + storage_dmatrix = _convert_partition_data_to_dmatrix( [pd.DataFrame(data)], has_weight=False, has_validation=False, has_base_margin=False ) @@ -139,7 +139,7 @@ def test_external_storage(self): data["weight"] = [0.2, 0.8] * 100 temporary_path = tempfile.mkdtemp() - storage_dmatrix = convert_partition_data_to_dmatrix( + storage_dmatrix = _convert_partition_data_to_dmatrix( [pd.DataFrame(data)], has_weight=True, has_validation=False, has_base_margin=False ) From e657a21325494ffe1d0bdf827a7d90cf7093810e Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 09:54:11 +0800 Subject: [PATCH 41/73] support feature weights --- python-package/xgboost/spark/core.py | 3 +++ tests/python/test_spark/xgboost_local_test.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 1516dfd93a35..f05a989d3637 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -75,6 +75,8 @@ _non_booster_params = [ "missing", "n_estimators", + "feature_types", + "feature_weights", ] _pyspark_param_alias_map = { @@ -542,6 +544,7 @@ def _fit(self, dataset): "nthread": cpu_per_task, "feature_types": self.getOrDefault(self.feature_types), "feature_names": self.getOrDefault(self.feature_names), + "feature_weights": self.getOrDefault(self.feature_weights), } booster_params['nthread'] = cpu_per_task use_gpu = self.getOrDefault(self.use_gpu) diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 4fe5204d9c5b..df65743fe3bb 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -949,3 +949,12 @@ def test_classifier_array_col_as_feature(self): self.assertTrue( np.allclose(row.probability, row.expected_probability, rtol=1e-3) ) + + def test_classifier_with_feature_names_types_weights(self): + classifier = SparkXGBClassifier( + feature_names=["a1", "a2", "a3"], + feature_types=["i", "int", "float"], + feature_weights=[2.0, 5.0, 3.0] + ) + model = classifier.fit(self.cls_df_train) + model.transform(self.cls_df_test).collect() From 60e2561941ad932544aac6b63f0151d16a6a7f5a Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 10:01:48 +0800 Subject: [PATCH 42/73] fix-ci --- python-package/xgboost/spark/core.py | 3 ++- python-package/xgboost/spark/model.py | 18 ++++++++++++------ python-package/xgboost/spark/params.py | 1 + tests/ci_build/conda_env/aarch64_test.yml | 1 + tests/ci_build/conda_env/cpu_test.yml | 1 + tests/ci_build/conda_env/win64_cpu_test.yml | 1 + tests/ci_build/conda_env/win64_test.yml | 1 + 7 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index f05a989d3637..5dc9c3908caa 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -2,7 +2,8 @@ # pylint: disable=import-error, consider-using-f-string, too-many-arguments, too-many-locals # pylint: disable=invalid-name, fixme, too-many-lines, unbalanced-tuple-unpacking, no-else-return # pylint: disable=protected-access, logging-fstring-interpolation, no-name-in-module -# pylint: disable=wrong-import-order, ungrouped-imports +# pylint: disable=wrong-import-order, ungrouped-imports, too-few-public-methods, broad-except +# pylint: disable=too-many-statements import numpy as np import pandas as pd from scipy.special import expit, softmax diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 6aa4d71498dd..b5ad14ee0fff 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -1,6 +1,6 @@ """Xgboost pyspark integration submodule for model API.""" # pylint: disable=import-error, consider-using-f-string, unspecified-encoding, -# pylint: disable=invalid-name, fixme, +# pylint: disable=invalid-name, fixme, unnecessary-lambda # pylint: disable=protected-access, too-few-public-methods import base64 import os @@ -107,7 +107,7 @@ def _get_spark_session(): return SparkSession.builder.getOrCreate() -class SparkXGBSharedReadWrite: +class _SparkXGBSharedReadWrite: @staticmethod def saveMetadata(instance, path, sc, logger, extraMetadata=None): """ @@ -195,7 +195,10 @@ def __init__(self, instance): self.logger = get_logger(self.__class__.__name__, level="WARN") def saveImpl(self, path): - SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + """ + save model. + """ + _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) class SparkXGBReader(MLReader): @@ -208,7 +211,10 @@ def __init__(self, cls): self.logger = get_logger(self.__class__.__name__, level="WARN") def load(self, path): - _, pyspark_xgb = SparkXGBSharedReadWrite.loadMetadataAndInstance( + """ + load model. + """ + _, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance( self.cls, path, self.sc, self.logger ) return pyspark_xgb @@ -230,7 +236,7 @@ def saveImpl(self, path): - save model to path/model.json """ xgb_model = self.instance._xgb_sklearn_model - SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) model_save_path = os.path.join(path, "model.json") ser_xgb_model = serialize_xgb_model(xgb_model) _get_spark_session().createDataFrame( @@ -253,7 +259,7 @@ def load(self, path): :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance """ - _, py_model = SparkXGBSharedReadWrite.loadMetadataAndInstance( + _, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance( self.cls, path, self.sc, self.logger ) diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 1d4cbac80f5e..2682f93e17b5 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -1,4 +1,5 @@ """Xgboost pyspark integration submodule for params.""" +# pylint: disable=import-error, too-few-public-methods from pyspark.ml.param.shared import Param, Params diff --git a/tests/ci_build/conda_env/aarch64_test.yml b/tests/ci_build/conda_env/aarch64_test.yml index 99e8f4840985..e57bc90ce08c 100644 --- a/tests/ci_build/conda_env/aarch64_test.yml +++ b/tests/ci_build/conda_env/aarch64_test.yml @@ -28,6 +28,7 @@ dependencies: - llvmlite - cffi - pyarrow +- pyspark - pip: - shap - awscli diff --git a/tests/ci_build/conda_env/cpu_test.yml b/tests/ci_build/conda_env/cpu_test.yml index 3180a66857aa..6932b21e8c9b 100644 --- a/tests/ci_build/conda_env/cpu_test.yml +++ b/tests/ci_build/conda_env/cpu_test.yml @@ -36,6 +36,7 @@ dependencies: - cffi - pyarrow - protobuf<=3.20 +- pyspark - pip: - shap - ipython # required by shap at import time. diff --git a/tests/ci_build/conda_env/win64_cpu_test.yml b/tests/ci_build/conda_env/win64_cpu_test.yml index 7789e94a6fcb..88fb5d7d0b56 100644 --- a/tests/ci_build/conda_env/win64_cpu_test.yml +++ b/tests/ci_build/conda_env/win64_cpu_test.yml @@ -20,3 +20,4 @@ dependencies: - py-ubjson - cffi - pyarrow +- pyspark diff --git a/tests/ci_build/conda_env/win64_test.yml b/tests/ci_build/conda_env/win64_test.yml index 3f62c034c6e0..9f7415b36283 100644 --- a/tests/ci_build/conda_env/win64_test.yml +++ b/tests/ci_build/conda_env/win64_test.yml @@ -18,3 +18,4 @@ dependencies: - py-ubjson - cffi - pyarrow +- pyspark From 70b2da2dcd0494cf091ea87ef42d61207b844162 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 12:32:17 +0800 Subject: [PATCH 43/73] update-ci-conda-env --- tests/ci_build/conda_env/aarch64_test.yml | 1 + tests/ci_build/conda_env/cpu_test.yml | 1 + tests/ci_build/conda_env/macos_cpu_test.yml | 1 + tests/ci_build/conda_env/win64_cpu_test.yml | 1 + tests/ci_build/conda_env/win64_test.yml | 2 ++ 5 files changed, 6 insertions(+) diff --git a/tests/ci_build/conda_env/aarch64_test.yml b/tests/ci_build/conda_env/aarch64_test.yml index e57bc90ce08c..7c6228a8de26 100644 --- a/tests/ci_build/conda_env/aarch64_test.yml +++ b/tests/ci_build/conda_env/aarch64_test.yml @@ -29,6 +29,7 @@ dependencies: - cffi - pyarrow - pyspark +- cloudpickle - pip: - shap - awscli diff --git a/tests/ci_build/conda_env/cpu_test.yml b/tests/ci_build/conda_env/cpu_test.yml index 6932b21e8c9b..13b92e49b00a 100644 --- a/tests/ci_build/conda_env/cpu_test.yml +++ b/tests/ci_build/conda_env/cpu_test.yml @@ -37,6 +37,7 @@ dependencies: - pyarrow - protobuf<=3.20 - pyspark +- cloudpickle - pip: - shap - ipython # required by shap at import time. diff --git a/tests/ci_build/conda_env/macos_cpu_test.yml b/tests/ci_build/conda_env/macos_cpu_test.yml index 7594c48894f0..0b30cc81b256 100644 --- a/tests/ci_build/conda_env/macos_cpu_test.yml +++ b/tests/ci_build/conda_env/macos_cpu_test.yml @@ -36,6 +36,7 @@ dependencies: - cffi - pyarrow - pyspark +- cloudpickle - pip: - sphinx_rtd_theme - datatable diff --git a/tests/ci_build/conda_env/win64_cpu_test.yml b/tests/ci_build/conda_env/win64_cpu_test.yml index 88fb5d7d0b56..2269ac7b1150 100644 --- a/tests/ci_build/conda_env/win64_cpu_test.yml +++ b/tests/ci_build/conda_env/win64_cpu_test.yml @@ -21,3 +21,4 @@ dependencies: - cffi - pyarrow - pyspark +- cloudpickle diff --git a/tests/ci_build/conda_env/win64_test.yml b/tests/ci_build/conda_env/win64_test.yml index 9f7415b36283..f94e470632ce 100644 --- a/tests/ci_build/conda_env/win64_test.yml +++ b/tests/ci_build/conda_env/win64_test.yml @@ -19,3 +19,5 @@ dependencies: - cffi - pyarrow - pyspark +- cloudpickle + From b40bc147f67f0dc884f8589ccf4ccf19c800f26b Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 15:14:43 +0800 Subject: [PATCH 44/73] add gpu test --- tests/python-gpu/conftest.py | 3 +- .../test_spark_with_gpu}/discover_gpu.sh | 0 .../test_spark_with_gpu.py | 91 +++++++++++++++++++ tests/python/test_spark/utils_test.py | 4 - 4 files changed, 93 insertions(+), 5 deletions(-) rename tests/{python/test_spark => python-gpu/test_spark_with_gpu}/discover_gpu.sh (100%) create mode 100644 tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py diff --git a/tests/python-gpu/conftest.py b/tests/python-gpu/conftest.py index 6b7eb531a3e4..3e4e1626b886 100644 --- a/tests/python-gpu/conftest.py +++ b/tests/python-gpu/conftest.py @@ -58,5 +58,6 @@ def pytest_collection_modifyitems(config, items): # mark dask tests as `mgpu`. mgpu_mark = pytest.mark.mgpu for item in items: - if item.nodeid.startswith("python-gpu/test_gpu_with_dask.py"): + if item.nodeid.startswith("python-gpu/test_gpu_with_dask.py") or \ + item.nodeid.startswith("python-gpu/test_spark_with_gpu/test_spark_with_gpu.py"): item.add_marker(mgpu_mark) diff --git a/tests/python/test_spark/discover_gpu.sh b/tests/python-gpu/test_spark_with_gpu/discover_gpu.sh similarity index 100% rename from tests/python/test_spark/discover_gpu.sh rename to tests/python-gpu/test_spark_with_gpu/discover_gpu.sh diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py new file mode 100644 index 000000000000..27adfd5308b8 --- /dev/null +++ b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py @@ -0,0 +1,91 @@ +from pyspark.sql import SparkSession +import logging +import pytest +from xgboost.spark import SparkXGBRegressor, SparkXGBClassifier +import sklearn +from pyspark.ml.linalg import Vectors + + +@pytest.fixture(scope="module", autouse=True) +def spark_session_with_gpu(): + spark_config = { + "spark.master": "local-cluster[1, 4, 1024]", + "spark.python.worker.reuse": "false", + "spark.cores.max": "4", + "spark.task.cpus": "1", + "spark.executor.cores": "4", + "spark.worker.resource.gpu.amount": "4", + "spark.task.resource.gpu.amount": "1", + "spark.executor.resource.gpu.amount": "4", + "spark.worker.resource.gpu.discoveryScript": "test_spark_with_gpu/discover_gpu.sh", + } + builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU") + for k, v in spark_config.items(): + builder.config(k, v) + spark = builder.getOrCreate() + logging.getLogger("pyspark").setLevel(logging.INFO) + # We run a dummy job so that we block until the workers have connected to the master + spark.sparkContext.parallelize(range(4), 4).barrier().mapPartitions(lambda _: []).collect() + yield spark + spark.stop() + + +@pytest.fixture +def spark_iris_dataset(spark_session_with_gpu): + spark = spark_session_with_gpu + data = sklearn.datasets.load_iris() + train_rows = [ + (Vectors.dense(features), float(label)) for features, label in zip(data.data[0::2], data.target[0::2]) + ] + train_df = spark.createDataFrame(spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]) + test_rows = [ + (Vectors.dense(features), float(label)) for features, label in zip(data.data[1::2], data.target[1::2]) + ] + test_df = spark.createDataFrame(spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]) + return train_df, test_df + + +@pytest.fixture +def spark_diabetes_dataset(spark_session_with_gpu): + spark = spark_session_with_gpu + data = sklearn.datasets.load_diabetes() + train_rows = [ + (Vectors.dense(features), float(label)) for features, label in zip(data.data[0::2], data.target[0::2]) + ] + train_df = spark.createDataFrame(spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]) + test_rows = [ + (Vectors.dense(features), float(label)) for features, label in zip(data.data[1::2], data.target[1::2]) + ] + test_df = spark.createDataFrame(spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]) + return train_df, test_df + + +def test_sparkxgb_classifier_with_gpu(spark_iris_dataset): + from pyspark.ml.evaluation import MulticlassClassificationEvaluator + classifier = SparkXGBClassifier( + use_gpu=True, + num_workers=4, + ) + train_df, test_df = spark_iris_dataset + model = classifier.fit(train_df) + pred_result_df = model.transform(test_df) + evaluator = MulticlassClassificationEvaluator(metricName="f1") + f1 = evaluator.evaluate(pred_result_df) + assert f1 >= 0.97 + + +def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): + from pyspark.ml.evaluation import RegressionEvaluator + regressor = SparkXGBRegressor( + use_gpu=True, + num_workers=4, + ) + train_df, test_df = spark_diabetes_dataset + model = regressor.fit(train_df) + pred_result_df = model.transform(test_df) + evaluator = RegressionEvaluator(metricName="rmse") + rmse = evaluator.evaluate(pred_result_df) + assert rmse <= 65.0 + + + diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index 1edf8c5b6abf..f20391127904 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -120,10 +120,6 @@ def setUpClass(cls): "spark.cores.max": "4", "spark.task.cpus": "1", "spark.executor.cores": "2", - "spark.worker.resource.gpu.amount": "4", - "spark.task.resource.gpu.amount": "2", - "spark.executor.resource.gpu.amount": "4", - "spark.worker.resource.gpu.discoveryScript": "test_spark/discover_gpu.sh", } ) cls.make_tempdir() From 6274aa09f2d23e43ee7b938453d6b63ddba7ee94 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 18:22:39 +0800 Subject: [PATCH 45/73] clean test --- .../test_spark_with_gpu.py | 2 -- .../test_spark/xgboost_local_cluster_test.py | 26 ------------------- 2 files changed, 28 deletions(-) diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py index 27adfd5308b8..5de56d2b5f93 100644 --- a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py +++ b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py @@ -87,5 +87,3 @@ def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): rmse = evaluator.evaluate(pred_result_df) assert rmse <= 65.0 - - diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index 66c3e8bc3e14..fd10008c9956 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -276,32 +276,6 @@ def test_regressor_distributed_basic(self): for row in pred_result: self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) - @unittest.skip - def test_check_use_gpu_param(self): - # Classifier - classifier = SparkXGBClassifier( - num_workers=self.n_workers, n_estimators=100, use_gpu=True - ) - self.assertTrue(hasattr(classifier, "use_gpu")) - self.assertTrue(classifier.getOrDefault(classifier.use_gpu)) - clf_model = classifier.fit(self.cls_df_train_distributed) - pred_result = clf_model.transform(self.cls_df_test_distributed).collect() - for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) - self.assertTrue( - np.allclose(row.expected_probability, row.probability, atol=1e-3) - ) - - regressor = SparkXGBRegressor( - num_workers=self.n_workers, n_estimators=100, use_gpu=True - ) - self.assertTrue(hasattr(regressor, "use_gpu")) - self.assertTrue(regressor.getOrDefault(regressor.use_gpu)) - model = regressor.fit(self.reg_df_train_distributed) - pred_result = model.transform(self.reg_df_test_distributed).collect() - for row in pred_result: - self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) - def test_classifier_distributed_weight_eval(self): # with weight classifier = SparkXGBClassifier( From 320725679ce1aea7c8d3d675b3b1649e27f60e96 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 19:13:19 +0800 Subject: [PATCH 46/73] ignore mypy --- python-package/xgboost/spark/core.py | 1 + python-package/xgboost/spark/data.py | 1 + python-package/xgboost/spark/estimator.py | 1 + python-package/xgboost/spark/model.py | 1 + python-package/xgboost/spark/params.py | 1 + python-package/xgboost/spark/utils.py | 1 + 6 files changed, 6 insertions(+) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 5dc9c3908caa..9eb834b00d5a 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,3 +1,4 @@ +# type: ignore """Xgboost pyspark integration submodule for core code.""" # pylint: disable=import-error, consider-using-f-string, too-many-arguments, too-many-locals # pylint: disable=invalid-name, fixme, too-many-lines, unbalanced-tuple-unpacking, no-else-return diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 7d7020778034..3e43f516f4a4 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -1,3 +1,4 @@ +# type: ignore """Xgboost pyspark integration submodule for data related functions.""" # pylint: disable=import-error, consider-using-f-string, too-many-arguments, too-many-locals, # pylint: disable=invalid-name, fixme, too-many-lines, unbalanced-tuple-unpacking, no-else-return diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 6f59a1990723..cc03d6b57055 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,3 +1,4 @@ +# type: ignore """Xgboost pyspark integration submodule for estimator API.""" # pylint: disable=import-error from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index b5ad14ee0fff..c3b8712f0f11 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -1,3 +1,4 @@ +# type: ignore """Xgboost pyspark integration submodule for model API.""" # pylint: disable=import-error, consider-using-f-string, unspecified-encoding, # pylint: disable=invalid-name, fixme, unnecessary-lambda diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 2682f93e17b5..27edd5c8f4ae 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -1,3 +1,4 @@ +# type: ignore """Xgboost pyspark integration submodule for params.""" # pylint: disable=import-error, too-few-public-methods from pyspark.ml.param.shared import Param, Params diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index f9c41350d30b..aa76657dd270 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -1,3 +1,4 @@ +# type: ignore """Xgboost pyspark integration submodule for helper functions.""" # pylint: disable=import-error, consider-using-f-string, protected-access, wrong-import-order # pylint: disable=invalid-name From 76c1f8d20f634fce1888e0fe920a23394361e528 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 19:14:24 +0800 Subject: [PATCH 47/73] ignore mypy --- python-package/xgboost/spark/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py index 6af666b185dc..4d4cd41627d0 100644 --- a/python-package/xgboost/spark/__init__.py +++ b/python-package/xgboost/spark/__init__.py @@ -1,3 +1,4 @@ +# type: ignore """PySpark XGBoost integration interface """ From 035ac68a2d3e615b2e7c3419205c9d93ea25ae20 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 20:18:43 +0800 Subject: [PATCH 48/73] update doc --- python-package/xgboost/spark/estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index cc03d6b57055..fdd633b3cc96 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -84,7 +84,7 @@ class SparkXGBRegressor(_SparkXGBEstimator): ... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), ) ... ], ["features"]) >>> xgb_regressor = SparkXGBRegressor(max_depth=5, missing=0.0, - ... validationIndicatorCol='isVal', weightCol='weight', + ... validation_indicator_col='isVal', weight_col='weight', ... early_stopping_rounds=1, eval_metric='rmse') >>> xgb_reg_model = xgb_regressor.fit(df_train) >>> xgb_reg_model.transform(df_test) @@ -188,7 +188,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction ... (Vectors.dense(1.0, 2.0, 3.0), ), ... ], ["features"]) >>> xgb_classifier = SparkXGBClassifier(max_depth=5, missing=0.0, - ... validationIndicatorCol='isVal', weightCol='weight', + ... validation_indicator_col='isVal', weight_col='weight', ... early_stopping_rounds=1, eval_metric='logloss') >>> xgb_clf_model = xgb_classifier.fit(df_train) >>> xgb_clf_model.transform(df_test).show() From a2ead7e711467a06aa8ed58c5a8f22a43a83d2cc Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 20:26:11 +0800 Subject: [PATCH 49/73] handle missing param --- python-package/xgboost/spark/core.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 9eb834b00d5a..4201a02d3685 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -547,6 +547,7 @@ def _fit(self, dataset): "feature_types": self.getOrDefault(self.feature_types), "feature_names": self.getOrDefault(self.feature_names), "feature_weights": self.getOrDefault(self.feature_weights), + "missing": self.getOrDefault(self.missing), } booster_params['nthread'] = cpu_per_task use_gpu = self.getOrDefault(self.use_gpu) @@ -835,18 +836,10 @@ def set_param_attrs(attr_name, param_obj_): setattr(pyspark_model_class, attr_name, param_obj_) for name in params_dict.keys(): - if name == "missing": - doc = ( - "Specify the missing value in the features, default np.nan. " - "We recommend using 0.0 as the missing value for better performance. " - "Note: In a spark DataFrame, the inactive values in a sparse vector " - "mean 0 instead of missing values, unless missing=0 is specified." - ) - else: - doc = ( - f"Refer to XGBoost doc of " - f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}" - ) + doc = ( + f"Refer to XGBoost doc of " + f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}" + ) param_obj = Param(Params._dummy(), name=name, doc=doc) set_param_attrs(name, param_obj) From 147243265e2b3314d55d31a4d281c0a4eea5d686 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 6 Jul 2022 20:35:57 +0800 Subject: [PATCH 50/73] update doc --- python-package/xgboost/spark/estimator.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index fdd633b3cc96..92d5cfd52c2c 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -33,10 +33,6 @@ class SparkXGBRegressor(_SparkXGBEstimator): callbacks: The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc. - missing: - The parameter `missing` in SparkXGBRegressor has different semantics with - that in `xgboost.XGBRegressor`. For details, see - :py:attr:`xgboost.spark.SparkXGBRegressor.missing` param doc. validationIndicatorCol For params related to `xgboost.XGBRegressor` training with evaluation dataset's supervision, set @@ -134,10 +130,6 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction callbacks: The export and import of the callback functions are at best effort. For details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc. - missing: - The parameter `missing` in SparkXGBClassifier has different semantics with - that in `xgboost.XGBClassifier`. For details, see - :py:attr:`xgboost.spark.SparkXGBClassifier.missing` param doc. rawPredictionCol: The `output_margin=True` is implicitly supported by the `rawPredictionCol` output column, which is always returned with the predicted margin From 7b5afa1e92e36e651753b475dfd624237ce2fb1f Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 7 Jul 2022 01:25:12 +0000 Subject: [PATCH 51/73] [CI] Install PySpark in Python env --- tests/ci_build/Dockerfile.gpu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 31fc1ffa82dc..ea4452564963 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -10,7 +10,7 @@ SHELL ["/bin/bash", "-c"] # Use Bash as shell RUN \ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \ apt-get update && \ - apt-get install -y wget unzip bzip2 libgomp1 build-essential && \ + apt-get install -y wget unzip bzip2 libgomp1 build-essential openjdk-8-jdk-headless && \ # Python wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ bash Miniconda3.sh -b -p /opt/python @@ -19,11 +19,14 @@ ENV PATH=/opt/python/bin:$PATH # Create new Conda environment with cuDF, Dask, and cuPy RUN \ - conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \ + conda install -c conda-forge mamba && \ + mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \ python=3.8 cudf=22.04* rmm=22.04* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda=22.04* dask-cudf=22.04* cupy \ - numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis + numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \ + pyspark cloudpickle cuda-python=11.7.0 ENV GOSU_VERSION 1.10 +ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/ # Install lightweight sudo (not bound to TTY) RUN set -ex; \ From f8aea440ed79a90da54d6d5eade596c78baa17a8 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 7 Jul 2022 12:03:08 +0800 Subject: [PATCH 52/73] fix spark-xgb-model params --- python-package/xgboost/spark/core.py | 30 +++-------- python-package/xgboost/spark/model.py | 16 ------ tests/python/test_spark/xgboost_local_test.py | 52 ++++++++----------- 3 files changed, 27 insertions(+), 71 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 4201a02d3685..e0df2398fb9d 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -162,10 +162,6 @@ def _xgb_cls(cls): """ raise NotImplementedError() - def _get_xgb_model_creator(self): - xgb_params = self._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True) - return get_xgb_model_creator(self._xgb_cls(), xgb_params) - # Parameters for xgboost.XGBModel() @classmethod def _get_xgb_params_default(cls): @@ -383,25 +379,11 @@ def _pyspark_model_cls(cls): def _create_pyspark_model(self, xgb_model): return self._pyspark_model_cls()(xgb_model) - @classmethod - def _convert_to_classifier(cls, booster): - clf = XGBClassifier() - clf._Booster = booster - return clf - - @classmethod - def _convert_to_regressor(cls, booster): - reg = XGBRegressor() - reg._Booster = booster - return reg - - def _convert_to_model(self, booster): - if self._xgb_cls() == XGBRegressor: - return self._convert_to_regressor(booster) - elif self._xgb_cls() == XGBClassifier: - return self._convert_to_classifier(booster) - else: - return None # check if this else statement is needed. + def _convert_to_sklearn_model(self, booster): + xgb_sklearn_params = self._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True) + sklearn_model = self._xgb_cls()(**xgb_sklearn_params) + sklearn_model._Booster = booster + return sklearn_model def _query_plan_contains_valid_repartition(self, dataset): """ @@ -606,7 +588,7 @@ def _train_booster(pandas_df_iter): .mapPartitions(lambda x: x) .collect()[0][0] ) - result_xgb_model = self._convert_to_model(cloudpickle.loads(result_ser_booster)) + result_xgb_model = self._convert_to_sklearn_model(cloudpickle.loads(result_ser_booster)) return self._copyValues(self._create_pyspark_model(result_xgb_model)) def write(self): diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index c3b8712f0f11..6a59463edbc4 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -16,22 +16,6 @@ from .utils import get_logger, get_class_name -def get_xgb_model_creator(model_cls, xgb_params): - """ - Returns a function that can be used to create an xgboost.XGBModel instance. - This function is used for creating the model instance on the worker, and is - shared by _XgboostEstimator and XgboostModel. - - Parameters - ---------- - model_cls: - a subclass of xgboost.XGBModel - xgb_params: - a dict of params to initialize the model_cls - """ - return lambda: model_cls(**xgb_params) # pylint: disable=W0108 - - def _get_or_create_tmp_dir(): root_dir = SparkFiles.getRootDirectory() xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp") diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index df65743fe3bb..6d45d9f70f48 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -381,10 +381,8 @@ def test_regressor_params_basic(self): self.assertEqual(py_reg.n_estimators.parent, py_reg.uid) self.assertFalse(hasattr(py_reg, "gpu_id")) self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100) - self.assertEqual(py_reg._get_xgb_model_creator()().n_estimators, 100) py_reg2 = SparkXGBRegressor(n_estimators=200) self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200) - self.assertEqual(py_reg2._get_xgb_model_creator()().n_estimators, 200) py_reg3 = py_reg2.copy({py_reg2.max_depth: 10}) self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200) self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10) @@ -395,10 +393,8 @@ def test_classifier_params_basic(self): self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) self.assertFalse(hasattr(py_cls, "gpu_id")) self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100) - self.assertEqual(py_cls._get_xgb_model_creator()().n_estimators, 100) py_cls2 = SparkXGBClassifier(n_estimators=200) self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200) - self.assertEqual(py_cls2._get_xgb_model_creator()().n_estimators, 200) py_cls3 = py_cls2.copy({py_cls2.max_depth: 10}) self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200) self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10) @@ -413,22 +409,15 @@ def test_classifier_kwargs_basic(self): self.assertEqual( py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs ) - self.assertTrue("sketch_eps" in py_cls._get_xgb_model_creator()().get_params()) - # We want all of the new params to be in the .get_params() call and be an attribute of py_cls, but not of the actual model - self.assertTrue( - "arbitrary_params_dict" not in py_cls._get_xgb_model_creator()().get_params() - ) # Testing overwritten params py_cls = SparkXGBClassifier() py_cls.setParams(x=1, y=2) - py_cls.setParams(y=1, z=2) - self.assertTrue("x" in py_cls._get_xgb_model_creator()().get_params()) - self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["x"], 1) - self.assertTrue("y" in py_cls._get_xgb_model_creator()().get_params()) - self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["y"], 1) - self.assertTrue("z" in py_cls._get_xgb_model_creator()().get_params()) - self.assertEqual(py_cls._get_xgb_model_creator()().get_params()["z"], 2) + py_cls.setParams(y=3, z=4) + xgb_params = py_cls._gen_xgb_params_dict() + assert xgb_params["x"] == 1 + assert xgb_params["y"] == 3 + assert xgb_params["z"] == 4 def test_param_alias(self): py_cls = SparkXGBClassifier(features_col="f1", label_col="l1") @@ -887,26 +876,27 @@ def test_use_gpu_param(self): classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist") classifier = SparkXGBClassifier(use_gpu=True) - def test_convert_to_model(self): - classifier = SparkXGBClassifier() + def test_convert_to_sklearn_model(self): + classifier = SparkXGBClassifier(n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5) clf_model = classifier.fit(self.cls_df_train) - regressor = SparkXGBRegressor() + regressor = SparkXGBRegressor(n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5) reg_model = regressor.fit(self.reg_df_train) # Check that regardless of what booster, _convert_to_model converts to the correct class type - self.assertEqual( - type(classifier._convert_to_model(clf_model.get_booster())), XGBClassifier - ) - self.assertEqual( - type(classifier._convert_to_model(reg_model.get_booster())), XGBClassifier - ) - self.assertEqual( - type(regressor._convert_to_model(clf_model.get_booster())), XGBRegressor - ) - self.assertEqual( - type(regressor._convert_to_model(reg_model.get_booster())), XGBRegressor - ) + sklearn_classifier = classifier._convert_to_sklearn_model(clf_model.get_booster()) + assert isinstance(sklearn_classifier, XGBClassifier) + assert sklearn_classifier.n_estimators == 200 + assert sklearn_classifier.missing == 2.0 + assert sklearn_classifier.max_depth == 3 + assert sklearn_classifier.get_params()["sketch_eps"] == 0.5 + + sklearn_regressor = regressor._convert_to_sklearn_model(reg_model.get_booster()) + assert isinstance(sklearn_regressor, XGBRegressor) + assert sklearn_regressor.n_estimators == 200 + assert sklearn_regressor.missing == 2.0 + assert sklearn_regressor.max_depth == 3 + assert sklearn_classifier.get_params()["sketch_eps"] == 0.5 def test_feature_importances(self): reg1 = SparkXGBRegressor(**self.reg_params) From c83843780ffd088ed2e490fc39eb7c58aef440ff Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 7 Jul 2022 12:18:20 +0800 Subject: [PATCH 53/73] add spark config for printing full error stack and avoid task retries --- .../python-gpu/test_spark_with_gpu/test_spark_with_gpu.py | 4 ++++ tests/python/test_spark/utils_test.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py index 5de56d2b5f93..418a5e6046b5 100644 --- a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py +++ b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py @@ -11,6 +11,10 @@ def spark_session_with_gpu(): spark_config = { "spark.master": "local-cluster[1, 4, 1024]", "spark.python.worker.reuse": "false", + "spark.driver.host": "127.0.0.1", + "spark.task.maxFailures": "1", + "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", + "spark.sql.pyspark.jvmStacktrace.enabled": "true", "spark.cores.max": "4", "spark.task.cpus": "1", "spark.executor.cores": "4", diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index f20391127904..48be73c9d622 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -102,6 +102,10 @@ def setUpClass(cls): { "spark.master": "local[2]", "spark.python.worker.reuse": "false", + "spark.driver.host": "127.0.0.1", + "spark.task.maxFailures": "1", + "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", + "spark.sql.pyspark.jvmStacktrace.enabled": "true", } ) @@ -117,6 +121,10 @@ def setUpClass(cls): { "spark.master": "local-cluster[2, 2, 1024]", "spark.python.worker.reuse": "false", + "spark.driver.host": "127.0.0.1", + "spark.task.maxFailures": "1", + "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", + "spark.sql.pyspark.jvmStacktrace.enabled": "true", "spark.cores.max": "4", "spark.task.cpus": "1", "spark.executor.cores": "2", From fac1c2a77bc7f46acab727b6a0592d1b84b2258f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 7 Jul 2022 12:38:05 +0800 Subject: [PATCH 54/73] clean test temp dir --- tests/python/test_spark/utils_test.py | 2 ++ tests/python/test_spark/xgboost_local_test.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index 48be73c9d622..1b3add9e7116 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -108,9 +108,11 @@ def setUpClass(cls): "spark.sql.pyspark.jvmStacktrace.enabled": "true", } ) + cls.make_tempdir() @classmethod def tearDownClass(cls): + cls.remove_tempdir() cls.tear_down_env() diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 6d45d9f70f48..7021d6b6cf9b 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -373,7 +373,7 @@ def setUp(self): ) def get_local_tmp_dir(self): - return "/tmp/xgboost_local_test/" + str(uuid.uuid4()) + return self.tempdir + str(uuid.uuid4()) def test_regressor_params_basic(self): py_reg = SparkXGBRegressor() From 62bf8bb746111e485beefbdbc49de355ea640962 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 7 Jul 2022 12:48:25 +0800 Subject: [PATCH 55/73] fix import --- python-package/xgboost/spark/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index e0df2398fb9d..f15b7175e5bb 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -35,7 +35,6 @@ SparkXGBWriter, SparkXGBModelReader, SparkXGBModelWriter, - get_xgb_model_creator, ) from .utils import ( _get_default_params_from_func, From 1b77f239b8ac1d3673f29c540f4c7069378a76e7 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 7 Jul 2022 14:11:39 +0800 Subject: [PATCH 56/73] add pyspark env config in CI --- tests/ci_build/test_python.sh | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/ci_build/test_python.sh b/tests/ci_build/test_python.sh index 8ecfb766126a..b0d01f1c3449 100755 --- a/tests/ci_build/test_python.sh +++ b/tests/ci_build/test_python.sh @@ -34,6 +34,18 @@ function install_xgboost { fi } +function setup_pyspark_envs { + export PYSPARK_DRIVER_PYTHON=`which python` + export PYSPARK_PYTHON=`which python` + export SPARK_TESTING=1 +} + +function unset_pyspark_envs { + unset PYSPARK_DRIVER_PYTHON + unset PYSPARK_PYTHON + unset SPARK_TESTING +} + function uninstall_xgboost { pip uninstall -y xgboost } @@ -43,14 +55,18 @@ case "$suite" in gpu) source activate gpu_test install_xgboost + setup_pyspark_envs pytest -v -s -rxXs --fulltrace --durations=0 -m "not mgpu" ${args} tests/python-gpu + unset_pyspark_envs uninstall_xgboost ;; mgpu) source activate gpu_test install_xgboost + setup_pyspark_envs pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu + unset_pyspark_envs cd tests/distributed ./runtests-gpu.sh @@ -61,7 +77,9 @@ case "$suite" in source activate cpu_test install_xgboost export RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1 + setup_pyspark_envs pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python + unset_pyspark_envs cd tests/distributed ./runtests.sh uninstall_xgboost @@ -70,7 +88,9 @@ case "$suite" in cpu-arm64) source activate aarch64_test install_xgboost + setup_pyspark_envs pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python/test_basic.py tests/python/test_basic_models.py tests/python/test_model_compatibility.py + unset_pyspark_envs uninstall_xgboost ;; From 0487223bace8a2b945f2ef7da85af42b1462e6ab Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 7 Jul 2022 15:24:03 +0800 Subject: [PATCH 57/73] update discoveryScript config path --- tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py index 418a5e6046b5..f3753bfb4684 100644 --- a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py +++ b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py @@ -21,7 +21,7 @@ def spark_session_with_gpu(): "spark.worker.resource.gpu.amount": "4", "spark.task.resource.gpu.amount": "1", "spark.executor.resource.gpu.amount": "4", - "spark.worker.resource.gpu.discoveryScript": "test_spark_with_gpu/discover_gpu.sh", + "spark.worker.resource.gpu.discoveryScript": "tests/python-gpu/test_spark_with_gpu/discover_gpu.sh", } builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU") for k, v in spark_config.items(): From 37938d89bb62f5891af54bc697a9121d1c4dae0e Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 7 Jul 2022 17:50:25 +0800 Subject: [PATCH 58/73] Skip tests. --- tests/ci_build/conda_env/win64_cpu_test.yml | 2 -- tests/ci_build/conda_env/win64_test.yml | 3 -- tests/python-gpu/conftest.py | 19 ++++++---- .../test_spark_with_gpu.py | 35 +++++++++++++------ tests/python/test_spark/data_test.py | 20 ++++++++--- .../test_spark/xgboost_local_cluster_test.py | 27 ++++++++------ tests/python/test_spark/xgboost_local_test.py | 26 ++++++++++---- tests/python/testing.py | 9 +++++ 8 files changed, 98 insertions(+), 43 deletions(-) diff --git a/tests/ci_build/conda_env/win64_cpu_test.yml b/tests/ci_build/conda_env/win64_cpu_test.yml index 2269ac7b1150..7789e94a6fcb 100644 --- a/tests/ci_build/conda_env/win64_cpu_test.yml +++ b/tests/ci_build/conda_env/win64_cpu_test.yml @@ -20,5 +20,3 @@ dependencies: - py-ubjson - cffi - pyarrow -- pyspark -- cloudpickle diff --git a/tests/ci_build/conda_env/win64_test.yml b/tests/ci_build/conda_env/win64_test.yml index f94e470632ce..3f62c034c6e0 100644 --- a/tests/ci_build/conda_env/win64_test.yml +++ b/tests/ci_build/conda_env/win64_test.yml @@ -18,6 +18,3 @@ dependencies: - py-ubjson - cffi - pyarrow -- pyspark -- cloudpickle - diff --git a/tests/python-gpu/conftest.py b/tests/python-gpu/conftest.py index 3e4e1626b886..789f96fc5511 100644 --- a/tests/python-gpu/conftest.py +++ b/tests/python-gpu/conftest.py @@ -44,13 +44,15 @@ def pytest_addoption(parser): def pytest_collection_modifyitems(config, items): - if config.getoption('--use-rmm-pool'): + if config.getoption("--use-rmm-pool"): blocklist = [ - 'python-gpu/test_gpu_demos.py::test_dask_training', - 'python-gpu/test_gpu_prediction.py::TestGPUPredict::test_shap', - 'python-gpu/test_gpu_linear.py::TestGPULinear' + "python-gpu/test_gpu_demos.py::test_dask_training", + "python-gpu/test_gpu_prediction.py::TestGPUPredict::test_shap", + "python-gpu/test_gpu_linear.py::TestGPULinear", ] - skip_mark = pytest.mark.skip(reason='This test is not run when --use-rmm-pool flag is active') + skip_mark = pytest.mark.skip( + reason="This test is not run when --use-rmm-pool flag is active" + ) for item in items: if any(item.nodeid.startswith(x) for x in blocklist): item.add_marker(skip_mark) @@ -58,6 +60,9 @@ def pytest_collection_modifyitems(config, items): # mark dask tests as `mgpu`. mgpu_mark = pytest.mark.mgpu for item in items: - if item.nodeid.startswith("python-gpu/test_gpu_with_dask.py") or \ - item.nodeid.startswith("python-gpu/test_spark_with_gpu/test_spark_with_gpu.py"): + if item.nodeid.startswith( + "python-gpu/test_gpu_with_dask.py" + ) or item.nodeid.startswith( + "python-gpu/test_spark_with_gpu/test_spark_with_gpu.py" + ): item.add_marker(mgpu_mark) diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py index f3753bfb4684..9d68751a266a 100644 --- a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py +++ b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py @@ -29,7 +29,9 @@ def spark_session_with_gpu(): spark = builder.getOrCreate() logging.getLogger("pyspark").setLevel(logging.INFO) # We run a dummy job so that we block until the workers have connected to the master - spark.sparkContext.parallelize(range(4), 4).barrier().mapPartitions(lambda _: []).collect() + spark.sparkContext.parallelize(range(4), 4).barrier().mapPartitions( + lambda _: [] + ).collect() yield spark spark.stop() @@ -39,13 +41,19 @@ def spark_iris_dataset(spark_session_with_gpu): spark = spark_session_with_gpu data = sklearn.datasets.load_iris() train_rows = [ - (Vectors.dense(features), float(label)) for features, label in zip(data.data[0::2], data.target[0::2]) + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[0::2], data.target[0::2]) ] - train_df = spark.createDataFrame(spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]) + train_df = spark.createDataFrame( + spark.sparkContext.parallelize(train_rows, 4), ["features", "label"] + ) test_rows = [ - (Vectors.dense(features), float(label)) for features, label in zip(data.data[1::2], data.target[1::2]) + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[1::2], data.target[1::2]) ] - test_df = spark.createDataFrame(spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]) + test_df = spark.createDataFrame( + spark.sparkContext.parallelize(test_rows, 4), ["features", "label"] + ) return train_df, test_df @@ -54,18 +62,25 @@ def spark_diabetes_dataset(spark_session_with_gpu): spark = spark_session_with_gpu data = sklearn.datasets.load_diabetes() train_rows = [ - (Vectors.dense(features), float(label)) for features, label in zip(data.data[0::2], data.target[0::2]) + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[0::2], data.target[0::2]) ] - train_df = spark.createDataFrame(spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]) + train_df = spark.createDataFrame( + spark.sparkContext.parallelize(train_rows, 4), ["features", "label"] + ) test_rows = [ - (Vectors.dense(features), float(label)) for features, label in zip(data.data[1::2], data.target[1::2]) + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[1::2], data.target[1::2]) ] - test_df = spark.createDataFrame(spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]) + test_df = spark.createDataFrame( + spark.sparkContext.parallelize(test_rows, 4), ["features", "label"] + ) return train_df, test_df def test_sparkxgb_classifier_with_gpu(spark_iris_dataset): from pyspark.ml.evaluation import MulticlassClassificationEvaluator + classifier = SparkXGBClassifier( use_gpu=True, num_workers=4, @@ -80,6 +95,7 @@ def test_sparkxgb_classifier_with_gpu(spark_iris_dataset): def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): from pyspark.ml.evaluation import RegressionEvaluator + regressor = SparkXGBRegressor( use_gpu=True, num_workers=4, @@ -90,4 +106,3 @@ def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): evaluator = RegressionEvaluator(metricName="rmse") rmse = evaluator.evaluate(pred_result_df) assert rmse <= 65.0 - diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 3feb8ba096e0..9773a54092ab 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -81,7 +81,10 @@ def row_tup_iter(data): "label": [1, 0] * 100, } output_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=False, has_validation=False, has_base_margin=False + [pd.DataFrame(data)], + has_weight=False, + has_validation=False, + has_base_margin=False, ) # You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using # the same classifier and making sure the outputs are equal @@ -99,7 +102,10 @@ def row_tup_iter(data): data["weight"] = [0.2, 0.8] * 100 output_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=True, has_validation=False, has_base_margin=False + [pd.DataFrame(data)], + has_weight=True, + has_validation=False, + has_base_margin=False, ) model.fit(expected_features, expected_labels, sample_weight=expected_weight) @@ -122,7 +128,10 @@ def test_external_storage(self): # Creating the dmatrix based on storage temporary_path = tempfile.mkdtemp() storage_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=False, has_validation=False, has_base_margin=False + [pd.DataFrame(data)], + has_weight=False, + has_validation=False, + has_base_margin=False, ) # Testing without weights @@ -140,7 +149,10 @@ def test_external_storage(self): temporary_path = tempfile.mkdtemp() storage_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], has_weight=True, has_validation=False, has_base_margin=False + [pd.DataFrame(data)], + has_weight=True, + has_validation=False, + has_base_margin=False, ) normal_booster = worker_train({}, normal_dmatrix) diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index fd10008c9956..4f47eb449090 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -1,15 +1,22 @@ +import sys import random -import unittest +import json +import uuid +import os +import pytest import numpy as np -from pyspark.ml.linalg import Vectors +import testing as tm + +if tm.no_dask()["condition"]: + pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) +if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows", allow_module_level=True) -from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from .utils_test import SparkLocalClusterTestCase +from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark.utils import _get_max_num_concurrent_tasks -import json -import uuid -import os +from pyspark.ml.linalg import Vectors class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): @@ -309,7 +316,7 @@ def test_classifier_distributed_weight_eval(self): assert np.isclose( float(model.get_booster().attributes()["best_score"]), self.clf_best_score_eval, - rtol=1e-3 + rtol=1e-3, ) # with both weight and eval @@ -332,7 +339,7 @@ def test_classifier_distributed_weight_eval(self): np.isclose( float(model.get_booster().attributes()["best_score"]), self.clf_best_score_weight_and_eval, - rtol=1e-3 + rtol=1e-3, ) def test_regressor_distributed_weight_eval(self): @@ -369,7 +376,7 @@ def test_regressor_distributed_weight_eval(self): assert np.isclose( float(model.get_booster().attributes()["best_score"]), self.reg_best_score_eval, - rtol=1e-3 + rtol=1e-3, ) # with both weight and eval regressor = SparkXGBRegressor( @@ -394,7 +401,7 @@ def test_regressor_distributed_weight_eval(self): assert np.isclose( float(model.get_booster().attributes()["best_score"]), self.reg_best_score_weight_and_eval, - rtol=1e-3 + rtol=1e-3, ) def test_num_estimators(self): diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 7021d6b6cf9b..422413cbc742 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -423,7 +423,9 @@ def test_param_alias(self): py_cls = SparkXGBClassifier(features_col="f1", label_col="l1") self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1") self.assertEqual(py_cls.getOrDefault(py_cls.labelCol), "l1") - with pytest.raises(ValueError, match="Please use param name features_col instead"): + with pytest.raises( + ValueError, match="Please use param name features_col instead" + ): SparkXGBClassifier(featuresCol="f1") def test_gpu_param_setting(self): @@ -719,7 +721,9 @@ def test_classifier_with_base_margin(self): row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 ) ) - np.testing.assert_allclose(row.probability, row.expected_prob_with_base_margin, atol=1e-3) + np.testing.assert_allclose( + row.probability, row.expected_prob_with_base_margin, atol=1e-3 + ) cls_with_different_base_margin = SparkXGBClassifier( weight_col="weight", base_margin_col="base_margin" @@ -738,7 +742,9 @@ def test_classifier_with_base_margin(self): row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 ) ) - np.testing.assert_allclose(row.probability, row.expected_prob_with_base_margin, atol=1e-3) + np.testing.assert_allclose( + row.probability, row.expected_prob_with_base_margin, atol=1e-3 + ) def test_regressor_with_weight_eval(self): # with weight @@ -877,14 +883,20 @@ def test_use_gpu_param(self): classifier = SparkXGBClassifier(use_gpu=True) def test_convert_to_sklearn_model(self): - classifier = SparkXGBClassifier(n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5) + classifier = SparkXGBClassifier( + n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5 + ) clf_model = classifier.fit(self.cls_df_train) - regressor = SparkXGBRegressor(n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5) + regressor = SparkXGBRegressor( + n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5 + ) reg_model = regressor.fit(self.reg_df_train) # Check that regardless of what booster, _convert_to_model converts to the correct class type - sklearn_classifier = classifier._convert_to_sklearn_model(clf_model.get_booster()) + sklearn_classifier = classifier._convert_to_sklearn_model( + clf_model.get_booster() + ) assert isinstance(sklearn_classifier, XGBClassifier) assert sklearn_classifier.n_estimators == 200 assert sklearn_classifier.missing == 2.0 @@ -944,7 +956,7 @@ def test_classifier_with_feature_names_types_weights(self): classifier = SparkXGBClassifier( feature_names=["a1", "a2", "a3"], feature_types=["i", "int", "float"], - feature_weights=[2.0, 5.0, 3.0] + feature_weights=[2.0, 5.0, 3.0], ) model = classifier.fit(self.cls_df_train) model.transform(self.cls_df_test).collect() diff --git a/tests/python/testing.py b/tests/python/testing.py index 8ff105e0c816..d1e19330b766 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -56,6 +56,15 @@ def no_dask(): return {"condition": not DASK_INSTALLED, "reason": "Dask is not installed"} +def no_spark(): + try: + import pyspark # noqa + SPARK_INSTALLED = True + except ImportError: + SPARK_INSTALLED = False + return {"condition": not SPARK_INSTALLED, "reason": "Spark is not installed"} + + def no_pandas(): return {'condition': not PANDAS_INSTALLED, 'reason': 'Pandas is not installed.'} From 024fc9e16d0b77f6427ce1b8a22128646a202783 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 7 Jul 2022 17:51:42 +0800 Subject: [PATCH 59/73] Missing dependency. --- python-package/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/setup.py b/python-package/setup.py index edea60087cdb..4f5ca6d74314 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -350,7 +350,7 @@ def run(self) -> None: 'dask': ['dask', 'pandas', 'distributed'], 'datatable': ['datatable'], 'plotting': ['graphviz', 'matplotlib'], - 'pyspark': ['pyspark'], + "pyspark": ["pyspark", "scikit-learn"], }, maintainer='Hyunsu Cho', maintainer_email='chohyu01@cs.washington.edu', From 4bb74047f72b0d03dc2fccee7a2d60048b5ddeb9 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 7 Jul 2022 17:52:50 +0800 Subject: [PATCH 60/73] black. --- python-package/xgboost/spark/core.py | 106 +++++++++++++++++--------- python-package/xgboost/spark/data.py | 32 ++++++-- python-package/xgboost/spark/model.py | 6 +- 3 files changed, 100 insertions(+), 44 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index f15b7175e5bb..bc812cbf89c7 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -51,7 +51,12 @@ from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.sql.types import ( - ArrayType, DoubleType, FloatType, IntegerType, LongType, ShortType + ArrayType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, ) from pyspark.ml.linalg import VectorUDT @@ -90,9 +95,7 @@ "validation_indicator_col": "validationIndicatorCol", } -_inverse_pyspark_param_alias_map = { - v: k for k, v in _pyspark_param_alias_map.items() -} +_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()} _unsupported_xgb_params = [ "gpu_id", # we have "use_gpu" pyspark param instead. @@ -148,9 +151,7 @@ class _XgboostParams( + "to have force_repartition be True.", ) feature_names = Param( - Params._dummy(), - "feature_names", - "A list of str to specify feature names." + Params._dummy(), "feature_names", "A list of str to specify feature names." ) @classmethod @@ -188,7 +189,9 @@ def _gen_xgb_params_dict(self, gen_xgb_sklearn_estimator_param=False): if param.name not in non_xgb_params: xgb_params[param.name] = self.getOrDefault(param) - arbitrary_params_dict = self.getOrDefault(self.getParam("arbitrary_params_dict")) + arbitrary_params_dict = self.getOrDefault( + self.getParam("arbitrary_params_dict") + ) xgb_params.update(arbitrary_params_dict) return xgb_params @@ -299,7 +302,7 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name): if isinstance(features_col_datatype, ArrayType): if not isinstance( features_col_datatype.elementType, - (DoubleType, FloatType, LongType, IntegerType, ShortType) + (DoubleType, FloatType, LongType, IntegerType, ShortType), ): raise ValueError( "If feature column is array type, its elements must be number type." @@ -334,7 +337,7 @@ def __init__(self): force_repartition=False, feature_names=None, feature_types=None, - arbitrary_params_dict={} + arbitrary_params_dict={}, ) def setParams(self, **kwargs): @@ -342,7 +345,7 @@ def setParams(self, **kwargs): Set params for the estimator. """ _extra_params = {} - if 'arbitrary_params_dict' in kwargs: + if "arbitrary_params_dict" in kwargs: raise ValueError("Invalid param name: 'arbitrary_params_dict'.") for k, v in kwargs.items(): @@ -353,15 +356,19 @@ def setParams(self, **kwargs): if k in _pyspark_param_alias_map: real_k = _pyspark_param_alias_map[k] if real_k in kwargs: - raise ValueError(f"You should set only one of param '{k}' and '{real_k}'") + raise ValueError( + f"You should set only one of param '{k}' and '{real_k}'" + ) k = real_k if self.hasParam(k): self._set(**{str(k): v}) else: - if k in _unsupported_xgb_params or \ - k in _unsupported_fit_params or \ - k in _unsupported_predict_params: + if ( + k in _unsupported_xgb_params + or k in _unsupported_fit_params + or k in _unsupported_predict_params + ): raise ValueError(f"Unsupported param '{k}'.") _extra_params[k] = v _existing_extra_params = self.getOrDefault(self.arbitrary_params_dict) @@ -379,7 +386,9 @@ def _create_pyspark_model(self, xgb_model): return self._pyspark_model_cls()(xgb_model) def _convert_to_sklearn_model(self, booster): - xgb_sklearn_params = self._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True) + xgb_sklearn_params = self._gen_xgb_params_dict( + gen_xgb_sklearn_estimator_param=True + ) sklearn_model = self._xgb_cls()(**xgb_sklearn_params) sklearn_model._Booster = booster return sklearn_model @@ -495,10 +504,12 @@ def _fit(self, dataset): ) if self.isDefined(self.base_margin_col) and self.getOrDefault( - self.base_margin_col): + self.base_margin_col + ): has_base_margin = True select_cols.append( - col(self.getOrDefault(self.base_margin_col)).alias("baseMargin")) + col(self.getOrDefault(self.base_margin_col)).alias("baseMargin") + ) dataset = dataset.select(*select_cols) @@ -517,20 +528,21 @@ def _fit(self, dataset): if self._repartition_needed(dataset): dataset = dataset.repartition(num_workers) train_params = self._get_distributed_train_params(dataset) - booster_params, train_call_kwargs_params = \ - self._get_xgb_train_call_args(train_params) + booster_params, train_call_kwargs_params = self._get_xgb_train_call_args( + train_params + ) cpu_per_task = int( _get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1") ) dmatrix_kwargs = { "nthread": cpu_per_task, - "feature_types": self.getOrDefault(self.feature_types), - "feature_names": self.getOrDefault(self.feature_names), + "feature_types": self.getOrDefault(self.feature_types), + "feature_names": self.getOrDefault(self.feature_names), "feature_weights": self.getOrDefault(self.feature_weights), "missing": self.getOrDefault(self.missing), } - booster_params['nthread'] = cpu_per_task + booster_params["nthread"] = cpu_per_task use_gpu = self.getOrDefault(self.use_gpu) def _train_booster(pandas_df_iter): @@ -545,7 +557,9 @@ def _train_booster(pandas_df_iter): if use_gpu: # Set booster worker to use the first GPU allocated to the spark task. - booster_params["gpu_id"] = int(context._resources["gpu"].addresses[0].strip()) + booster_params["gpu_id"] = int( + context._resources["gpu"].addresses[0].strip() + ) _rabit_args = "" if context.partitionId() == 0: @@ -558,14 +572,20 @@ def _train_booster(pandas_df_iter): dtrain, dval = None, [] if has_validation: dtrain, dval = _convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation, has_base_margin, + pandas_df_iter, + has_weight, + has_validation, + has_base_margin, dmatrix_kwargs=dmatrix_kwargs, ) # TODO: Question: do we need to add dtrain to dval list ? dval = [(dtrain, "training"), (dval, "validation")] else: dtrain = _convert_partition_data_to_dmatrix( - pandas_df_iter, has_weight, has_validation, has_base_margin, + pandas_df_iter, + has_weight, + has_validation, + has_base_margin, dmatrix_kwargs=dmatrix_kwargs, ) @@ -587,7 +607,9 @@ def _train_booster(pandas_df_iter): .mapPartitions(lambda x: x) .collect()[0][0] ) - result_xgb_model = self._convert_to_sklearn_model(cloudpickle.loads(result_ser_booster)) + result_xgb_model = self._convert_to_sklearn_model( + cloudpickle.loads(result_ser_booster) + ) return self._copyValues(self._create_pyspark_model(result_xgb_model)) def write(self): @@ -673,9 +695,13 @@ def _transform(self, dataset): predict_params = self._gen_predict_params_dict() has_base_margin = False - if self.isDefined(self.base_margin_col) and self.getOrDefault(self.base_margin_col): + if self.isDefined(self.base_margin_col) and self.getOrDefault( + self.base_margin_col + ): has_base_margin = True - base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias("baseMargin") + base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias( + "baseMargin" + ) @pandas_udf("double") def predict_udf(input_data: pd.DataFrame) -> pd.Series: @@ -686,8 +712,7 @@ def predict_udf(input_data: pd.DataFrame) -> pd.Series: base_margin = None preds = xgb_sklearn_model.predict( - X, base_margin=base_margin, validate_features=False, - **predict_params + X, base_margin=base_margin, validate_features=False, **predict_params ) return pd.Series(preds) @@ -723,9 +748,13 @@ def _transform(self, dataset): predict_params = self._gen_predict_params_dict() has_base_margin = False - if self.isDefined(self.base_margin_col) and self.getOrDefault(self.base_margin_col): + if self.isDefined(self.base_margin_col) and self.getOrDefault( + self.base_margin_col + ): has_base_margin = True - base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias("baseMargin") + base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias( + "baseMargin" + ) @pandas_udf( "rawPrediction array, prediction double, probability array" @@ -738,17 +767,18 @@ def predict_udf(input_data: pd.DataFrame) -> pd.DataFrame: base_margin = None margins = xgb_sklearn_model.predict( - X, base_margin=base_margin, output_margin=True, validate_features=False, - **predict_params + X, + base_margin=base_margin, + output_margin=True, + validate_features=False, + **predict_params, ) if margins.ndim == 1: # binomial case classone_probs = expit(margins) classzero_probs = 1.0 - classone_probs raw_preds = np.vstack((-margins, margins)).transpose() - class_probs = np.vstack( - (classzero_probs, classone_probs) - ).transpose() + class_probs = np.vstack((classzero_probs, classone_probs)).transpose() else: # multinomial case raw_preds = margins diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 3e43f516f4a4..47fb541dc60a 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -163,7 +163,11 @@ def _process_data_iter( def _convert_partition_data_to_dmatrix( - partition_data_iter, has_weight, has_validation, has_base_margin, dmatrix_kwargs=None + partition_data_iter, + has_weight, + has_validation, + has_base_margin, + dmatrix_kwargs=None, ): dmatrix_kwargs = dmatrix_kwargs or {} # if we are not using external storage, we use the standard method of parsing data. @@ -171,20 +175,38 @@ def _convert_partition_data_to_dmatrix( partition_data_iter, has_weight, has_validation, has_base_margin ) if has_validation: - train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m = train_val_data + ( + train_X, + train_y, + train_w, + train_b_m, + val_X, + val_y, + val_w, + val_b_m, + ) = train_val_data training_dmatrix = DMatrix( - data=train_X, label=train_y, weight=train_w, base_margin=train_b_m, + data=train_X, + label=train_y, + weight=train_w, + base_margin=train_b_m, **dmatrix_kwargs, ) val_dmatrix = DMatrix( - data=val_X, label=val_y, weight=val_w, base_margin=val_b_m, + data=val_X, + label=val_y, + weight=val_w, + base_margin=val_b_m, **dmatrix_kwargs, ) return training_dmatrix, val_dmatrix else: train_X, train_y, train_w, train_b_m = train_val_data training_dmatrix = DMatrix( - data=train_X, label=train_y, weight=train_w, base_margin=train_b_m, + data=train_X, + label=train_y, + weight=train_w, + base_margin=train_b_m, **dmatrix_kwargs, ) return training_dmatrix diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 6a59463edbc4..6564c8e4a65b 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -37,7 +37,7 @@ def serialize_xgb_model(model): # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") model.save_model(tmp_file_name) - with open(tmp_file_name, 'r') as f: + with open(tmp_file_name, "r") as f: ser_model_string = f.read() return ser_model_string @@ -174,6 +174,7 @@ class SparkXGBWriter(MLWriter): """ Spark Xgboost estimator writer. """ + def __init__(self, instance): super().__init__() self.instance = instance @@ -190,6 +191,7 @@ class SparkXGBReader(MLReader): """ Spark Xgboost estimator reader. """ + def __init__(self, cls): super().__init__() self.cls = cls @@ -209,6 +211,7 @@ class SparkXGBModelWriter(MLWriter): """ Spark Xgboost model writer. """ + def __init__(self, instance): super().__init__() self.instance = instance @@ -233,6 +236,7 @@ class SparkXGBModelReader(MLReader): """ Spark Xgboost model reader. """ + def __init__(self, cls): super().__init__() self.cls = cls From b92959e4ad03e665d42565226bf175eea5baeff6 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 7 Jul 2022 18:21:15 +0800 Subject: [PATCH 61/73] missing dependency. --- python-package/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/setup.py b/python-package/setup.py index 4f5ca6d74314..5e0a2ad7ecca 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -350,7 +350,7 @@ def run(self) -> None: 'dask': ['dask', 'pandas', 'distributed'], 'datatable': ['datatable'], 'plotting': ['graphviz', 'matplotlib'], - "pyspark": ["pyspark", "scikit-learn"], + "pyspark": ["pyspark", "scikit-learn", "cloudpickle"], }, maintainer='Hyunsu Cho', maintainer_email='chohyu01@cs.washington.edu', From b7bb690cd25b50a733a9a6125b455b30d222c54b Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 7 Jul 2022 18:26:43 +0800 Subject: [PATCH 62/73] Rest of them. --- .../test_spark_with_gpu/test_spark_with_gpu.py | 16 ++++++++++++++-- tests/python/test_spark/data_test.py | 10 ++++++++++ tests/python/test_spark/utils_test.py | 8 ++++++++ tests/python/test_spark/xgboost_local_test.py | 9 +++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py index 9d68751a266a..159189142216 100644 --- a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py +++ b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py @@ -1,9 +1,21 @@ -from pyspark.sql import SparkSession +import sys + import logging import pytest -from xgboost.spark import SparkXGBRegressor, SparkXGBClassifier import sklearn + +sys.path.append("tests/python") +import testing as tm + +if tm.no_dask()["condition"]: + pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) +if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + + +from pyspark.sql import SparkSession from pyspark.ml.linalg import Vectors +from xgboost.spark import SparkXGBRegressor, SparkXGBClassifier @pytest.fixture(scope="module", autouse=True) diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 9773a54092ab..1a9d8c1eb14e 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -1,8 +1,18 @@ +import sys import tempfile import shutil + +import pytest import numpy as np import pandas as pd +import testing as tm + +if tm.no_dask()["condition"]: + pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) +if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + from xgboost.spark.data import ( _row_tuple_list_to_feature_matrix_y_w, _convert_partition_data_to_dmatrix, diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index 1b3add9e7116..bd50e953efee 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -5,9 +5,17 @@ import tempfile import unittest +import pytest from six import StringIO +import testing as tm + +if tm.no_dask()["condition"]: + pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) +if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + from pyspark.sql import SQLContext from pyspark.sql import SparkSession diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 422413cbc742..6ee924d08316 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -1,9 +1,18 @@ +import sys import logging import random import uuid import numpy as np import pytest + +import testing as tm + +if tm.no_dask()["condition"]: + pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) +if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + from pyspark.ml.functions import vector_to_array from pyspark.sql import functions as spark_sql_func from pyspark.ml import Pipeline, PipelineModel From 888524743700486a75bbc0c88ad3e22dbc1e5a36 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 7 Jul 2022 16:43:34 +0000 Subject: [PATCH 63/73] Fix message --- tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py | 2 +- tests/python/test_spark/data_test.py | 2 +- tests/python/test_spark/utils_test.py | 2 +- tests/python/test_spark/xgboost_local_cluster_test.py | 2 +- tests/python/test_spark/xgboost_local_test.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py index 159189142216..ab6faed2c41b 100644 --- a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py +++ b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py @@ -10,7 +10,7 @@ if tm.no_dask()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from pyspark.sql import SparkSession diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 1a9d8c1eb14e..594277dafb2f 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -11,7 +11,7 @@ if tm.no_dask()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from xgboost.spark.data import ( _row_tuple_list_to_feature_matrix_y_w, diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index bd50e953efee..3998a07e1eb6 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -14,7 +14,7 @@ if tm.no_dask()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from pyspark.sql import SQLContext from pyspark.sql import SparkSession diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/xgboost_local_cluster_test.py index 4f47eb449090..71238c53a64b 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/xgboost_local_cluster_test.py @@ -11,7 +11,7 @@ if tm.no_dask()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from .utils_test import SparkLocalClusterTestCase from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/xgboost_local_test.py index 6ee924d08316..91dfa8449d48 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/xgboost_local_test.py @@ -11,7 +11,7 @@ if tm.no_dask()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from pyspark.ml.functions import vector_to_array from pyspark.sql import functions as spark_sql_func From 22f2103d32af45b03cbca2973e160f9b28ca4203 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 8 Jul 2022 09:41:23 +0800 Subject: [PATCH 64/73] fix CI config --- .github/workflows/python_tests.yml | 3 ++- tests/ci_build/Dockerfile.cpu | 3 ++- tests/python/test_spark/data_test.py | 5 ----- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index cbb6dc1e9789..db814b902d7a 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -131,6 +131,7 @@ jobs: - name: Install Python package shell: bash -l {0} run: | + echo JAVA_HOME=$JAVA_HOME cd python-package python --version python setup.py bdist_wheel --universal @@ -139,4 +140,4 @@ jobs: - name: Test Python package shell: bash -l {0} run: | - pytest -s -v ./tests/python + PYSPARK_DRIVER_PYTHON=`which python` PYSPARK_PYTHON=`which python` SPARK_TESTING=1 pytest -s -v ./tests/python diff --git a/tests/ci_build/Dockerfile.cpu b/tests/ci_build/Dockerfile.cpu index 49346f7fcb14..2528018350f8 100644 --- a/tests/ci_build/Dockerfile.cpu +++ b/tests/ci_build/Dockerfile.cpu @@ -10,7 +10,7 @@ RUN \ apt-get install -y software-properties-common && \ add-apt-repository ppa:ubuntu-toolchain-r/test && \ apt-get update && \ - apt-get install -y tar unzip wget git build-essential doxygen graphviz llvm libasan2 libidn11 ninja-build gcc-8 g++-8 && \ + apt-get install -y tar unzip wget git build-essential doxygen graphviz llvm libasan2 libidn11 ninja-build gcc-8 g++-8 openjdk-8-jdk-headless && \ # CMake wget -nv -nc https://cmake.org/files/v3.14/cmake-3.14.0-Linux-x86_64.sh --no-check-certificate && \ bash cmake-3.14.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ @@ -24,6 +24,7 @@ ENV CXX=g++-8 ENV CPP=cpp-8 ENV GOSU_VERSION 1.10 +ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/ # Create new Conda environment COPY conda_env/cpu_test.yml /scripts/ diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/data_test.py index 594277dafb2f..61bae2dcd4cc 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/data_test.py @@ -32,8 +32,6 @@ def row_tup_iter(data): pdf = pd.DataFrame(data) yield pdf - # row1 = Vectors.dense(1.0, 2.0, 3.0),), - # row2 = Vectors.sparse(3, {1: 1.0, 2: 5.5}) expected_ndarray = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]} feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( @@ -45,7 +43,6 @@ def row_tup_iter(data): ) self.assertIsNone(y) self.assertIsNone(w) - # self.assertTrue(isinstance(feature_matrix, csr_matrix)) self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) data["label"] = [1, 0] @@ -57,7 +54,6 @@ def row_tup_iter(data): has_predict_base_margin=False, ) self.assertIsNone(w) - # self.assertTrue(isinstance(feature_matrix, csr_matrix)) self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) self.assertTrue(np.array_equal(y, np.array(data["label"]))) @@ -69,7 +65,6 @@ def row_tup_iter(data): has_fit_base_margin=False, has_predict_base_margin=False, ) - # self.assertTrue(isinstance(feature_matrix, csr_matrix)) self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) self.assertTrue(np.array_equal(y, np.array(data["label"]))) self.assertTrue(np.array_equal(w, np.array(data["weight"]))) From 277984d7b209900185252854512cfe5f17034e58 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 8 Jul 2022 10:26:17 +0800 Subject: [PATCH 65/73] setup python-test action jdk --- .github/workflows/python_tests.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index db814b902d7a..347401f750da 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -102,6 +102,10 @@ jobs: with: submodules: 'true' + - uses: actions/setup-java@v1 + with: + java-version: 1.8 + - uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true From 0f9ffa0710291366b1ffc9fa371f24962eb69505 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 8 Jul 2022 12:30:07 +0800 Subject: [PATCH 66/73] update macos ci config --- .github/workflows/python_tests.yml | 2 +- tests/python/test_spark/utils_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index 347401f750da..fe708a90a079 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -104,7 +104,7 @@ jobs: - uses: actions/setup-java@v1 with: - java-version: 1.8 + java-version: 1.11 - uses: conda-incubator/setup-miniconda@v2 with: diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils_test.py index 3998a07e1eb6..ad3e376c4118 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils_test.py @@ -129,7 +129,7 @@ class SparkLocalClusterTestCase(TestSparkContext, TestTempDir, unittest.TestCase def setUpClass(cls): cls.setup_env( { - "spark.master": "local-cluster[2, 2, 1024]", + "spark.master": "local-cluster[2, 2, 512]", "spark.python.worker.reuse": "false", "spark.driver.host": "127.0.0.1", "spark.task.maxFailures": "1", From 3534c7517928be481870cafac33e9854a9df5c7d Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 8 Jul 2022 14:00:34 +0800 Subject: [PATCH 67/73] fix callback test --- .github/workflows/python_tests.yml | 1 + .../test_spark/{data_test.py => test_data.py} | 4 ++-- .../{xgboost_local_test.py => test_spark_local.py} | 13 +++++++------ ..._cluster_test.py => test_spark_local_cluster.py} | 4 ++-- tests/python/test_spark/{utils_test.py => utils.py} | 4 ++-- 5 files changed, 14 insertions(+), 12 deletions(-) rename tests/python/test_spark/{data_test.py => test_data.py} (98%) rename tests/python/test_spark/{xgboost_local_test.py => test_spark_local.py} (99%) rename tests/python/test_spark/{xgboost_local_cluster_test.py => test_spark_local_cluster.py} (99%) rename tests/python/test_spark/{utils_test.py => utils.py} (97%) diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index fe708a90a079..a970289b1c27 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -92,6 +92,7 @@ jobs: python-tests-on-macos: name: Test XGBoost Python package on ${{ matrix.config.os }} runs-on: ${{ matrix.config.os }} + timeout-minutes: 90 strategy: matrix: config: diff --git a/tests/python/test_spark/data_test.py b/tests/python/test_spark/test_data.py similarity index 98% rename from tests/python/test_spark/data_test.py rename to tests/python/test_spark/test_data.py index 61bae2dcd4cc..a5c8351e7af5 100644 --- a/tests/python/test_spark/data_test.py +++ b/tests/python/test_spark/test_data.py @@ -8,7 +8,7 @@ import testing as tm -if tm.no_dask()["condition"]: +if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) @@ -20,7 +20,7 @@ from xgboost import DMatrix, XGBClassifier from xgboost.training import train as worker_train -from .utils_test import SparkTestCase +from .utils import SparkTestCase import logging logging.getLogger("py4j").setLevel(logging.INFO) diff --git a/tests/python/test_spark/xgboost_local_test.py b/tests/python/test_spark/test_spark_local.py similarity index 99% rename from tests/python/test_spark/xgboost_local_test.py rename to tests/python/test_spark/test_spark_local.py index 91dfa8449d48..a66d7b7d4e7c 100644 --- a/tests/python/test_spark/xgboost_local_test.py +++ b/tests/python/test_spark/test_spark_local.py @@ -8,7 +8,7 @@ import testing as tm -if tm.no_dask()["condition"]: +if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) @@ -29,13 +29,17 @@ SparkXGBRegressor, SparkXGBRegressorModel, ) -from .utils_test import SparkTestCase +from .utils import SparkTestCase from xgboost import XGBClassifier, XGBRegressor from xgboost.spark.core import _non_booster_params logging.getLogger("py4j").setLevel(logging.INFO) +def custom_learning_rate_callback(boosting_round): + return 1.0 / (boosting_round + 1) + + class XgboostLocalTest(SparkTestCase): def setUp(self): logging.getLogger().setLevel("INFO") @@ -658,10 +662,7 @@ def test_callbacks(self): path = self.get_local_tmp_dir() - def custom_learning_rate(boosting_round): - return 1.0 / (boosting_round + 1) - - cb = [LearningRateScheduler(custom_learning_rate)] + cb = [LearningRateScheduler(custom_learning_rate_callback)] regressor = SparkXGBRegressor(callbacks=cb) # Test the save/load of the estimator instead of the model, since diff --git a/tests/python/test_spark/xgboost_local_cluster_test.py b/tests/python/test_spark/test_spark_local_cluster.py similarity index 99% rename from tests/python/test_spark/xgboost_local_cluster_test.py rename to tests/python/test_spark/test_spark_local_cluster.py index 71238c53a64b..4c8776a876d4 100644 --- a/tests/python/test_spark/xgboost_local_cluster_test.py +++ b/tests/python/test_spark/test_spark_local_cluster.py @@ -8,12 +8,12 @@ import numpy as np import testing as tm -if tm.no_dask()["condition"]: +if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) -from .utils_test import SparkLocalClusterTestCase +from .utils import SparkLocalClusterTestCase from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark.utils import _get_max_num_concurrent_tasks from pyspark.ml.linalg import Vectors diff --git a/tests/python/test_spark/utils_test.py b/tests/python/test_spark/utils.py similarity index 97% rename from tests/python/test_spark/utils_test.py rename to tests/python/test_spark/utils.py index ad3e376c4118..5f388c6af844 100644 --- a/tests/python/test_spark/utils_test.py +++ b/tests/python/test_spark/utils.py @@ -11,7 +11,7 @@ import testing as tm -if tm.no_dask()["condition"]: +if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) @@ -129,7 +129,7 @@ class SparkLocalClusterTestCase(TestSparkContext, TestTempDir, unittest.TestCase def setUpClass(cls): cls.setup_env( { - "spark.master": "local-cluster[2, 2, 512]", + "spark.master": "local-cluster[2, 2, 1024]", "spark.python.worker.reuse": "false", "spark.driver.host": "127.0.0.1", "spark.task.maxFailures": "1", From 197c267f059f29a1ee5c467b9d0ff5728ddb0088 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 8 Jul 2022 14:26:08 +0800 Subject: [PATCH 68/73] disable spark test on macos --- .github/workflows/python_tests.yml | 7 +------ tests/python/test_spark/test_data.py | 2 +- tests/python/test_spark/test_spark_local.py | 11 +++++------ tests/python/test_spark/test_spark_local_cluster.py | 2 +- tests/python/test_spark/utils.py | 2 +- 5 files changed, 9 insertions(+), 15 deletions(-) diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index a970289b1c27..38ca58b9a3c3 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -103,10 +103,6 @@ jobs: with: submodules: 'true' - - uses: actions/setup-java@v1 - with: - java-version: 1.11 - - uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true @@ -136,7 +132,6 @@ jobs: - name: Install Python package shell: bash -l {0} run: | - echo JAVA_HOME=$JAVA_HOME cd python-package python --version python setup.py bdist_wheel --universal @@ -145,4 +140,4 @@ jobs: - name: Test Python package shell: bash -l {0} run: | - PYSPARK_DRIVER_PYTHON=`which python` PYSPARK_PYTHON=`which python` SPARK_TESTING=1 pytest -s -v ./tests/python + pytest -s -v ./tests/python diff --git a/tests/python/test_spark/test_data.py b/tests/python/test_spark/test_data.py index a5c8351e7af5..9b6aa1b72305 100644 --- a/tests/python/test_spark/test_data.py +++ b/tests/python/test_spark/test_data.py @@ -10,7 +10,7 @@ if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) -if sys.platform.startswith("win"): +if sys.platform.startswith("win") or sys.platform.startswith("darwin"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from xgboost.spark.data import ( diff --git a/tests/python/test_spark/test_spark_local.py b/tests/python/test_spark/test_spark_local.py index a66d7b7d4e7c..f57e9b874d8a 100644 --- a/tests/python/test_spark/test_spark_local.py +++ b/tests/python/test_spark/test_spark_local.py @@ -10,7 +10,7 @@ if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) -if sys.platform.startswith("win"): +if sys.platform.startswith("win") or sys.platform.startswith("darwin"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from pyspark.ml.functions import vector_to_array @@ -36,10 +36,6 @@ logging.getLogger("py4j").setLevel(logging.INFO) -def custom_learning_rate_callback(boosting_round): - return 1.0 / (boosting_round + 1) - - class XgboostLocalTest(SparkTestCase): def setUp(self): logging.getLogger().setLevel("INFO") @@ -662,7 +658,10 @@ def test_callbacks(self): path = self.get_local_tmp_dir() - cb = [LearningRateScheduler(custom_learning_rate_callback)] + def custom_learning_rate(boosting_round): + return 1.0 / (boosting_round + 1) + + cb = [LearningRateScheduler(custom_learning_rate)] regressor = SparkXGBRegressor(callbacks=cb) # Test the save/load of the estimator instead of the model, since diff --git a/tests/python/test_spark/test_spark_local_cluster.py b/tests/python/test_spark/test_spark_local_cluster.py index 4c8776a876d4..60448fde818b 100644 --- a/tests/python/test_spark/test_spark_local_cluster.py +++ b/tests/python/test_spark/test_spark_local_cluster.py @@ -10,7 +10,7 @@ if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) -if sys.platform.startswith("win"): +if sys.platform.startswith("win") or sys.platform.startswith("darwin"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from .utils import SparkLocalClusterTestCase diff --git a/tests/python/test_spark/utils.py b/tests/python/test_spark/utils.py index 5f388c6af844..df15a8af9393 100644 --- a/tests/python/test_spark/utils.py +++ b/tests/python/test_spark/utils.py @@ -13,7 +13,7 @@ if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) -if sys.platform.startswith("win"): +if sys.platform.startswith("win") or sys.platform.startswith("darwin"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) from pyspark.sql import SQLContext From ee78b562f1662e4da5e4fc20ac4813b587785c0d Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 8 Jul 2022 23:18:43 +0800 Subject: [PATCH 69/73] remove disable lint --- python-package/xgboost/spark/core.py | 5 ----- python-package/xgboost/spark/data.py | 2 -- python-package/xgboost/spark/estimator.py | 1 - python-package/xgboost/spark/model.py | 3 --- python-package/xgboost/spark/params.py | 1 - python-package/xgboost/spark/utils.py | 2 -- tests/python/test_spark/utils.py | 2 -- 7 files changed, 16 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index bc812cbf89c7..324c3c266f23 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,10 +1,5 @@ # type: ignore """Xgboost pyspark integration submodule for core code.""" -# pylint: disable=import-error, consider-using-f-string, too-many-arguments, too-many-locals -# pylint: disable=invalid-name, fixme, too-many-lines, unbalanced-tuple-unpacking, no-else-return -# pylint: disable=protected-access, logging-fstring-interpolation, no-name-in-module -# pylint: disable=wrong-import-order, ungrouped-imports, too-few-public-methods, broad-except -# pylint: disable=too-many-statements import numpy as np import pandas as pd from scipy.special import expit, softmax diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 47fb541dc60a..ac8d6e9e3b47 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -1,7 +1,5 @@ # type: ignore """Xgboost pyspark integration submodule for data related functions.""" -# pylint: disable=import-error, consider-using-f-string, too-many-arguments, too-many-locals, -# pylint: disable=invalid-name, fixme, too-many-lines, unbalanced-tuple-unpacking, no-else-return from typing import Iterator import numpy as np import pandas as pd diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 92d5cfd52c2c..4ac24cb5f6bd 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,6 +1,5 @@ # type: ignore """Xgboost pyspark integration submodule for estimator API.""" -# pylint: disable=import-error from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from xgboost import XGBClassifier, XGBRegressor from .core import ( diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 6564c8e4a65b..eb70b4899de5 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -1,8 +1,5 @@ # type: ignore """Xgboost pyspark integration submodule for model API.""" -# pylint: disable=import-error, consider-using-f-string, unspecified-encoding, -# pylint: disable=invalid-name, fixme, unnecessary-lambda -# pylint: disable=protected-access, too-few-public-methods import base64 import os import uuid diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 27edd5c8f4ae..ff5f651a3dfa 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -1,6 +1,5 @@ # type: ignore """Xgboost pyspark integration submodule for params.""" -# pylint: disable=import-error, too-few-public-methods from pyspark.ml.param.shared import Param, Params diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index aa76657dd270..b324f60dd21b 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -1,7 +1,5 @@ # type: ignore """Xgboost pyspark integration submodule for helper functions.""" -# pylint: disable=import-error, consider-using-f-string, protected-access, wrong-import-order -# pylint: disable=invalid-name import inspect from threading import Thread import sys diff --git a/tests/python/test_spark/utils.py b/tests/python/test_spark/utils.py index df15a8af9393..549aadf5e6bd 100644 --- a/tests/python/test_spark/utils.py +++ b/tests/python/test_spark/utils.py @@ -91,7 +91,6 @@ def setup_env(cls, spark_config): logging.getLogger("pyspark").setLevel(logging.INFO) cls.sc = spark.sparkContext - cls.sql = SQLContext(cls.sc) cls.session = spark @classmethod @@ -100,7 +99,6 @@ def tear_down_env(cls): cls.session = None cls.sc.stop() cls.sc = None - cls.sql = None class SparkTestCase(TestSparkContext, TestTempDir, unittest.TestCase): From 42d0ce58cec5ed938b298a3ff94411b2a5a6b3d3 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 10 Jul 2022 16:59:40 +0800 Subject: [PATCH 70/73] address pylint errors --- python-package/xgboost/spark/core.py | 56 ++++++++++++----------- python-package/xgboost/spark/data.py | 51 +++++++-------------- python-package/xgboost/spark/estimator.py | 1 + python-package/xgboost/spark/model.py | 37 ++++++++------- python-package/xgboost/spark/params.py | 4 +- python-package/xgboost/spark/utils.py | 17 ++++--- 6 files changed, 81 insertions(+), 85 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 324c3c266f23..b582639ca3f3 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,9 +1,14 @@ # type: ignore """Xgboost pyspark integration submodule for core code.""" +# pylint: disable=fixme, too-many-ancestors, protected-access, no-member +import cloudpickle import numpy as np import pandas as pd -from scipy.special import expit, softmax +from scipy.special import expit, softmax # pylint: disable=no-name-in-module + +from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.ml import Estimator, Model +from pyspark.ml.linalg import VectorUDT from pyspark.ml.param.shared import ( HasFeaturesCol, HasLabelCol, @@ -16,12 +21,20 @@ from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.util import MLReadable, MLWritable from pyspark.sql.functions import col, pandas_udf, countDistinct, struct +from pyspark.sql.types import ( + ArrayType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, +) + +import xgboost from xgboost import XGBClassifier, XGBRegressor from xgboost.core import Booster -import cloudpickle -import xgboost from xgboost.training import train as worker_train -from .utils import get_logger, _get_max_num_concurrent_tasks + from .data import ( _convert_partition_data_to_dmatrix, ) @@ -32,6 +45,7 @@ SparkXGBModelWriter, ) from .utils import ( + get_logger, _get_max_num_concurrent_tasks, _get_default_params_from_func, get_class_name, RabitContext, @@ -44,17 +58,6 @@ HasBaseMarginCol, ) -from pyspark.ml.functions import array_to_vector, vector_to_array -from pyspark.sql.types import ( - ArrayType, - DoubleType, - FloatType, - IntegerType, - LongType, - ShortType, -) -from pyspark.ml.linalg import VectorUDT - # Put pyspark specific params here, they won't be passed to XGBoost. # like `validationIndicatorCol`, `base_margin_col` _pyspark_specific_params = [ @@ -116,7 +119,7 @@ } -class _XgboostParams( +class _SparkXGBParams( HasFeaturesCol, HasLabelCol, HasWeightCol, @@ -286,8 +289,9 @@ def _validate_params(self): if int(gpu_per_task) > 1: get_logger(self.__class__.__name__).warning( - f"You configured {gpu_per_task} GPU cores for each spark task, but in " - f"XGBoost training, every Spark task will only use one GPU core." + "You configured %s GPU cores for each spark task, but in " + "XGBoost training, every Spark task will only use one GPU core.", + gpu_per_task ) @@ -317,7 +321,7 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name): return features_array_col -class _SparkXGBEstimator(Estimator, _XgboostParams, MLReadable, MLWritable): +class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): def __init__(self): super().__init__() self._set_xgb_params_default() @@ -335,7 +339,7 @@ def __init__(self): arbitrary_params_dict={}, ) - def setParams(self, **kwargs): + def setParams(self, **kwargs): # pylint: disable=invalid-name """ Set params for the estimator. """ @@ -373,7 +377,7 @@ def setParams(self, **kwargs): def _pyspark_model_cls(cls): """ Subclasses should override this method and - returns a _XgboostModel subclass + returns a _SparkXGBModel subclass """ raise NotImplementedError() @@ -472,6 +476,7 @@ def _get_xgb_train_call_args(cls, train_params): return booster_params, kwargs_params def _fit(self, dataset): + # pylint: disable=too-many-statements self._validate_params() label_col = col(self.getOrDefault(self.labelCol)).alias("label") @@ -621,7 +626,7 @@ def read(cls): return SparkXGBReader(cls) -class _SparkXGBModel(Model, _XgboostParams, MLReadable, MLWritable): +class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): def __init__(self, xgb_sklearn_model=None): super().__init__() self._xgb_sklearn_model = xgb_sklearn_model @@ -829,12 +834,11 @@ def param_value_converter(v): if isinstance(v, np.generic): # convert numpy scalar values to corresponding python scalar values return np.array(v).item() - elif isinstance(v, dict): + if isinstance(v, dict): return {k: param_value_converter(nv) for k, nv in v.items()} - elif isinstance(v, list): + if isinstance(v, list): return [param_value_converter(nv) for nv in v] - else: - return v + return v def set_param_attrs(attr_name, param_obj_): param_obj_.typeConverter = param_value_converter diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index ac8d6e9e3b47..3ca8b88df3db 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -31,12 +31,12 @@ def _check_feature_dims(num_dims, expected_dims): return num_dims if num_dims != expected_dims: raise ValueError( - "Rows contain different feature dimensions: " - "Expecting {}, got {}.".format(expected_dims, num_dims) + f"Rows contain different feature dimensions: Expecting {expected_dims}, got {num_dims}." ) return expected_dims +# pylint: disable=too-many-arguments def _row_tuple_list_to_feature_matrix_y_w( data_iterator, train, @@ -54,6 +54,7 @@ def _row_tuple_list_to_feature_matrix_y_w( Note: the row_tuple_list will be cleared during executing for reducing peak memory consumption """ + # pylint: disable=too-many-locals expected_feature_dims = None label_list, weight_list, base_margin_list = [], [], [] label_val_list, weight_val_list, base_margin_val_list = [], [], [] @@ -114,6 +115,7 @@ def _row_tuple_list_to_feature_matrix_y_w( ) return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val return feature_matrix, y, w, b_m +# pylint: enable=too-many-arguments def _process_data_iter( @@ -130,26 +132,6 @@ def _process_data_iter( train base margin, val_X, val_y, val_w, val_b_m <- validation base margin) otherwise return (X, y, w, b_m <- base margin) """ - if train and has_validation: - ( - train_X, - train_y, - train_w, - train_b_m, - val_X, - val_y, - val_w, - val_b_m, - ) = _row_tuple_list_to_feature_matrix_y_w( - data_iterator, - train, - has_weight, - has_fit_base_margin, - has_predict_base_margin, - has_validation, - ) - return train_X, train_y, train_w, train_b_m, val_X, val_y, val_w, val_b_m - return _row_tuple_list_to_feature_matrix_y_w( data_iterator, train, @@ -167,6 +149,7 @@ def _convert_partition_data_to_dmatrix( has_base_margin, dmatrix_kwargs=None, ): + # pylint: disable=too-many-locals dmatrix_kwargs = dmatrix_kwargs or {} # if we are not using external storage, we use the standard method of parsing data. train_val_data = _prepare_train_val_data( @@ -174,7 +157,7 @@ def _convert_partition_data_to_dmatrix( ) if has_validation: ( - train_X, + train_x, train_y, train_w, train_b_m, @@ -184,7 +167,7 @@ def _convert_partition_data_to_dmatrix( val_b_m, ) = train_val_data training_dmatrix = DMatrix( - data=train_X, + data=train_x, label=train_y, weight=train_w, base_margin=train_b_m, @@ -198,13 +181,13 @@ def _convert_partition_data_to_dmatrix( **dmatrix_kwargs, ) return training_dmatrix, val_dmatrix - else: - train_X, train_y, train_w, train_b_m = train_val_data - training_dmatrix = DMatrix( - data=train_X, - label=train_y, - weight=train_w, - base_margin=train_b_m, - **dmatrix_kwargs, - ) - return training_dmatrix + + train_x, train_y, train_w, train_b_m = train_val_data + training_dmatrix = DMatrix( + data=train_x, + label=train_y, + weight=train_w, + base_margin=train_b_m, + **dmatrix_kwargs, + ) + return training_dmatrix diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 4ac24cb5f6bd..3f50ab2bf2b9 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,5 +1,6 @@ # type: ignore """Xgboost pyspark integration submodule for estimator API.""" +# pylint: disable=too-many-ancestors from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from xgboost import XGBClassifier, XGBRegressor from .core import ( diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index eb70b4899de5..506fdb1ba7e3 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -1,5 +1,6 @@ # type: ignore """Xgboost pyspark integration submodule for model API.""" +# pylint: disable=fixme, invalid-name, protected-access import base64 import os import uuid @@ -34,7 +35,7 @@ def serialize_xgb_model(model): # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") model.save_model(tmp_file_name) - with open(tmp_file_name, "r") as f: + with open(tmp_file_name, "r", encoding="utf-8") as f: ser_model_string = f.read() return ser_model_string @@ -46,7 +47,7 @@ def deserialize_xgb_model(ser_model_string, xgb_model_creator): xgb_model = xgb_model_creator() # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") - with open(tmp_file_name, "w") as f: + with open(tmp_file_name, "w", encoding="utf-8") as f: f.write(ser_model_string) xgb_model.load_model(tmp_file_name) return xgb_model @@ -64,7 +65,7 @@ def serialize_booster(booster): # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") booster.save_model(tmp_file_name) - with open(tmp_file_name) as f: + with open(tmp_file_name, encoding="utf-8") as f: ser_model_string = f.read() return ser_model_string @@ -76,7 +77,7 @@ def deserialize_booster(ser_model_string): booster = Booster() # TODO: change to use string io tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") - with open(tmp_file_name, "w") as f: + with open(tmp_file_name, "w", encoding="utf-8") as f: f.write(ser_model_string) booster.load_model(tmp_file_name) return booster @@ -93,13 +94,13 @@ class _SparkXGBSharedReadWrite: @staticmethod def saveMetadata(instance, path, sc, logger, extraMetadata=None): """ - Save the metadata of an xgboost.spark._XgboostEstimator or - xgboost.spark._XgboostModel. + Save the metadata of an xgboost.spark._SparkXGBEstimator or + xgboost.spark._SparkXGBModel. """ instance._validate_params() skipParams = ["callbacks", "xgb_model"] jsonParams = {} - for p, v in instance._paramMap.items(): + for p, v in instance._paramMap.items(): # pylint: disable=protected-access if p.name not in skipParams: jsonParams[p.name] = v @@ -131,8 +132,8 @@ def saveMetadata(instance, path, sc, logger, extraMetadata=None): @staticmethod def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): """ - Load the metadata and the instance of an xgboost.spark._XgboostEstimator or - xgboost.spark._XgboostModel. + Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or + xgboost.spark._SparkXGBModel. :return: a tuple of (metadata, instance) """ @@ -151,8 +152,8 @@ def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) except Exception as e: # pylint: disable=W0703 logger.warning( - "Fails to load the callbacks param due to {}. Please set the " - "callbacks param manually for the loaded estimator.".format(e) + f"Fails to load the callbacks param due to {e}. Please set the " + "callbacks param manually for the loaded estimator." ) if "init_booster" in metadata: @@ -163,7 +164,7 @@ def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): init_booster = deserialize_booster(ser_init_booster) pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) - pyspark_xgb._resetUid(metadata["uid"]) + pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access return metadata, pyspark_xgb @@ -216,7 +217,7 @@ def __init__(self, instance): def saveImpl(self, path): """ - Save metadata and model for a :py:class:`_XgboostModel` + Save metadata and model for a :py:class:`_SparkXGBModel` - save metadata to path/metadata - save model to path/model.json """ @@ -241,7 +242,7 @@ def __init__(self, cls): def load(self, path): """ - Load metadata and model for a :py:class:`_XgboostModel` + Load metadata and model for a :py:class:`_SparkXGBModel` :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance """ @@ -249,7 +250,7 @@ def load(self, path): self.cls, path, self.sc, self.logger ) - xgb_params = py_model._gen_xgb_params_dict() + xgb_sklearn_params = py_model._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True) model_load_path = os.path.join(path, "model.json") ser_xgb_model = ( @@ -258,8 +259,12 @@ def load(self, path): .collect()[0] .xgb_sklearn_model ) + + def create_xgb_model(): + return lambda: self.cls._xgb_cls()(**xgb_sklearn_params) + xgb_model = deserialize_xgb_model( - ser_xgb_model, lambda: self.cls._xgb_cls()(**xgb_params) + ser_xgb_model, create_xgb_model ) py_model._xgb_sklearn_model = xgb_model return py_model diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index ff5f651a3dfa..cd12c2e24f87 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -5,7 +5,7 @@ class HasArbitraryParamsDict(Params): """ - This is a Params based class that is extended by _XGBoostParams + This is a Params based class that is extended by _SparkXGBParams and holds the variable to store the **kwargs parts of the XGBoost input. """ @@ -21,7 +21,7 @@ class HasArbitraryParamsDict(Params): class HasBaseMarginCol(Params): """ - This is a Params based class that is extended by _XGBoostParams + This is a Params based class that is extended by _SparkXGBParams and holds the variable to store the base margin column part of XGboost. """ diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index b324f60dd21b..b358e9be5db9 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -5,11 +5,12 @@ import sys import logging -from xgboost import rabit -from xgboost.tracker import RabitTracker import pyspark from pyspark.sql.session import SparkSession +from xgboost import rabit +from xgboost.tracker import RabitTracker + def get_class_name(cls): """ @@ -71,6 +72,7 @@ def _get_rabit_args(context, n_workers): """ Get rabit context arguments to send to each worker. """ + # pylint: disable=consider-using-f-string env = _start_tracker(context, n_workers) rabit_args = [("%s=%s" % item).encode() for item in env.items()] return rabit_args @@ -117,11 +119,12 @@ def get_logger(name, level="INFO"): return logger -def _get_max_num_concurrent_tasks(sc): +def _get_max_num_concurrent_tasks(spark_context): """Gets the current max number of concurrent tasks.""" + # pylint: disable=protected-access # spark 3.1 and above has a different API for fetching max concurrent tasks - if sc._jsc.sc().version() >= "3.1": - return sc._jsc.sc().maxNumConcurrentTasks( - sc._jsc.sc().resourceProfileManager().resourceProfileFromId(0) + if spark_context._jsc.sc().version() >= "3.1": + return spark_context._jsc.sc().maxNumConcurrentTasks( + spark_context._jsc.sc().resourceProfileManager().resourceProfileFromId(0) ) - return sc._jsc.sc().maxNumConcurrentTasks() + return spark_context._jsc.sc().maxNumConcurrentTasks() From dc02023247601fad4de3a0938ff4e843535c8b52 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 10 Jul 2022 17:12:27 +0800 Subject: [PATCH 71/73] address pylint errors --- python-package/xgboost/spark/core.py | 15 ++++++++------- python-package/xgboost/spark/data.py | 9 ++++----- python-package/xgboost/spark/model.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index b582639ca3f3..cf440082d930 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,6 +1,6 @@ # type: ignore """Xgboost pyspark integration submodule for core code.""" -# pylint: disable=fixme, too-many-ancestors, protected-access, no-member +# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name import cloudpickle import numpy as np import pandas as pd @@ -432,7 +432,7 @@ def _repartition_needed(self, dataset): try: if self._query_plan_contains_valid_repartition(dataset): return False - except Exception: # noqa: E722 + except Exception: # pylint: disable=broad-except pass return True @@ -476,7 +476,7 @@ def _get_xgb_train_call_args(cls, train_params): return booster_params, kwargs_params def _fit(self, dataset): - # pylint: disable=too-many-statements + # pylint: disable=too-many-statements, too-many-locals self._validate_params() label_col = col(self.getOrDefault(self.labelCol)).alias("label") @@ -519,10 +519,11 @@ def _fit(self, dataset): if num_workers > max_concurrent_tasks: get_logger(self.__class__.__name__).warning( - f"The num_workers {num_workers} set for xgboost distributed " - f"training is greater than current max number of concurrent " - f"spark task slots, you need wait until more task slots available " - f"or you need increase spark cluster workers." + "The num_workers %s set for xgboost distributed " + "training is greater than current max number of concurrent " + "spark task slots, you need wait until more task slots available " + "or you need increase spark cluster workers.", + num_workers ) if self._repartition_needed(dataset): diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 3ca8b88df3db..16fe038edf15 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -1,5 +1,6 @@ # type: ignore """Xgboost pyspark integration submodule for data related functions.""" +# pylint: disable=too-many-arguments from typing import Iterator import numpy as np import pandas as pd @@ -36,7 +37,6 @@ def _check_feature_dims(num_dims, expected_dims): return expected_dims -# pylint: disable=too-many-arguments def _row_tuple_list_to_feature_matrix_y_w( data_iterator, train, @@ -115,7 +115,6 @@ def _row_tuple_list_to_feature_matrix_y_w( ) return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val return feature_matrix, y, w, b_m -# pylint: enable=too-many-arguments def _process_data_iter( @@ -149,7 +148,7 @@ def _convert_partition_data_to_dmatrix( has_base_margin, dmatrix_kwargs=None, ): - # pylint: disable=too-many-locals + # pylint: disable=too-many-locals, unbalanced-tuple-unpacking dmatrix_kwargs = dmatrix_kwargs or {} # if we are not using external storage, we use the standard method of parsing data. train_val_data = _prepare_train_val_data( @@ -161,7 +160,7 @@ def _convert_partition_data_to_dmatrix( train_y, train_w, train_b_m, - val_X, + val_x, val_y, val_w, val_b_m, @@ -174,7 +173,7 @@ def _convert_partition_data_to_dmatrix( **dmatrix_kwargs, ) val_dmatrix = DMatrix( - data=val_X, + data=val_x, label=val_y, weight=val_w, base_margin=val_b_m, diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 506fdb1ba7e3..c829d9717d08 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -261,7 +261,7 @@ def load(self, path): ) def create_xgb_model(): - return lambda: self.cls._xgb_cls()(**xgb_sklearn_params) + return self.cls._xgb_cls()(**xgb_sklearn_params) xgb_model = deserialize_xgb_model( ser_xgb_model, create_xgb_model From 520d754648339431eff5fbd75e0cc9d046bcb79d Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 10 Jul 2022 23:39:05 +0800 Subject: [PATCH 72/73] address pylint issues --- .github/workflows/main.yml | 2 +- python-package/xgboost/spark/core.py | 1 + python-package/xgboost/spark/model.py | 2 +- python-package/xgboost/spark/params.py | 1 + 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d7ec12c78cd0..5d2f66988bd1 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -141,7 +141,7 @@ jobs: - name: Install Python packages run: | python -m pip install wheel setuptools - python -m pip install pylint cpplint numpy scipy scikit-learn + python -m pip install pylint cpplint numpy scipy scikit-learn pyspark - name: Run lint run: | make lint diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index cf440082d930..68a15a534f33 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,6 +1,7 @@ # type: ignore """Xgboost pyspark integration submodule for core code.""" # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name +# pylint: disable=too-few-public-methods import cloudpickle import numpy as np import pandas as pd diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index c829d9717d08..4573d91df098 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -1,6 +1,6 @@ # type: ignore """Xgboost pyspark integration submodule for model API.""" -# pylint: disable=fixme, invalid-name, protected-access +# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods import base64 import os import uuid diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index cd12c2e24f87..9528eb69dd70 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -1,5 +1,6 @@ # type: ignore """Xgboost pyspark integration submodule for params.""" +# pylint: disable=too-few-public-methods from pyspark.ml.param.shared import Param, Params From 5237d81150a77c100180a909d577bf09698e83b6 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 11 Jul 2022 10:18:38 +0800 Subject: [PATCH 73/73] fix lint ci action --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5d2f66988bd1..7821a6e9cf9b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -141,7 +141,7 @@ jobs: - name: Install Python packages run: | python -m pip install wheel setuptools - python -m pip install pylint cpplint numpy scipy scikit-learn pyspark + python -m pip install pylint cpplint numpy scipy scikit-learn pyspark pandas cloudpickle - name: Run lint run: | make lint