Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] Make Xgboost estimator support using sparse matrix as optimization #8145

Merged
merged 18 commits into from Aug 18, 2022
151 changes: 130 additions & 21 deletions python-package/xgboost/spark/core.py
@@ -1,7 +1,7 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code."""
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods
# pylint: disable=too-few-public-methods, too-many-lines
from typing import Iterator, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -37,14 +37,24 @@
import xgboost
from xgboost import XGBClassifier, XGBRegressor

from .data import alias, create_dmatrix_from_partitions, stack_series
from .data import (
_read_csr_matrix_from_unwrapped_spark_vec,
alias,
create_dmatrix_from_partitions,
stack_series,
)
from .model import (
SparkXGBModelReader,
SparkXGBModelWriter,
SparkXGBReader,
SparkXGBWriter,
)
from .params import HasArbitraryParamsDict, HasBaseMarginCol, HasFeaturesCols
from .params import (
HasArbitraryParamsDict,
HasBaseMarginCol,
HasEnableSparseDataOptim,
HasFeaturesCols,
)
from .utils import (
RabitContext,
_get_args_from_message_list,
Expand Down Expand Up @@ -75,6 +85,7 @@
"use_gpu",
"feature_names",
"features_cols",
"enable_sparse_data_optim",
]

_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
Expand Down Expand Up @@ -124,6 +135,7 @@ class _SparkXGBParams(
HasArbitraryParamsDict,
HasBaseMarginCol,
HasFeaturesCols,
HasEnableSparseDataOptim,
):
num_workers = Param(
Params._dummy(),
Expand Down Expand Up @@ -237,6 +249,7 @@ def _gen_predict_params_dict(self):
return predict_params

def _validate_params(self):
# pylint: disable=too-many-branches
init_model = self.getOrDefault(self.xgb_model)
if init_model is not None and not isinstance(init_model, Booster):
raise ValueError(
Expand Down Expand Up @@ -267,6 +280,26 @@ def _validate_params(self):
"If features_cols param set, then features_col param is ignored."
)

if self.getOrDefault(self.enable_sparse_data_optim):
if self.getOrDefault(self.missing) != 0.0:
# If DMatrix is constructed from csr / csc matrix, then inactive elements
# in csr / csc matrix are regarded as missing value, but, in pyspark, we
# are hard to control elements to be active or inactive in sparse vector column,
# some spark transformers such as VectorAssembler might compress vectors
# to be dense or sparse format automatically, and when a spark ML vector object
# is compressed to sparse vector, then all zero value elements become inactive.
# So we force setting missing param to be 0 when enable_sparse_data_optim config
# is True.
raise ValueError(
"If enable_sparse_data_optim is True, missing param != 0 is not supported."
)
if self.getOrDefault(self.features_cols):
raise ValueError(
"If enable_sparse_data_optim is True, you cannot set multiple feature columns "
"but you should set one feature column with values of "
"`pyspark.ml.linalg.Vector` type."
)

if self.getOrDefault(self.use_gpu):
tree_method = self.getParam("tree_method")
if (
Expand Down Expand Up @@ -363,6 +396,52 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
return features_array_col


def _get_unwrap_udt_fn():
try:
from pyspark.sql.functions import unwrap_udt

return unwrap_udt
except ImportError:
pass

try:
from pyspark.databricks.sql.functions import unwrap_udt

return unwrap_udt
except ImportError as exc:
raise RuntimeError(
"Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 "
"or run on Databricks Runtime."
) from exc


def _get_unwrapped_vec_cols(feature_col):
unwrap_udt = _get_unwrap_udt_fn()
features_unwrapped_vec_col = unwrap_udt(feature_col)

# After a `pyspark.ml.linalg.VectorUDT` type column being unwrapped, it becomes
# a pyspark struct type column, the struct fields are:
# - `type`: byte
# - `size`: int
# - `indices`: array<int>
# - `values`: array<double>
# For sparse vector, `type` field is 0, `size` field means vector length,
# `indices` field is the array of active element indices, `values` field
# is the array of active element values.
# For dense vector, `type` field is 1, `size` and `indices` fields are None,
# `values` field is the array of the vector element values.
return [
features_unwrapped_vec_col.type.alias("featureVectorType"),
features_unwrapped_vec_col.size.alias("featureVectorSize"),
features_unwrapped_vec_col.indices.alias("featureVectorIndices"),
# Note: the value field is double array type, cast it to float32 array type
# for speedup following repartitioning.
features_unwrapped_vec_col.values.cast(ArrayType(FloatType())).alias(
"featureVectorValues"
),
]


class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -527,17 +606,28 @@ def _fit(self, dataset):

select_cols = [label_col]
features_cols_names = None
if self.getOrDefault(self.features_cols):
features_cols_names = self.getOrDefault(self.features_cols)
features_cols = _validate_and_convert_feature_col_as_float_col_list(
dataset, features_cols_names
)
select_cols.extend(features_cols)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the _fit function is almost 200 lines, which is super huge, could we split this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on another Ranker estimator PR. We can do this refactor after these feature PRs merged. Otherwise fixing conflicts is annoying.

if enable_sparse_data_optim:
features_col_name = self.getOrDefault(self.featuresCol)
features_col_datatype = dataset.schema[features_col_name].dataType
if not isinstance(features_col_datatype, VectorUDT):
raise ValueError(
"If enable_sparse_data_optim is True, the feature column values must be "
"`pyspark.ml.linalg.Vector` type."
)
select_cols.extend(_get_unwrapped_vec_cols(col(features_col_name)))
else:
features_array_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
select_cols.append(features_array_col)
if self.getOrDefault(self.features_cols):
features_cols_names = self.getOrDefault(self.features_cols)
features_cols = _validate_and_convert_feature_col_as_float_col_list(
dataset, features_cols_names
)
select_cols.extend(features_cols)
else:
features_array_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
select_cols.append(features_array_col)

if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
select_cols.append(
Expand Down Expand Up @@ -589,7 +679,7 @@ def _fit(self, dataset):
"feature_types": self.getOrDefault(self.feature_types),
"feature_names": self.getOrDefault(self.feature_names),
"feature_weights": self.getOrDefault(self.feature_weights),
"missing": self.getOrDefault(self.missing),
"missing": float(self.getOrDefault(self.missing)),
}
booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)
Expand Down Expand Up @@ -627,7 +717,11 @@ def _train_booster(pandas_df_iter):
evals_result = {}
with RabitContext(_rabit_args, context):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter, features_cols_names, gpu_id, dmatrix_kwargs
pandas_df_iter,
features_cols_names,
gpu_id,
dmatrix_kwargs,
enable_sparse_data_optim=enable_sparse_data_optim,
)
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
Expand Down Expand Up @@ -732,6 +826,12 @@ def _get_feature_col(self, dataset) -> (list, Optional[list]):
vector or array feature type. But first we need to check features_cols
and then featuresCol
"""
if self.getOrDefault(self.enable_sparse_data_optim):
feature_col_names = None
features_col = _get_unwrapped_vec_cols(
col(self.getOrDefault(self.featuresCol))
)
return features_col, feature_col_names

feature_col_names = self.getOrDefault(self.features_cols)
features_col = []
Expand Down Expand Up @@ -783,15 +883,19 @@ def _transform(self, dataset):
)

features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)

@pandas_udf("double")
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
model = xgb_sklearn_model
for data in iterator:
if feature_col_names is not None:
X = data[feature_col_names]
if enable_sparse_data_optim:
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
else:
X = stack_series(data[alias.data])
if feature_col_names is not None:
X = data[feature_col_names]
else:
X = stack_series(data[alias.data])

if has_base_margin:
base_margin = data[alias.margin].to_numpy()
Expand Down Expand Up @@ -828,6 +932,7 @@ def _xgb_cls(cls):
return XGBClassifier

def _transform(self, dataset):
# pylint: disable=too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
Expand Down Expand Up @@ -856,6 +961,7 @@ def transform_margin(margins: np.ndarray):
return raw_preds, class_probs

features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)

@pandas_udf(
"rawPrediction array<double>, prediction double, probability array<double>"
Expand All @@ -865,10 +971,13 @@ def predict_udf(
) -> Iterator[pd.DataFrame]:
model = xgb_sklearn_model
for data in iterator:
if feature_col_names is not None:
X = data[feature_col_names]
if enable_sparse_data_optim:
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
else:
X = stack_series(data[alias.data])
if feature_col_names is not None:
X = data[feature_col_names]
else:
X = stack_series(data[alias.data])

if has_base_margin:
base_margin = stack_series(data[alias.margin])
Expand Down
72 changes: 70 additions & 2 deletions python-package/xgboost/spark/data.py
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from xgboost.compat import concat

from xgboost import DataIter, DeviceQuantileDMatrix, DMatrix
Expand Down Expand Up @@ -101,11 +102,55 @@ def reset(self) -> None:
self._iter = 0


def _read_csr_matrix_from_unwrapped_spark_vec(part: pd.DataFrame) -> csr_matrix:
# variables for constructing csr_matrix
csr_indices_list, csr_indptr_list, csr_values_list = [], [0], []

n_features = 0

for vec_type, vec_size_, vec_indices, vec_values in zip(
part.featureVectorType,
part.featureVectorSize,
part.featureVectorIndices,
part.featureVectorValues,
):
if vec_type == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I'm wrong. now that the missing is 0, do we still really need the sparse vector? per my understanding, if one instance has a missing value, then the whole instance will be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis Is it true ? If so then training on sparse data makes no sense. Almost all instances will be removed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think no, @wbo4958 You can check my test case test_regressor_with_sparse_optim and test_classifier_with_sparse_optim, every training instance contains missing value "0", but the generated model transforming has good prediction results.

# sparse vector
vec_size = int(vec_size_)
csr_indices = vec_indices
csr_values = vec_values
else:
# dense vector
# Note: According to spark ML VectorUDT format,
# when type field is 1, the size field is also empty.
# we need to check the values field to get vector length.
vec_size = len(vec_values)
csr_indices = np.arange(vec_size, dtype=np.int32)
csr_values = vec_values

if n_features == 0:
n_features = vec_size
assert n_features == vec_size

csr_indices_list.append(csr_indices)
csr_indptr_list.append(csr_indptr_list[-1] + len(csr_indices))
csr_values_list.append(csr_values)

csr_indptr_arr = np.array(csr_indptr_list)
csr_indices_arr = np.concatenate(csr_indices_list)
csr_values_arr = np.concatenate(csr_values_list)

return csr_matrix(
(csr_values_arr, csr_indices_arr, csr_indptr_arr), shape=(len(part), n_features)
)


def create_dmatrix_from_partitions(
iterator: Iterator[pd.DataFrame],
feature_cols: Optional[Sequence[str]],
gpu_id: Optional[int],
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
enable_sparse_data_optim: bool,
) -> Tuple[DMatrix, Optional[DMatrix]]:
"""Create DMatrix from spark data partitions. This is not particularly efficient as
we need to convert the pandas series format to numpy then concatenate all the data.
Expand All @@ -118,7 +163,7 @@ def create_dmatrix_from_partitions(
Metainfo for DMatrix.

"""

# pylint: disable=too-many-locals, too-many-statements
train_data: Dict[str, List[np.ndarray]] = defaultdict(list)
valid_data: Dict[str, List[np.ndarray]] = defaultdict(list)

Expand All @@ -139,6 +184,23 @@ def append_m(part: pd.DataFrame, name: str, is_valid: bool) -> None:
else:
train_data[name].append(array)

def append_m_sparse(part: pd.DataFrame, name: str, is_valid: bool) -> None:
nonlocal n_features

if name == alias.data or name in part.columns:
if name == alias.data:
array = _read_csr_matrix_from_unwrapped_spark_vec(part)
if n_features == 0:
n_features = array.shape[1]
assert n_features == array.shape[1]
else:
array = part[name]

if is_valid:
valid_data[name].append(array)
else:
train_data[name].append(array)

def append_dqm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
"""Preprocessing for DeviceQuantileDMatrix"""
nonlocal n_features
Expand All @@ -164,13 +226,19 @@ def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix
label = concat_or_none(values.get(alias.label, None))
weight = concat_or_none(values.get(alias.weight, None))
margin = concat_or_none(values.get(alias.margin, None))

return DMatrix(
data=data, label=label, weight=weight, base_margin=margin, **kwargs
)

is_dmatrix = feature_cols is None
if is_dmatrix:
cache_partitions(iterator, append_m)
if enable_sparse_data_optim:
append_fn = append_m_sparse
assert "missing" in kwargs and kwargs["missing"] == 0.0
else:
append_fn = append_m
cache_partitions(iterator, append_fn)
dtrain = make(train_data, kwargs)
else:
cache_partitions(iterator, append_dqm)
Expand Down