Skip to content

Commit

Permalink
[pyspark] support a list of feature column names (#8117)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Aug 8, 2022
1 parent bcc8679 commit 03cc3b3
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 171 deletions.
154 changes: 114 additions & 40 deletions python-package/xgboost/spark/core.py
Expand Up @@ -2,7 +2,7 @@
"""Xgboost pyspark integration submodule for core code."""
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods
from typing import Iterator, Tuple
from typing import Iterator, Optional, Tuple

import numpy as np
import pandas as pd
Expand All @@ -26,6 +26,7 @@
DoubleType,
FloatType,
IntegerType,
IntegralType,
LongType,
ShortType,
)
Expand All @@ -43,7 +44,7 @@
SparkXGBReader,
SparkXGBWriter,
)
from .params import HasArbitraryParamsDict, HasBaseMarginCol
from .params import HasArbitraryParamsDict, HasBaseMarginCol, HasFeaturesCols
from .utils import (
RabitContext,
_get_args_from_message_list,
Expand Down Expand Up @@ -73,14 +74,10 @@
"num_workers",
"use_gpu",
"feature_names",
"features_cols",
]

_non_booster_params = [
"missing",
"n_estimators",
"feature_types",
"feature_weights",
]
_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]

_pyspark_param_alias_map = {
"features_col": "featuresCol",
Expand Down Expand Up @@ -126,6 +123,7 @@ class _SparkXGBParams(
HasValidationIndicatorCol,
HasArbitraryParamsDict,
HasBaseMarginCol,
HasFeaturesCols,
):
num_workers = Param(
Params._dummy(),
Expand Down Expand Up @@ -240,12 +238,11 @@ def _gen_predict_params_dict(self):

def _validate_params(self):
init_model = self.getOrDefault(self.xgb_model)
if init_model is not None:
if init_model is not None and not isinstance(init_model, Booster):
raise ValueError(
"The xgb_model param must be set with a `xgboost.core.Booster` "
"instance."
)
if init_model is not None and not isinstance(init_model, Booster):
raise ValueError(
"The xgb_model param must be set with a `xgboost.core.Booster` "
"instance."
)

if self.getOrDefault(self.num_workers) < 1:
raise ValueError(
Expand All @@ -262,6 +259,14 @@ def _validate_params(self):
"Therefore, that parameter will be ignored."
)

if self.getOrDefault(self.features_cols):
if not self.getOrDefault(self.use_gpu):
raise ValueError("features_cols param requires enabling use_gpu.")

get_logger(self.__class__.__name__).warning(
"If features_cols param set, then features_col param is ignored."
)

if self.getOrDefault(self.use_gpu):
tree_method = self.getParam("tree_method")
if (
Expand Down Expand Up @@ -315,6 +320,23 @@ def _validate_params(self):
)


def _validate_and_convert_feature_col_as_float_col_list(
dataset, features_col_names: list
) -> list:
"""Values in feature columns must be integral types or float/double types"""
feature_cols = []
for c in features_col_names:
if isinstance(dataset.schema[c].dataType, DoubleType):
feature_cols.append(col(c).cast(FloatType()).alias(c))
elif isinstance(dataset.schema[c].dataType, (FloatType, IntegralType)):
feature_cols.append(col(c))
else:
raise ValueError(
"Values in feature columns must be integral types or float/double types."
)
return feature_cols


def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
features_col_datatype = dataset.schema[features_col_name].dataType
features_col = col(features_col_name)
Expand Down Expand Up @@ -373,8 +395,14 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name
f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead."
)
if k in _pyspark_param_alias_map:
real_k = _pyspark_param_alias_map[k]
k = real_k
if k == _inverse_pyspark_param_alias_map[
self.featuresCol.name
] and isinstance(v, list):
real_k = self.features_cols.name
k = real_k
else:
real_k = _pyspark_param_alias_map[k]
k = real_k

if self.hasParam(k):
self._set(**{str(k): v})
Expand Down Expand Up @@ -497,10 +525,19 @@ def _fit(self, dataset):
self._validate_params()
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)

features_array_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
select_cols = [features_array_col, label_col]
select_cols = [label_col]
features_cols_names = None
if self.getOrDefault(self.features_cols):
features_cols_names = self.getOrDefault(self.features_cols)
features_cols = _validate_and_convert_feature_col_as_float_col_list(
dataset, features_cols_names
)
select_cols.extend(features_cols)
else:
features_array_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
select_cols.append(features_array_col)

if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
select_cols.append(
Expand Down Expand Up @@ -569,10 +606,17 @@ def _train_booster(pandas_df_iter):
context = BarrierTaskContext.get()
context.barrier()

gpu_id = None
if use_gpu:
booster_params["gpu_id"] = (
context.partitionId() if is_local else _get_gpu_id(context)
)
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id

# max_bin is needed for qdm
if (
features_cols_names is not None
and booster_params.get("max_bin", None) is not None
):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]

_rabit_args = ""
if context.partitionId() == 0:
Expand All @@ -583,9 +627,7 @@ def _train_booster(pandas_df_iter):
evals_result = {}
with RabitContext(_rabit_args, context):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter,
None,
dmatrix_kwargs,
pandas_df_iter, features_cols_names, gpu_id, dmatrix_kwargs
)
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
Expand Down Expand Up @@ -685,6 +727,34 @@ def read(cls):
def _transform(self, dataset):
raise NotImplementedError()

def _get_feature_col(self, dataset) -> (list, Optional[list]):
"""XGBoost model trained with features_cols parameter can also predict
vector or array feature type. But first we need to check features_cols
and then featuresCol
"""

feature_col_names = self.getOrDefault(self.features_cols)
features_col = []
if feature_col_names and set(feature_col_names).issubset(set(dataset.columns)):
# The model is trained with features_cols and the predicted dataset
# also contains all the columns specified by features_cols.
features_col = _validate_and_convert_feature_col_as_float_col_list(
dataset, feature_col_names
)
else:
# 1. The model was trained by features_cols, but the dataset doesn't contain
# all the columns specified by features_cols, so we need to check if
# the dataframe has the featuresCol
# 2. The model was trained by featuresCol, and the predicted dataset must contain
# featuresCol column.
feature_col_names = None
features_col.append(
_validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
)
return features_col, feature_col_names


class SparkXGBRegressorModel(_SparkXGBModel):
"""
Expand Down Expand Up @@ -712,11 +782,17 @@ def _transform(self, dataset):
alias.margin
)

features_col, feature_col_names = self._get_feature_col(dataset)

@pandas_udf("double")
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
model = xgb_sklearn_model
for data in iterator:
X = stack_series(data[alias.data])
if feature_col_names is not None:
X = data[feature_col_names]
else:
X = stack_series(data[alias.data])

if has_base_margin:
base_margin = data[alias.margin].to_numpy()
else:
Expand All @@ -730,14 +806,10 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
)
yield pd.Series(preds)

features_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)

if has_base_margin:
pred_col = predict_udf(struct(features_col, base_margin_col))
pred_col = predict_udf(struct(*features_col, base_margin_col))
else:
pred_col = predict_udf(struct(features_col))
pred_col = predict_udf(struct(*features_col))

predictionColName = self.getOrDefault(self.predictionCol)

Expand Down Expand Up @@ -783,6 +855,8 @@ def transform_margin(margins: np.ndarray):
class_probs = softmax(raw_preds, axis=1)
return raw_preds, class_probs

features_col, feature_col_names = self._get_feature_col(dataset)

@pandas_udf(
"rawPrediction array<double>, prediction double, probability array<double>"
)
Expand All @@ -791,7 +865,11 @@ def predict_udf(
) -> Iterator[pd.DataFrame]:
model = xgb_sklearn_model
for data in iterator:
X = stack_series(data[alias.data])
if feature_col_names is not None:
X = data[feature_col_names]
else:
X = stack_series(data[alias.data])

if has_base_margin:
base_margin = stack_series(data[alias.margin])
else:
Expand All @@ -817,14 +895,10 @@ def predict_udf(
}
)

features_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)

if has_base_margin:
pred_struct = predict_udf(struct(features_col, base_margin_col))
pred_struct = predict_udf(struct(*features_col, base_margin_col))
else:
pred_struct = predict_udf(struct(features_col))
pred_struct = predict_udf(struct(*features_col))

pred_struct_col = "_prediction_struct"

Expand Down
13 changes: 9 additions & 4 deletions python-package/xgboost/spark/data.py
Expand Up @@ -63,9 +63,9 @@ def make_blob(part: pd.DataFrame, is_valid: bool) -> None:
class PartIter(DataIter):
"""Iterator for creating Quantile DMatrix from partitions."""

def __init__(self, data: Dict[str, List], on_device: bool) -> None:
def __init__(self, data: Dict[str, List], device_id: Optional[int]) -> None:
self._iter = 0
self._cuda = on_device
self._device_id = device_id
self._data = data

super().__init__()
Expand All @@ -74,9 +74,13 @@ def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFram
if not data:
return None

if self._cuda:
if self._device_id is not None:
import cudf # pylint: disable=import-error
import cupy as cp # pylint: disable=import-error

# We must set the device after import cudf, which will change the device id to 0
# See https://github.com/rapidsai/cudf/issues/11386
cp.cuda.runtime.setDevice(self._device_id)
return cudf.DataFrame(data[self._iter])

return data[self._iter]
Expand All @@ -100,6 +104,7 @@ def reset(self) -> None:
def create_dmatrix_from_partitions(
iterator: Iterator[pd.DataFrame],
feature_cols: Optional[Sequence[str]],
gpu_id: Optional[int],
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
) -> Tuple[DMatrix, Optional[DMatrix]]:
"""Create DMatrix from spark data partitions. This is not particularly efficient as
Expand Down Expand Up @@ -169,7 +174,7 @@ def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix
dtrain = make(train_data, kwargs)
else:
cache_partitions(iterator, append_dqm)
it = PartIter(train_data, True)
it = PartIter(train_data, gpu_id)
dtrain = DeviceQuantileDMatrix(it, **kwargs)

dvalid = make(valid_data, kwargs) if len(valid_data) != 0 else None
Expand Down
19 changes: 19 additions & 0 deletions python-package/xgboost/spark/params.py
@@ -1,6 +1,7 @@
# type: ignore
"""Xgboost pyspark integration submodule for params."""
# pylint: disable=too-few-public-methods
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import Param, Params


Expand Down Expand Up @@ -31,3 +32,21 @@ class HasBaseMarginCol(Params):
"base_margin_col",
"This stores the name for the column of the base margin",
)


class HasFeaturesCols(Params):
"""
Mixin for param featuresCols: a list of feature column names.
This parameter is taken effect only when use_gpu is enabled.
"""

features_cols = Param(
Params._dummy(),
"features_cols",
"feature column names.",
typeConverter=TypeConverters.toListString,
)

def __init__(self):
super().__init__()
self._setDefault(features_cols=[])
6 changes: 3 additions & 3 deletions tests/ci_build/lint_python.py
Expand Up @@ -115,7 +115,7 @@ def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int:
"python-package/xgboost/dask.py",
"python-package/xgboost/spark",
"tests/python/test_spark/test_data.py",
"tests/python-gpu/test_spark_with_gpu/test_data.py",
"tests/python-gpu/test_gpu_spark/test_data.py",
"tests/ci_build/lint_python.py",
]
):
Expand All @@ -130,9 +130,9 @@ def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int:
"demo/guide-python/cat_in_the_dat.py",
"tests/python/test_data_iterator.py",
"tests/python/test_spark/test_data.py",
"tests/python-gpu/test_gpu_with_dask.py",
"tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/test_spark_with_gpu/test_data.py",
"tests/python-gpu/test_gpu_spark/test_data.py",
"tests/ci_build/lint_python.py",
]
):
Expand Down
4 changes: 2 additions & 2 deletions tests/python-gpu/conftest.py
Expand Up @@ -61,8 +61,8 @@ def pytest_collection_modifyitems(config, items):
mgpu_mark = pytest.mark.mgpu
for item in items:
if item.nodeid.startswith(
"python-gpu/test_gpu_with_dask.py"
"python-gpu/test_gpu_with_dask/test_gpu_with_dask.py"
) or item.nodeid.startswith(
"python-gpu/test_spark_with_gpu/test_spark_with_gpu.py"
"python-gpu/test_gpu_spark/test_gpu_spark.py"
):
item.add_marker(mgpu_mark)

0 comments on commit 03cc3b3

Please sign in to comment.