Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] Make Xgboost estimator support using sparse matrix as optimization #8145

Merged
merged 18 commits into from Aug 18, 2022
99 changes: 86 additions & 13 deletions python-package/xgboost/spark/core.py
Expand Up @@ -44,7 +44,9 @@
SparkXGBReader,
SparkXGBWriter,
)
from .params import HasArbitraryParamsDict, HasBaseMarginCol, HasFeaturesCols
from .params import (
HasArbitraryParamsDict, HasBaseMarginCol, HasFeaturesCols, HasEnableSparseDataOptim
)
from .utils import (
RabitContext,
_get_args_from_message_list,
Expand Down Expand Up @@ -124,6 +126,7 @@ class _SparkXGBParams(
HasArbitraryParamsDict,
HasBaseMarginCol,
HasFeaturesCols,
HasEnableSparseDataOptim,
):
num_workers = Param(
Params._dummy(),
Expand Down Expand Up @@ -363,6 +366,23 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
return features_array_col


def _get_unwrap_udt_fn():
try:
from pyspark.sql.functions import unwrap_udt
return unwrap_udt
except ImportError:
pass

try:
from pyspark.databricks.sql.functions import unwrap_udt
return unwrap_udt
except ImportError:
raise RuntimeError(
"Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 "
"or run on Databricks Runtime."
)


class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -527,17 +547,69 @@ def _fit(self, dataset):

select_cols = [label_col]
features_cols_names = None
if self.getOrDefault(self.features_cols):
features_cols_names = self.getOrDefault(self.features_cols)
features_cols = _validate_and_convert_feature_col_as_float_col_list(
dataset, features_cols_names
)
select_cols.extend(features_cols)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the _fit function is almost 200 lines, which is super huge, could we split this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on another Ranker estimator PR. We can do this refactor after these feature PRs merged. Otherwise fixing conflicts is annoying.

if enable_sparse_data_optim:
from pyspark.ml.linalg import VectorUDT

if self.getOrDefault(self.missing) != 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we put this checking into the _validate_params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis
According to your comment: #8145 (comment)
Seemingly we don't need to add this restriction ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need the restriction that missing must be 0. Otherwise there will be two missing/invalid value, 0s removed by spark. missing removed by xgboost.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure that we are on the same page. For the xgboost.DMatrix class, CSR is the same as dense. The difference is in the spark interface. Feel free to add related documents.

# If DMatrix is constructed from csr / csc matrix, then inactive elements
# in csr / csc matrix are regarded as missing value, but, in pyspark, we
# are hard to control elements to be active or inactive in sparse vector column,
# some spark transformers such as VectorAssembler might compress vectors
# to be dense or sparse format automatically, and when a spark ML vector object
# is compressed to sparse vector, then all zero value elements become inactive.
# So we force setting missing param to be 0 when enable_sparse_data_optim config
# is True.
raise ValueError(
"If enable_sparse_data_optim is True, missing param != 0 is not supported."
)

if self.getOrDefault(self.features_cols):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wbo4958 Would you like to take a look into this?

raise ValueError(
"If enable_sparse_data_optim is True, you cannot set multiple feature columns "
"but you should set one feature column with values of "
"`pyspark.ml.linalg.Vector` type."
)
features_col_name = self.getOrDefault(self.featuresCol)
features_col_datatype = dataset.schema[features_col_name].dataType
if not isinstance(features_col_datatype, VectorUDT):
raise ValueError(
"If enable_sparse_data_optim is True, the feature column values must be "
"`pyspark.ml.linalg.Vector` type."
)

unwrap_udt = _get_unwrap_udt_fn()
features_unwrapped_vec_col = unwrap_udt(col(features_col_name))

# After a `pyspark.ml.linalg.VectorUDT` type column being unwrapped, it becomes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed comments!

# a pyspark struct type column, the struct fields are:
# - `type`: byte
# - `size`: int
# - `indices`: array<int>
# - `values`: array<double>
# For sparse vector, `type` field is 0, `size` field means vector length,
# `indices` field is the array of active element indices, `values` field
# is the array of active element values.
# For dense vector, `type` field is 1, `size` and `indices` fields are None,
# `values` field is the array of the vector element values.
select_cols.extend([
features_unwrapped_vec_col.type.alias("featureVectorType"),
features_unwrapped_vec_col.size.alias("featureVectorSize"),
features_unwrapped_vec_col.indices.alias("featureVectorIndices"),
features_unwrapped_vec_col.values.alias("featureVectorValues"),
])
else:
features_array_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
select_cols.append(features_array_col)
if self.getOrDefault(self.features_cols):
features_cols_names = self.getOrDefault(self.features_cols)
features_cols = _validate_and_convert_feature_col_as_float_col_list(
dataset, features_cols_names
)
select_cols.extend(features_cols)
else:
features_array_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
select_cols.append(features_array_col)

if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
select_cols.append(
Expand Down Expand Up @@ -589,7 +661,7 @@ def _fit(self, dataset):
"feature_types": self.getOrDefault(self.feature_types),
"feature_names": self.getOrDefault(self.feature_names),
"feature_weights": self.getOrDefault(self.feature_weights),
"missing": self.getOrDefault(self.missing),
"missing": float(self.getOrDefault(self.missing)),
}
booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)
Expand Down Expand Up @@ -627,7 +699,8 @@ def _train_booster(pandas_df_iter):
evals_result = {}
with RabitContext(_rabit_args, context):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter, features_cols_names, gpu_id, dmatrix_kwargs
pandas_df_iter, features_cols_names, gpu_id, dmatrix_kwargs,
enable_sparse_data_optim=enable_sparse_data_optim,
)
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
Expand Down
78 changes: 77 additions & 1 deletion python-package/xgboost/spark/data.py
Expand Up @@ -106,6 +106,7 @@ def create_dmatrix_from_partitions(
feature_cols: Optional[Sequence[str]],
gpu_id: Optional[int],
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
enable_sparse_data_optim: bool,
) -> Tuple[DMatrix, Optional[DMatrix]]:
"""Create DMatrix from spark data partitions. This is not particularly efficient as
we need to convert the pandas series format to numpy then concatenate all the data.
Expand Down Expand Up @@ -139,6 +140,76 @@ def append_m(part: pd.DataFrame, name: str, is_valid: bool) -> None:
else:
train_data[name].append(array)

def append_m_sparse(part: pd.DataFrame, name: str, is_valid: bool) -> None:
from scipy.sparse import csr_matrix
nonlocal n_features

if name == alias.data or name in part.columns:
if name == alias.data:
# variables for constructing csr_matrix
csr_indices_list, csr_indptr_list, csr_values_list = [], [0], []

for vec_type, vec_size_, vec_indices, vec_values in zip(
part.featureVectorType,
part.featureVectorSize,
part.featureVectorIndices,
part.featureVectorValues
):
if vec_type == 0:
# sparse vector
vec_size = int(vec_size_)
csr_indices = vec_indices
csr_values = vec_values
else:
# dense vector
# Note: According to spark ML VectorUDT format,
# when type field is 1, the size field is also empty.
# we need to check the values field to get vector length.
vec_size = len(vec_values)
csr_indices = np.arange(vec_size, dtype=np.int32)
csr_values = vec_values

if n_features == 0:
n_features = vec_size
assert n_features == vec_size

# remove zero elements from csr_indices / csr_values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? When missing is set to 0, XGBoost can remove those values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember in the DMatrix ctor, if data argument using csc/csr matrix, then it ignore the "missing" argument but regard all inactive element in the sparse matrix as missing values. (Ref: #341 (comment))

If so, then keep zero elements or removing them represents 2 different semantic:
Keep these zero means it will be regarded as "zero" value feature,
Remove these zero elements means it will be regarded as missing value.

Is my understanding correct ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ctor for CSR matrix should be able to handle missing values (but not for the CSC, which would raise a warning).

def _warn_unused_missing(data: DataType, missing: Optional[FloatCompatible]) -> None:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which is a good reminder that I should clear the difference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis
OK, so for missing handling, CSR is the same with dense input (respect "missing" param), but CSC is different (ignore "missing" param and regard inactive elements as missing), right ?
We should document this in DMatrix doc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update the CSC implementation instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis what does "updating CSC implementation" mean? so for current implementation, it indeed remove the whole instance if one element in the instance is missing value?

n_actives = len(csr_indices)
nz_csr_indices = np.empty(n_actives, dtype=np.int32)
nz_csr_values = np.empty(n_actives, dtype=np.int32)

active_i = 0
nz_i = 0
while active_i < n_actives:
if csr_values[active_i] != 0.0:
nz_csr_indices[nz_i] = csr_indices[active_i]
nz_csr_values[nz_i] = csr_values[active_i]
nz_i += 1
active_i += 1

nz_csr_indices = nz_csr_indices[:nz_i]
nz_csr_values = nz_csr_values[:nz_i]

csr_indices_list.append(nz_csr_indices)
csr_indptr_list.append(csr_indptr_list[-1] + len(nz_csr_indices))
csr_values_list.append(nz_csr_values)

csr_indptr_arr = np.array(csr_indptr_list)
csr_indices_arr = np.concatenate(csr_indices_list)
csr_values_arr = np.concatenate(csr_values_list)

array = csr_matrix(
(csr_values_arr, csr_indices_arr, csr_indptr_arr),
shape=(len(part), n_features)
)
else:
array = part[name]

if is_valid:
valid_data[name].append(array)
else:
train_data[name].append(array)

def append_dqm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
"""Preprocessing for DeviceQuantileDMatrix"""
nonlocal n_features
Expand All @@ -164,13 +235,18 @@ def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix
label = concat_or_none(values.get(alias.label, None))
weight = concat_or_none(values.get(alias.weight, None))
margin = concat_or_none(values.get(alias.margin, None))

return DMatrix(
data=data, label=label, weight=weight, base_margin=margin, **kwargs
)

is_dmatrix = feature_cols is None
if is_dmatrix:
cache_partitions(iterator, append_m)
if enable_sparse_data_optim:
append_fn = append_m_sparse
else:
append_fn = append_m
cache_partitions(iterator, append_fn)
dtrain = make(train_data, kwargs)
else:
cache_partitions(iterator, append_dqm)
Expand Down
22 changes: 22 additions & 0 deletions python-package/xgboost/spark/params.py
Expand Up @@ -50,3 +50,25 @@ class HasFeaturesCols(Params):
def __init__(self):
super().__init__()
self._setDefault(features_cols=[])


class HasEnableSparseDataOptim(Params):

"""
This is a Params based class that is extended by _SparkXGBParams
and holds the variable to store the boolean config of enabling sparse data optimization.
"""

enable_sparse_data_optim = Param(
Params._dummy(),
"enable_sparse_data_optim",
"This stores the boolean config of enabling sparse data optimization, if enabled, "
"Xgboost DMatrix object will be constructed from sparse matrix instead of "
"dense matrix. This config is disabled by default. If most of examples in your "
"training dataset contains sparse features, we suggest to enable this config.",
typeConverter=TypeConverters.toBoolean,
)

def __init__(self):
super().__init__()
self._setDefault(enable_sparse_data_optim=False)
30 changes: 28 additions & 2 deletions tests/python/test_spark/test_data.py
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pytest
import testing as tm
from unittest import mock

if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
Expand Down Expand Up @@ -62,10 +63,10 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
kwargs = {"feature_types": feature_types}
if is_dqm:
cols = [f"feat-{i}" for i in range(n_features)]
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, 0, kwargs)
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, 0, kwargs, False)
else:
train_Xy, valid_Xy = create_dmatrix_from_partitions(
iter(dfs), None, None, kwargs
iter(dfs), None, None, kwargs, True
)

assert valid_Xy is not None
Expand Down Expand Up @@ -100,3 +101,28 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:

def test_dmatrix_ctor() -> None:
run_dmatrix_ctor(False)


def test_dmatrix_ctor_with_sparse_optim():
from scipy.sparse import csr_matrix
pd1 = pd.DataFrame({
"featureVectorType": [0, 1],
"featureVectorSize": [3, None],
"featureVectorIndices": [np.array([0, 2], dtype=np.int32), None],
"featureVectorValues": [np.array([3.0, 0.0], dtype=np.float64), np.array([13.0, 14.0, 0.0], dtype=np.float64)],
})
pd2 = pd.DataFrame({
"featureVectorType": [1, 0],
"featureVectorSize": [None, 3],
"featureVectorIndices": [None, np.array([1, 2], dtype=np.int32)],
"featureVectorValues": [np.array([0.0, 24.0, 25.0], dtype=np.float64), np.array([0.0, 35.0], dtype=np.float64)],
})

with mock.patch("xgboost.core.DMatrix.__init__", return_value=None) as mock_dmatrix_ctor:
create_dmatrix_from_partitions([pd1, pd2], None, None, {}, True)
sm = mock_dmatrix_ctor.call_args_list[0][1]["data"]
assert isinstance(sm, csr_matrix)
np.testing.assert_array_equal(sm.data, [3, 13, 14, 24, 25, 35])
np.testing.assert_array_equal(sm.indptr, [0, 1, 3, 5, 6])
np.testing.assert_array_equal(sm.indices, [0, 0, 1, 1, 2, 2])
assert sm.shape == (4, 3)