diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 9cf5abab90d7..add414c8746c 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -2,7 +2,7 @@ """Xgboost pyspark integration submodule for core code.""" # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=too-few-public-methods -from typing import Iterator, Tuple +from typing import Iterator, Optional, Tuple import numpy as np import pandas as pd @@ -26,6 +26,7 @@ DoubleType, FloatType, IntegerType, + IntegralType, LongType, ShortType, ) @@ -43,7 +44,7 @@ SparkXGBReader, SparkXGBWriter, ) -from .params import HasArbitraryParamsDict, HasBaseMarginCol +from .params import HasArbitraryParamsDict, HasBaseMarginCol, HasFeaturesCols from .utils import ( RabitContext, _get_args_from_message_list, @@ -73,14 +74,10 @@ "num_workers", "use_gpu", "feature_names", + "features_cols", ] -_non_booster_params = [ - "missing", - "n_estimators", - "feature_types", - "feature_weights", -] +_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"] _pyspark_param_alias_map = { "features_col": "featuresCol", @@ -126,6 +123,7 @@ class _SparkXGBParams( HasValidationIndicatorCol, HasArbitraryParamsDict, HasBaseMarginCol, + HasFeaturesCols, ): num_workers = Param( Params._dummy(), @@ -240,12 +238,11 @@ def _gen_predict_params_dict(self): def _validate_params(self): init_model = self.getOrDefault(self.xgb_model) - if init_model is not None: - if init_model is not None and not isinstance(init_model, Booster): - raise ValueError( - "The xgb_model param must be set with a `xgboost.core.Booster` " - "instance." - ) + if init_model is not None and not isinstance(init_model, Booster): + raise ValueError( + "The xgb_model param must be set with a `xgboost.core.Booster` " + "instance." + ) if self.getOrDefault(self.num_workers) < 1: raise ValueError( @@ -262,6 +259,14 @@ def _validate_params(self): "Therefore, that parameter will be ignored." ) + if self.getOrDefault(self.features_cols): + if not self.getOrDefault(self.use_gpu): + raise ValueError("features_cols param requires enabling use_gpu.") + + get_logger(self.__class__.__name__).warning( + "If features_cols param set, then features_col param is ignored." + ) + if self.getOrDefault(self.use_gpu): tree_method = self.getParam("tree_method") if ( @@ -315,6 +320,23 @@ def _validate_params(self): ) +def _validate_and_convert_feature_col_as_float_col_list( + dataset, features_col_names: list +) -> list: + """Values in feature columns must be integral types or float/double types""" + feature_cols = [] + for c in features_col_names: + if isinstance(dataset.schema[c].dataType, DoubleType): + feature_cols.append(col(c).cast(FloatType()).alias(c)) + elif isinstance(dataset.schema[c].dataType, (FloatType, IntegralType)): + feature_cols.append(col(c)) + else: + raise ValueError( + "Values in feature columns must be integral types or float/double types." + ) + return feature_cols + + def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name): features_col_datatype = dataset.schema[features_col_name].dataType features_col = col(features_col_name) @@ -373,8 +395,14 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead." ) if k in _pyspark_param_alias_map: - real_k = _pyspark_param_alias_map[k] - k = real_k + if k == _inverse_pyspark_param_alias_map[ + self.featuresCol.name + ] and isinstance(v, list): + real_k = self.features_cols.name + k = real_k + else: + real_k = _pyspark_param_alias_map[k] + k = real_k if self.hasParam(k): self._set(**{str(k): v}) @@ -497,10 +525,19 @@ def _fit(self, dataset): self._validate_params() label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label) - features_array_col = _validate_and_convert_feature_col_as_array_col( - dataset, self.getOrDefault(self.featuresCol) - ) - select_cols = [features_array_col, label_col] + select_cols = [label_col] + features_cols_names = None + if self.getOrDefault(self.features_cols): + features_cols_names = self.getOrDefault(self.features_cols) + features_cols = _validate_and_convert_feature_col_as_float_col_list( + dataset, features_cols_names + ) + select_cols.extend(features_cols) + else: + features_array_col = _validate_and_convert_feature_col_as_array_col( + dataset, self.getOrDefault(self.featuresCol) + ) + select_cols.append(features_array_col) if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol): select_cols.append( @@ -569,10 +606,17 @@ def _train_booster(pandas_df_iter): context = BarrierTaskContext.get() context.barrier() + gpu_id = None if use_gpu: - booster_params["gpu_id"] = ( - context.partitionId() if is_local else _get_gpu_id(context) - ) + gpu_id = context.partitionId() if is_local else _get_gpu_id(context) + booster_params["gpu_id"] = gpu_id + + # max_bin is needed for qdm + if ( + features_cols_names is not None + and booster_params.get("max_bin", None) is not None + ): + dmatrix_kwargs["max_bin"] = booster_params["max_bin"] _rabit_args = "" if context.partitionId() == 0: @@ -583,9 +627,7 @@ def _train_booster(pandas_df_iter): evals_result = {} with RabitContext(_rabit_args, context): dtrain, dvalid = create_dmatrix_from_partitions( - pandas_df_iter, - None, - dmatrix_kwargs, + pandas_df_iter, features_cols_names, gpu_id, dmatrix_kwargs ) if dvalid is not None: dval = [(dtrain, "training"), (dvalid, "validation")] @@ -685,6 +727,34 @@ def read(cls): def _transform(self, dataset): raise NotImplementedError() + def _get_feature_col(self, dataset) -> (list, Optional[list]): + """XGBoost model trained with features_cols parameter can also predict + vector or array feature type. But first we need to check features_cols + and then featuresCol + """ + + feature_col_names = self.getOrDefault(self.features_cols) + features_col = [] + if feature_col_names and set(feature_col_names).issubset(set(dataset.columns)): + # The model is trained with features_cols and the predicted dataset + # also contains all the columns specified by features_cols. + features_col = _validate_and_convert_feature_col_as_float_col_list( + dataset, feature_col_names + ) + else: + # 1. The model was trained by features_cols, but the dataset doesn't contain + # all the columns specified by features_cols, so we need to check if + # the dataframe has the featuresCol + # 2. The model was trained by featuresCol, and the predicted dataset must contain + # featuresCol column. + feature_col_names = None + features_col.append( + _validate_and_convert_feature_col_as_array_col( + dataset, self.getOrDefault(self.featuresCol) + ) + ) + return features_col, feature_col_names + class SparkXGBRegressorModel(_SparkXGBModel): """ @@ -712,11 +782,17 @@ def _transform(self, dataset): alias.margin ) + features_col, feature_col_names = self._get_feature_col(dataset) + @pandas_udf("double") def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: model = xgb_sklearn_model for data in iterator: - X = stack_series(data[alias.data]) + if feature_col_names is not None: + X = data[feature_col_names] + else: + X = stack_series(data[alias.data]) + if has_base_margin: base_margin = data[alias.margin].to_numpy() else: @@ -730,14 +806,10 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: ) yield pd.Series(preds) - features_col = _validate_and_convert_feature_col_as_array_col( - dataset, self.getOrDefault(self.featuresCol) - ) - if has_base_margin: - pred_col = predict_udf(struct(features_col, base_margin_col)) + pred_col = predict_udf(struct(*features_col, base_margin_col)) else: - pred_col = predict_udf(struct(features_col)) + pred_col = predict_udf(struct(*features_col)) predictionColName = self.getOrDefault(self.predictionCol) @@ -783,6 +855,8 @@ def transform_margin(margins: np.ndarray): class_probs = softmax(raw_preds, axis=1) return raw_preds, class_probs + features_col, feature_col_names = self._get_feature_col(dataset) + @pandas_udf( "rawPrediction array, prediction double, probability array" ) @@ -791,7 +865,11 @@ def predict_udf( ) -> Iterator[pd.DataFrame]: model = xgb_sklearn_model for data in iterator: - X = stack_series(data[alias.data]) + if feature_col_names is not None: + X = data[feature_col_names] + else: + X = stack_series(data[alias.data]) + if has_base_margin: base_margin = stack_series(data[alias.margin]) else: @@ -817,14 +895,10 @@ def predict_udf( } ) - features_col = _validate_and_convert_feature_col_as_array_col( - dataset, self.getOrDefault(self.featuresCol) - ) - if has_base_margin: - pred_struct = predict_udf(struct(features_col, base_margin_col)) + pred_struct = predict_udf(struct(*features_col, base_margin_col)) else: - pred_struct = predict_udf(struct(features_col)) + pred_struct = predict_udf(struct(*features_col)) pred_struct_col = "_prediction_struct" diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index e3fda4c14d03..eb825df73827 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -63,9 +63,9 @@ def make_blob(part: pd.DataFrame, is_valid: bool) -> None: class PartIter(DataIter): """Iterator for creating Quantile DMatrix from partitions.""" - def __init__(self, data: Dict[str, List], on_device: bool) -> None: + def __init__(self, data: Dict[str, List], device_id: Optional[int]) -> None: self._iter = 0 - self._cuda = on_device + self._device_id = device_id self._data = data super().__init__() @@ -74,9 +74,13 @@ def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFram if not data: return None - if self._cuda: + if self._device_id is not None: import cudf # pylint: disable=import-error + import cupy as cp # pylint: disable=import-error + # We must set the device after import cudf, which will change the device id to 0 + # See https://github.com/rapidsai/cudf/issues/11386 + cp.cuda.runtime.setDevice(self._device_id) return cudf.DataFrame(data[self._iter]) return data[self._iter] @@ -100,6 +104,7 @@ def reset(self) -> None: def create_dmatrix_from_partitions( iterator: Iterator[pd.DataFrame], feature_cols: Optional[Sequence[str]], + gpu_id: Optional[int], kwargs: Dict[str, Any], # use dict to make sure this parameter is passed. ) -> Tuple[DMatrix, Optional[DMatrix]]: """Create DMatrix from spark data partitions. This is not particularly efficient as @@ -169,7 +174,7 @@ def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix dtrain = make(train_data, kwargs) else: cache_partitions(iterator, append_dqm) - it = PartIter(train_data, True) + it = PartIter(train_data, gpu_id) dtrain = DeviceQuantileDMatrix(it, **kwargs) dvalid = make(valid_data, kwargs) if len(valid_data) != 0 else None diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 9528eb69dd70..7a9844e532c9 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -1,6 +1,7 @@ # type: ignore """Xgboost pyspark integration submodule for params.""" # pylint: disable=too-few-public-methods +from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import Param, Params @@ -31,3 +32,21 @@ class HasBaseMarginCol(Params): "base_margin_col", "This stores the name for the column of the base margin", ) + + +class HasFeaturesCols(Params): + """ + Mixin for param featuresCols: a list of feature column names. + This parameter is taken effect only when use_gpu is enabled. + """ + + features_cols = Param( + Params._dummy(), + "features_cols", + "feature column names.", + typeConverter=TypeConverters.toListString, + ) + + def __init__(self): + super().__init__() + self._setDefault(features_cols=[]) diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 8008c5774250..cfc4b85988e2 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -115,7 +115,7 @@ def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int: "python-package/xgboost/dask.py", "python-package/xgboost/spark", "tests/python/test_spark/test_data.py", - "tests/python-gpu/test_spark_with_gpu/test_data.py", + "tests/python-gpu/test_gpu_spark/test_data.py", "tests/ci_build/lint_python.py", ] ): @@ -130,9 +130,9 @@ def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int: "demo/guide-python/cat_in_the_dat.py", "tests/python/test_data_iterator.py", "tests/python/test_spark/test_data.py", - "tests/python-gpu/test_gpu_with_dask.py", + "tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py", "tests/python-gpu/test_gpu_data_iterator.py", - "tests/python-gpu/test_spark_with_gpu/test_data.py", + "tests/python-gpu/test_gpu_spark/test_data.py", "tests/ci_build/lint_python.py", ] ): diff --git a/tests/python-gpu/conftest.py b/tests/python-gpu/conftest.py index 789f96fc5511..c97b1a1668bf 100644 --- a/tests/python-gpu/conftest.py +++ b/tests/python-gpu/conftest.py @@ -61,8 +61,8 @@ def pytest_collection_modifyitems(config, items): mgpu_mark = pytest.mark.mgpu for item in items: if item.nodeid.startswith( - "python-gpu/test_gpu_with_dask.py" + "python-gpu/test_gpu_with_dask/test_gpu_with_dask.py" ) or item.nodeid.startswith( - "python-gpu/test_spark_with_gpu/test_spark_with_gpu.py" + "python-gpu/test_gpu_spark/test_gpu_spark.py" ): item.add_marker(mgpu_mark) diff --git a/tests/python-gpu/test_spark_with_gpu/discover_gpu.sh b/tests/python-gpu/test_gpu_spark/discover_gpu.sh similarity index 100% rename from tests/python-gpu/test_spark_with_gpu/discover_gpu.sh rename to tests/python-gpu/test_gpu_spark/discover_gpu.sh diff --git a/tests/python-gpu/test_spark_with_gpu/test_data.py b/tests/python-gpu/test_gpu_spark/test_data.py similarity index 100% rename from tests/python-gpu/test_spark_with_gpu/test_data.py rename to tests/python-gpu/test_gpu_spark/test_data.py diff --git a/tests/python-gpu/test_gpu_spark/test_gpu_spark.py b/tests/python-gpu/test_gpu_spark/test_gpu_spark.py new file mode 100644 index 000000000000..ce5b9d8c8d42 --- /dev/null +++ b/tests/python-gpu/test_gpu_spark/test_gpu_spark.py @@ -0,0 +1,215 @@ +import logging +import sys + +import pytest +import sklearn + +sys.path.append("tests/python") +import testing as tm + +if tm.no_dask()["condition"]: + pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) +if sys.platform.startswith("win"): + pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) + +from pyspark.ml.linalg import Vectors +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import SparkSession +from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor + +gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh" +executor_gpu_amount = 4 +executor_cores = 4 +num_workers = executor_gpu_amount + + +@pytest.fixture(scope="module", autouse=True) +def spark_session_with_gpu(): + spark_config = { + "spark.master": f"local-cluster[1, {executor_gpu_amount}, 1024]", + "spark.python.worker.reuse": "false", + "spark.driver.host": "127.0.0.1", + "spark.task.maxFailures": "1", + "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", + "spark.sql.pyspark.jvmStacktrace.enabled": "true", + "spark.cores.max": executor_cores, + "spark.task.cpus": "1", + "spark.executor.cores": executor_cores, + "spark.worker.resource.gpu.amount": executor_gpu_amount, + "spark.task.resource.gpu.amount": "1", + "spark.executor.resource.gpu.amount": executor_gpu_amount, + "spark.worker.resource.gpu.discoveryScript": gpu_discovery_script_path, + } + builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU") + for k, v in spark_config.items(): + builder.config(k, v) + spark = builder.getOrCreate() + logging.getLogger("pyspark").setLevel(logging.INFO) + # We run a dummy job so that we block until the workers have connected to the master + spark.sparkContext.parallelize( + range(num_workers), num_workers + ).barrier().mapPartitions(lambda _: []).collect() + yield spark + spark.stop() + + +@pytest.fixture +def spark_iris_dataset(spark_session_with_gpu): + spark = spark_session_with_gpu + data = sklearn.datasets.load_iris() + train_rows = [ + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[0::2], data.target[0::2]) + ] + train_df = spark.createDataFrame( + spark.sparkContext.parallelize(train_rows, num_workers), ["features", "label"] + ) + test_rows = [ + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[1::2], data.target[1::2]) + ] + test_df = spark.createDataFrame( + spark.sparkContext.parallelize(test_rows, num_workers), ["features", "label"] + ) + return train_df, test_df + + +@pytest.fixture +def spark_iris_dataset_feature_cols(spark_session_with_gpu): + spark = spark_session_with_gpu + data = sklearn.datasets.load_iris() + train_rows = [ + (*features.tolist(), float(label)) + for features, label in zip(data.data[0::2], data.target[0::2]) + ] + train_df = spark.createDataFrame( + spark.sparkContext.parallelize(train_rows, num_workers), + [*data.feature_names, "label"], + ) + test_rows = [ + (*features.tolist(), float(label)) + for features, label in zip(data.data[1::2], data.target[1::2]) + ] + test_df = spark.createDataFrame( + spark.sparkContext.parallelize(test_rows, num_workers), + [*data.feature_names, "label"], + ) + return train_df, test_df, data.feature_names + + +@pytest.fixture +def spark_diabetes_dataset(spark_session_with_gpu): + spark = spark_session_with_gpu + data = sklearn.datasets.load_diabetes() + train_rows = [ + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[0::2], data.target[0::2]) + ] + train_df = spark.createDataFrame( + spark.sparkContext.parallelize(train_rows, num_workers), ["features", "label"] + ) + test_rows = [ + (Vectors.dense(features), float(label)) + for features, label in zip(data.data[1::2], data.target[1::2]) + ] + test_df = spark.createDataFrame( + spark.sparkContext.parallelize(test_rows, num_workers), ["features", "label"] + ) + return train_df, test_df + + +@pytest.fixture +def spark_diabetes_dataset_feature_cols(spark_session_with_gpu): + spark = spark_session_with_gpu + data = sklearn.datasets.load_diabetes() + train_rows = [ + (*features.tolist(), float(label)) + for features, label in zip(data.data[0::2], data.target[0::2]) + ] + train_df = spark.createDataFrame( + spark.sparkContext.parallelize(train_rows, num_workers), + [*data.feature_names, "label"], + ) + test_rows = [ + (*features.tolist(), float(label)) + for features, label in zip(data.data[1::2], data.target[1::2]) + ] + test_df = spark.createDataFrame( + spark.sparkContext.parallelize(test_rows, num_workers), + [*data.feature_names, "label"], + ) + return train_df, test_df, data.feature_names + + +def test_sparkxgb_classifier_with_gpu(spark_iris_dataset): + from pyspark.ml.evaluation import MulticlassClassificationEvaluator + + classifier = SparkXGBClassifier(use_gpu=True, num_workers=num_workers) + train_df, test_df = spark_iris_dataset + model = classifier.fit(train_df) + pred_result_df = model.transform(test_df) + evaluator = MulticlassClassificationEvaluator(metricName="f1") + f1 = evaluator.evaluate(pred_result_df) + assert f1 >= 0.97 + + +def test_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_cols): + from pyspark.ml.evaluation import MulticlassClassificationEvaluator + + train_df, test_df, feature_names = spark_iris_dataset_feature_cols + + classifier = SparkXGBClassifier( + features_col=feature_names, use_gpu=True, num_workers=num_workers + ) + + model = classifier.fit(train_df) + pred_result_df = model.transform(test_df) + evaluator = MulticlassClassificationEvaluator(metricName="f1") + f1 = evaluator.evaluate(pred_result_df) + assert f1 >= 0.97 + + +def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_cols): + from pyspark.ml.evaluation import MulticlassClassificationEvaluator + + train_df, test_df, feature_names = spark_iris_dataset_feature_cols + + classifier = SparkXGBClassifier( + features_col=feature_names, use_gpu=True, num_workers=num_workers + ) + grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build() + evaluator = MulticlassClassificationEvaluator(metricName="f1") + cv = CrossValidator( + estimator=classifier, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3 + ) + cvModel = cv.fit(train_df) + pred_result_df = cvModel.transform(test_df) + f1 = evaluator.evaluate(pred_result_df) + assert f1 >= 0.97 + + +def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): + from pyspark.ml.evaluation import RegressionEvaluator + + regressor = SparkXGBRegressor(use_gpu=True, num_workers=num_workers) + train_df, test_df = spark_diabetes_dataset + model = regressor.fit(train_df) + pred_result_df = model.transform(test_df) + evaluator = RegressionEvaluator(metricName="rmse") + rmse = evaluator.evaluate(pred_result_df) + assert rmse <= 65.0 + + +def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature_cols): + from pyspark.ml.evaluation import RegressionEvaluator + + train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols + regressor = SparkXGBRegressor( + features_col=feature_names, use_gpu=True, num_workers=num_workers + ) + + model = regressor.fit(train_df) + pred_result_df = model.transform(test_df) + evaluator = RegressionEvaluator(metricName="rmse") + rmse = evaluator.evaluate(pred_result_df) + assert rmse <= 65.0 diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py similarity index 100% rename from tests/python-gpu/test_gpu_with_dask.py rename to tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py diff --git a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py b/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py deleted file mode 100644 index ab6faed2c41b..000000000000 --- a/tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py +++ /dev/null @@ -1,120 +0,0 @@ -import sys - -import logging -import pytest -import sklearn - -sys.path.append("tests/python") -import testing as tm - -if tm.no_dask()["condition"]: - pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) -if sys.platform.startswith("win"): - pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) - - -from pyspark.sql import SparkSession -from pyspark.ml.linalg import Vectors -from xgboost.spark import SparkXGBRegressor, SparkXGBClassifier - - -@pytest.fixture(scope="module", autouse=True) -def spark_session_with_gpu(): - spark_config = { - "spark.master": "local-cluster[1, 4, 1024]", - "spark.python.worker.reuse": "false", - "spark.driver.host": "127.0.0.1", - "spark.task.maxFailures": "1", - "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", - "spark.sql.pyspark.jvmStacktrace.enabled": "true", - "spark.cores.max": "4", - "spark.task.cpus": "1", - "spark.executor.cores": "4", - "spark.worker.resource.gpu.amount": "4", - "spark.task.resource.gpu.amount": "1", - "spark.executor.resource.gpu.amount": "4", - "spark.worker.resource.gpu.discoveryScript": "tests/python-gpu/test_spark_with_gpu/discover_gpu.sh", - } - builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU") - for k, v in spark_config.items(): - builder.config(k, v) - spark = builder.getOrCreate() - logging.getLogger("pyspark").setLevel(logging.INFO) - # We run a dummy job so that we block until the workers have connected to the master - spark.sparkContext.parallelize(range(4), 4).barrier().mapPartitions( - lambda _: [] - ).collect() - yield spark - spark.stop() - - -@pytest.fixture -def spark_iris_dataset(spark_session_with_gpu): - spark = spark_session_with_gpu - data = sklearn.datasets.load_iris() - train_rows = [ - (Vectors.dense(features), float(label)) - for features, label in zip(data.data[0::2], data.target[0::2]) - ] - train_df = spark.createDataFrame( - spark.sparkContext.parallelize(train_rows, 4), ["features", "label"] - ) - test_rows = [ - (Vectors.dense(features), float(label)) - for features, label in zip(data.data[1::2], data.target[1::2]) - ] - test_df = spark.createDataFrame( - spark.sparkContext.parallelize(test_rows, 4), ["features", "label"] - ) - return train_df, test_df - - -@pytest.fixture -def spark_diabetes_dataset(spark_session_with_gpu): - spark = spark_session_with_gpu - data = sklearn.datasets.load_diabetes() - train_rows = [ - (Vectors.dense(features), float(label)) - for features, label in zip(data.data[0::2], data.target[0::2]) - ] - train_df = spark.createDataFrame( - spark.sparkContext.parallelize(train_rows, 4), ["features", "label"] - ) - test_rows = [ - (Vectors.dense(features), float(label)) - for features, label in zip(data.data[1::2], data.target[1::2]) - ] - test_df = spark.createDataFrame( - spark.sparkContext.parallelize(test_rows, 4), ["features", "label"] - ) - return train_df, test_df - - -def test_sparkxgb_classifier_with_gpu(spark_iris_dataset): - from pyspark.ml.evaluation import MulticlassClassificationEvaluator - - classifier = SparkXGBClassifier( - use_gpu=True, - num_workers=4, - ) - train_df, test_df = spark_iris_dataset - model = classifier.fit(train_df) - pred_result_df = model.transform(test_df) - evaluator = MulticlassClassificationEvaluator(metricName="f1") - f1 = evaluator.evaluate(pred_result_df) - assert f1 >= 0.97 - - -def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): - from pyspark.ml.evaluation import RegressionEvaluator - - regressor = SparkXGBRegressor( - use_gpu=True, - num_workers=4, - ) - train_df, test_df = spark_diabetes_dataset - model = regressor.fit(train_df) - pred_result_df = model.transform(test_df) - evaluator = RegressionEvaluator(metricName="rmse") - rmse = evaluator.evaluate(pred_result_df) - assert rmse <= 65.0 diff --git a/tests/python/test_spark/test_data.py b/tests/python/test_spark/test_data.py index a3da4764ad4e..07896ece618b 100644 --- a/tests/python/test_spark/test_data.py +++ b/tests/python/test_spark/test_data.py @@ -62,9 +62,11 @@ def run_dmatrix_ctor(is_dqm: bool) -> None: kwargs = {"feature_types": feature_types} if is_dqm: cols = [f"feat-{i}" for i in range(n_features)] - train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, kwargs) + train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, 0, kwargs) else: - train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), None, kwargs) + train_Xy, valid_Xy = create_dmatrix_from_partitions( + iter(dfs), None, None, kwargs + ) assert valid_Xy is not None assert valid_Xy.num_row() + train_Xy.num_row() == n_samples_per_batch * n_batches