Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantile DMatrix for CPU. #8130

Merged
merged 2 commits into from Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/python/python_api.rst
Expand Up @@ -22,6 +22,9 @@ Core Data Structure
:members:
:show-inheritance:

.. autoclass:: xgboost.QuantileDMatrix
:show-inheritance:

.. autoclass:: xgboost.DeviceQuantileDMatrix
:show-inheritance:

Expand Down
34 changes: 21 additions & 13 deletions include/xgboost/c_api.h
Expand Up @@ -415,28 +415,26 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const* c_json_config,
DMatrixHandle *out);
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
char const *c_json_config, DMatrixHandle *out);

/*!
* \brief Create a Quantile DMatrix with data iterator.
*
* Short note for how to use the second set of callback for GPU Hist tree method:
* Short note for how to use the second set of callback for (GPU)Hist tree method:
*
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
* - Step 1: Create a DMatrix proxy by `XGProxyDMatrixCreate` and hold the handle.
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
* `XGDeviceQuantileDMatrixCreateFromCallback`.
* `XGQuantileDMatrixCreateFromCallback`.
* - Step 3: Call appropriate data setters in `next` functions.
*
* See test_iterative_device_dmatrix.cu or Python interface for examples.
* See test_iterative_dmatrix.cu or Python interface for examples.
*
* \param iter A handle to external data iterator.
* \param proxy A DMatrix proxy handle created by `XGProxyDMatrixCreate`.
* \param ref Reference DMatrix for providing quantile information.
* \param reset Callback function resetting the iterator state.
* \param next Callback function yielding the next batch of data.
* \param missing Which value to represent missing value
Expand All @@ -446,10 +444,20 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread, int max_bin,
DMatrixHandle *out);
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out);

/*!
* \brief Create a Device Quantile DMatrix with data iterator.
* \deprecated since 2.0
* \see XGQuantileDMatrixCreateFromCallback()
*/
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread, int max_bin, DMatrixHandle *out);

/*!
* \brief Set data on a DMatrix proxy.
Expand Down
2 changes: 2 additions & 0 deletions python-package/xgboost/__init__.py
Expand Up @@ -6,6 +6,7 @@
from .core import (
DMatrix,
DeviceQuantileDMatrix,
QuantileDMatrix,
Booster,
DataIter,
build_info,
Expand Down Expand Up @@ -33,6 +34,7 @@
# core
"DMatrix",
"DeviceQuantileDMatrix",
"QuantileDMatrix",
"Booster",
"DataIter",
"train",
Expand Down
92 changes: 69 additions & 23 deletions python-package/xgboost/core.py
Expand Up @@ -1146,7 +1146,7 @@ def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None:

Parameters
----------
feature_types : list or None
feature_types :
Labels for features. None will reset existing feature names

"""
Expand Down Expand Up @@ -1189,7 +1189,7 @@ def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None:


class _ProxyDMatrix(DMatrix):
"""A placeholder class when DMatrix cannot be constructed (DeviceQuantileDMatrix,
"""A placeholder class when DMatrix cannot be constructed (QuantileDMatrix,
inplace_predict).

"""
Expand Down Expand Up @@ -1234,17 +1234,35 @@ def _set_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
)


class DeviceQuantileDMatrix(DMatrix):
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do
not use this for test/validation tasks as some information may be lost in
quantisation. This DMatrix is primarily designed to save memory in training from
device memory inputs by avoiding intermediate storage. Set max_bin to control the
number of bins during quantisation. See doc string in :py:obj:`xgboost.DMatrix` for
documents on meta info.
class QuantileDMatrix(DMatrix):
"""A DMatrix variant that generates quantilized data directly from input for
``hist`` and ``gpu_hist`` tree methods. This DMatrix is primarily designed to save
memory in training by avoiding intermediate storage. Set ``max_bin`` to control the
number of bins during quantisation, which should be consistent with the training
parameter ``max_bin``. When ``QuantileDMatrix`` is used for validation/test dataset,
``ref`` should be another ``QuantileDMatrix``(or ``DMatrix``, but not recommended as
it defeats the purpose of saving memory) constructed from training dataset. See
:py:obj:`xgboost.DMatrix` for documents on meta info.

You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
.. note::

.. versionadded:: 1.1.0
Do not use ``QuantileDMatrix`` as validation/test dataset without supplying a
reference (the training dataset) ``QuantileDMatrix`` using ``ref`` as some
information may be lost in quantisation.

.. versionadded:: 2.0.0

Parameters
----------
max_bin :
The number of histogram bin, should be consistent with the training parameter
``max_bin``.

ref :
The training dataset that provides quantile information, needed when creating
validation/test dataset with ``QuantileDMatrix``. Supplying the training DMatrix
as a reference means that the same quantisation applied to the training data is
applied to the validation/test data

"""

Expand All @@ -1261,17 +1279,18 @@ def __init__( # pylint: disable=super-init-not-called
feature_names: Optional[FeatureNames] = None,
feature_types: Optional[FeatureTypes] = None,
nthread: Optional[int] = None,
max_bin: int = 256,
max_bin: Optional[int] = None,
ref: Optional[DMatrix] = None,
group: Optional[ArrayLike] = None,
qid: Optional[ArrayLike] = None,
label_lower_bound: Optional[ArrayLike] = None,
label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False,
) -> None:
self.max_bin = max_bin
self.max_bin: int = max_bin if max_bin is not None else 256
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1
self.nthread = nthread if nthread is not None else -1
self._silent = silent # unused, kept for compatibility

if isinstance(data, ctypes.c_void_p):
Expand All @@ -1280,12 +1299,13 @@ def __init__( # pylint: disable=super-init-not-called

if qid is not None and group is not None:
raise ValueError(
'Only one of the eval_qid or eval_group for each evaluation '
'dataset should be provided.'
"Only one of the eval_qid or eval_group for each evaluation "
"dataset should be provided."
)

self._init(
data,
ref=ref,
label=label,
weight=weight,
base_margin=base_margin,
Expand All @@ -1299,7 +1319,13 @@ def __init__( # pylint: disable=super-init-not-called
enable_categorical=enable_categorical,
)

def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None:
def _init(
self,
data: DataType,
ref: Optional[DMatrix],
enable_categorical: bool,
**meta: Any,
) -> None:
from .data import (
_is_dlpack,
_transform_dlpack,
Expand All @@ -1317,20 +1343,26 @@ def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None:
it = SingleBatchInternalIter(data=data, **meta)

handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(False, enable_categorical)
reset_callback, next_callback = it.get_callbacks(True, enable_categorical)
if it.cache_prefix is not None:
raise ValueError(
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
"QuantileDMatrix doesn't cache data, remove the cache_prefix "
"in iterator to fix this error."
)
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(

args = {
"nthread": self.nthread,
"missing": self.missing,
"max_bin": self.max_bin,
}
config = from_pystr_to_cstr(json.dumps(args))
ret = _LIB.XGQuantileDMatrixCreateFromCallback(
None,
it.proxy.handle,
ref.handle if ref is not None else ref,
reset_callback,
next_callback,
ctypes.c_float(self.missing),
ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin),
config,
ctypes.byref(handle),
)
it.reraise()
Expand All @@ -1339,6 +1371,20 @@ def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None:
self.handle = handle


class DeviceQuantileDMatrix(QuantileDMatrix):
""" Use `QuantileDMatrix` instead.

.. deprecated:: 2.0.0

.. versionadded:: 1.1.0

"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn("Please use `QuantileDMatrix` instead.", FutureWarning)
super().__init__(*args, **kwargs)


Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]

Expand Down