Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] Use quantile dmatrix. #8284

Merged
merged 21 commits into from Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/parameter.rst
Expand Up @@ -349,7 +349,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective.
- ``reg:logistic``: logistic regression.
- ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss.
- ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction.
- ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction. If used in distributed training, the leaf value is calculated as the mean value from all workers, which is not guaranteed to be optimal.
- ``binary:logistic``: logistic regression for binary classification, output probability
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
- ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.
Expand Down
43 changes: 31 additions & 12 deletions python-package/xgboost/core.py
Expand Up @@ -105,6 +105,11 @@ def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]:
return res


def make_jcargs(**kwargs: Any) -> bytes:
"Make JSON-based arguments for C functions."
return from_pystr_to_cstr(json.dumps(kwargs))


IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int])


Expand Down Expand Up @@ -1243,7 +1248,7 @@ def __init__(self) -> None: # pylint: disable=super-init-not-called
def _set_data_from_cuda_interface(self, data: DataType) -> None:
"""Set data from CUDA array interface."""
interface = data.__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
interface_str = bytes(json.dumps(interface), "utf-8")
_check_call(
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
)
Expand Down Expand Up @@ -1344,6 +1349,26 @@ def __init__( # pylint: disable=super-init-not-called
"Only one of the eval_qid or eval_group for each evaluation "
"dataset should be provided."
)
if isinstance(data, DataIter):
if any(
info is not None
for info in (
label,
weight,
base_margin,
feature_names,
feature_types,
group,
qid,
label_lower_bound,
label_upper_bound,
feature_weights,
)
):
raise ValueError(
"If data iterator is used as input, data like label should be "
"specified as batch argument."
)

self._init(
data,
Expand Down Expand Up @@ -1392,12 +1417,9 @@ def _init(
"in iterator to fix this error."
)

args = {
"nthread": self.nthread,
"missing": self.missing,
"max_bin": self.max_bin,
}
config = from_pystr_to_cstr(json.dumps(args))
config = make_jcargs(
nthread=self.nthread, missing=self.missing, max_bin=self.max_bin
)
ret = _LIB.XGQuantileDMatrixCreateFromCallback(
None,
it.proxy.handle,
Expand Down Expand Up @@ -2362,7 +2384,7 @@ def save_raw(self, raw_format: str = "deprecated") -> bytearray:
"""
length = c_bst_ulong()
cptr = ctypes.POINTER(ctypes.c_char)()
config = from_pystr_to_cstr(json.dumps({"format": raw_format}))
config = make_jcargs(format=raw_format)
_check_call(
_LIB.XGBoosterSaveModelToBuffer(
self.handle, config, ctypes.byref(length), ctypes.byref(cptr)
Expand Down Expand Up @@ -2557,9 +2579,6 @@ def get_score(
`n_classes`, otherwise they're scalars.
"""
fmap = os.fspath(os.path.expanduser(fmap))
args = from_pystr_to_cstr(
json.dumps({"importance_type": importance_type, "feature_map": fmap})
)
features = ctypes.POINTER(ctypes.c_char_p)()
scores = ctypes.POINTER(ctypes.c_float)()
n_out_features = c_bst_ulong()
Expand All @@ -2569,7 +2588,7 @@ def get_score(
_check_call(
_LIB.XGBoosterFeatureScore(
self.handle,
args,
make_jcargs(importance_type=importance_type, feature_map=fmap),
ctypes.byref(n_out_features),
ctypes.byref(features),
ctypes.byref(out_dim),
Expand Down
13 changes: 9 additions & 4 deletions python-package/xgboost/dask.py
Expand Up @@ -573,6 +573,7 @@ def __init__(
label_upper_bound: Optional[List[Any]] = None,
feature_names: Optional[FeatureNames] = None,
feature_types: Optional[Union[Any, List[Any]]] = None,
feature_weights: Optional[Any] = None,
) -> None:
self._data = data
self._label = label
Expand All @@ -583,6 +584,7 @@ def __init__(
self._label_upper_bound = label_upper_bound
self._feature_names = feature_names
self._feature_types = feature_types
self._feature_weights = feature_weights

assert isinstance(self._data, collections.abc.Sequence)

Expand Down Expand Up @@ -633,6 +635,7 @@ def next(self, input_data: Callable) -> int:
label_upper_bound=self._get("_label_upper_bound"),
feature_names=feature_names,
feature_types=self._feature_types,
feature_weights=self._feature_weights,
)
self._iter += 1
return 1
Expand Down Expand Up @@ -731,19 +734,21 @@ def _create_quantile_dmatrix(
return d

unzipped_dict = _get_worker_parts(parts)
it = DaskPartitionIter(**unzipped_dict)
it = DaskPartitionIter(
**unzipped_dict,
feature_types=feature_types,
feature_names=feature_names,
feature_weights=feature_weights,
)

dmatrix = QuantileDMatrix(
it,
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=nthread,
max_bin=max_bin,
ref=ref,
enable_categorical=enable_categorical,
)
dmatrix.set_info(feature_weights=feature_weights)
return dmatrix


Expand Down
13 changes: 6 additions & 7 deletions python-package/xgboost/spark/core.py
Expand Up @@ -747,6 +747,7 @@ def _fit(self, dataset):
k: v for k, v in train_call_kwargs_params.items() if v is not None
}
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
use_qdm = booster_params.get("tree_method") in ("hist", "gpu_hist")

def _train_booster(pandas_df_iter):
"""Takes in an RDD partition and outputs a booster for that partition after
Expand All @@ -759,17 +760,14 @@ def _train_booster(pandas_df_iter):
context.barrier()

gpu_id = None

if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]

if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id

# max_bin is needed for qdm
if (
features_cols_names is not None
and booster_params.get("max_bin", None) is not None
):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]

_rabit_args = {}
if context.partitionId() == 0:
get_logger("XGBoostPySpark").info(
Expand All @@ -791,6 +789,7 @@ def _train_booster(pandas_df_iter):
pandas_df_iter,
features_cols_names,
gpu_id,
use_qdm,
dmatrix_kwargs,
enable_sparse_data_optim=enable_sparse_data_optim,
has_validation_col=has_validation_col,
Expand Down
96 changes: 74 additions & 22 deletions python-package/xgboost/spark/data.py
@@ -1,13 +1,13 @@
"""Utilities for processing spark partitions."""
from collections import defaultdict, namedtuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from xgboost.compat import concat

from xgboost import DataIter, DeviceQuantileDMatrix, DMatrix
from xgboost import DataIter, DMatrix, QuantileDMatrix

from .utils import get_logger # type: ignore

Expand Down Expand Up @@ -67,10 +67,13 @@ def make_blob(part: pd.DataFrame, is_valid: bool) -> None:
class PartIter(DataIter):
"""Iterator for creating Quantile DMatrix from partitions."""

def __init__(self, data: Dict[str, List], device_id: Optional[int]) -> None:
def __init__(
self, data: Dict[str, List], device_id: Optional[int], **kwargs: Any
) -> None:
self._iter = 0
self._device_id = device_id
self._data = data
self._kwargs = kwargs

super().__init__()

Expand Down Expand Up @@ -98,6 +101,7 @@ def next(self, input_data: Callable) -> int:
weight=self._fetch(self._data.get(alias.weight, None)),
base_margin=self._fetch(self._data.get(alias.margin, None)),
qid=self._fetch(self._data.get(alias.qid, None)),
**self._kwargs,
)
self._iter += 1
return 1
Expand Down Expand Up @@ -153,6 +157,7 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
iterator: Iterator[pd.DataFrame],
feature_cols: Optional[Sequence[str]],
gpu_id: Optional[int],
use_qdm: bool,
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
enable_sparse_data_optim: bool,
has_validation_col: bool,
Expand All @@ -164,9 +169,22 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
----------
iterator :
Pyspark partition iterator.
feature_cols:
A sequence of feqture names, used only when rapids plugin is enabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feqture -> feature. this parameter can be used even without rapid plugin.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't really done any test for it and it will likely trigger an assert error. We have DMatrix and QuantileDMatrix to support, I will leave that to the next release.

gpu_id:
Device ordinal, used when GPU is enabled.
use_qdm :
Whether QuantileDMatrix should be used instead of DMatrix.
kwargs :
Metainfo for DMatrix.

enable_sparse_data_optim :
Whether sparse data should be unwrapped
has_validation:
Whether there's validation data.

Returns
-------
Training DMatrix and an optional validation DMatrix.
"""
# pylint: disable=too-many-locals, too-many-statements
train_data: Dict[str, List[np.ndarray]] = defaultdict(list)
Expand Down Expand Up @@ -207,15 +225,15 @@ def append_m_sparse(part: pd.DataFrame, name: str, is_valid: bool) -> None:
train_data[name].append(array)

def append_dqm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
"""Preprocessing for DeviceQuantileDMatrix"""
"""Preprocessing for QuantileDMatrix"""
nonlocal n_features
if name == alias.data or name in part.columns:
if name == alias.data:
cname = feature_cols
if name == alias.data and feature_cols is not None:
array = part[feature_cols]
else:
cname = name
array = part[name]
array = stack_series(array)

array = part[cname]
if name == alias.data:
if n_features == 0:
n_features = array.shape[1]
Expand All @@ -240,32 +258,66 @@ def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix
data=data, label=label, weight=weight, base_margin=margin, qid=qid, **kwargs
)

is_dmatrix = feature_cols is None
if is_dmatrix:
if enable_sparse_data_optim:
append_fn = append_m_sparse
assert "missing" in kwargs and kwargs["missing"] == 0.0
else:
append_fn = append_m
if enable_sparse_data_optim:
append_fn = append_m_sparse
assert "missing" in kwargs and kwargs["missing"] == 0.0
else:
append_fn = append_m

def split_params() -> Tuple[Dict[str, Any], Dict[str, Union[int, float, bool]]]:
# FIXME(jiamingy): we really need a better way to bridge distributed frameworks
# to XGBoost native interface and prevent scattering parameters like this.
non_data_keys = (
"max_bin",
"missing",
"silent",
"nthread",
"enable_categorical",
)
non_data_params = {}
meta = {}
for k, v in kwargs.items():
if k in non_data_keys:
non_data_params[k] = v
else:
meta[k] = v
return meta, non_data_params

meta, params = split_params()

if feature_cols is not None: # rapidsai plugin
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
cache_partitions(iterator, append_dqm)
assert gpu_id is not None
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
assert use_qdm is True
it = PartIter(train_data, gpu_id, **meta)
dtrain: DMatrix = QuantileDMatrix(it, **params)
elif use_qdm:
cache_partitions(iterator, append_dqm)
it = PartIter(train_data, gpu_id, **meta)
dtrain = QuantileDMatrix(it, **params)
else:
cache_partitions(iterator, append_fn)
if len(train_data) == 0:
get_logger("XGBoostPySpark").warning(
"Detected an empty partition in the training data. "
"Consider to enable repartition_random_shuffle"
)
dtrain = make(train_data, kwargs)
else:
cache_partitions(iterator, append_dqm)
it = PartIter(train_data, gpu_id)
dtrain = DeviceQuantileDMatrix(it, **kwargs)

# Using has_validation_col here to indicate if there is validation col
# instead of getting it from iterator, since the iterator may be empty
# in some special case. That is to say, we must ensure every worker
# construct DMatrix even there is no any data since we need to ensure every
# construct DMatrix even there is no data since we need to ensure every
# worker do the AllReduce when constructing DMatrix, or else it may hang
# forever.
dvalid = make(valid_data, kwargs) if has_validation_col else None
if has_validation_col:
if use_qdm:
it = PartIter(valid_data, gpu_id, **meta)
dvalid: Optional[DMatrix] = QuantileDMatrix(it, **params, ref=dtrain)
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
else:
dvalid = make(valid_data, kwargs) if has_validation_col else None
else:
dvalid = None

if dvalid is not None:
assert dvalid.num_col() == dtrain.num_col()
Expand Down
4 changes: 3 additions & 1 deletion tests/python-gpu/test_gpu_spark/test_data.py
Expand Up @@ -20,4 +20,6 @@

@pytest.mark.skipif(**tm.no_cudf())
def test_qdm_ctor() -> None:
run_dmatrix_ctor(True)
run_dmatrix_ctor(is_dqm=True, on_gpu=True)
with pytest.raises(AssertionError):
run_dmatrix_ctor(is_dqm=False, on_gpu=True)
2 changes: 1 addition & 1 deletion tests/python-gpu/test_gpu_spark/test_gpu_spark.py
Expand Up @@ -7,7 +7,7 @@
sys.path.append("tests/python")
import testing as tm

if tm.no_dask()["condition"]:
if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
Expand Down
5 changes: 1 addition & 4 deletions tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py
Expand Up @@ -188,12 +188,9 @@ def run_gpu_hist(

# See note on `ObjFunction::UpdateTreeLeaf`.
update_leaf = dataset.name.endswith("-l1")
if update_leaf and len(history) == 2:
if update_leaf:
assert history[0] + 1e-2 >= history[-1]
return
if update_leaf and len(history) > 2:
assert history[0] >= history[-1]
return
else:
assert tm.non_increasing(history)

Expand Down