diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 33be6e6275e0..a5ddd7831006 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -22,6 +22,9 @@ Core Data Structure :members: :show-inheritance: +.. autoclass:: xgboost.QuantileDMatrix + :show-inheritance: + .. autoclass:: xgboost.DeviceQuantileDMatrix :show-inheritance: diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 17cd5f4af36d..f66b8097f50e 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -415,28 +415,26 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN * * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, - DMatrixHandle proxy, - DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, - char const* c_json_config, - DMatrixHandle *out); +XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterResetCallback *reset, XGDMatrixCallbackNext *next, + char const *c_json_config, DMatrixHandle *out); /*! * \brief Create a Quantile DMatrix with data iterator. * - * Short note for how to use the second set of callback for GPU Hist tree method: + * Short note for how to use the second set of callback for (GPU)Hist tree method: * * - Step 0: Define a data iterator with 2 methods `reset`, and `next`. * - Step 1: Create a DMatrix proxy by `XGProxyDMatrixCreate` and hold the handle. * - Step 2: Pass the iterator handle, proxy handle and 2 methods into - * `XGDeviceQuantileDMatrixCreateFromCallback`. + * `XGQuantileDMatrixCreateFromCallback`. * - Step 3: Call appropriate data setters in `next` functions. * - * See test_iterative_device_dmatrix.cu or Python interface for examples. + * See test_iterative_dmatrix.cu or Python interface for examples. * * \param iter A handle to external data iterator. * \param proxy A DMatrix proxy handle created by `XGProxyDMatrixCreate`. + * \param ref Reference DMatrix for providing quantile information. * \param reset Callback function resetting the iterator state. * \param next Callback function yielding the next batch of data. * \param missing Which value to represent missing value @@ -446,10 +444,20 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, * * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback( - DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, int nthread, int max_bin, - DMatrixHandle *out); +XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterHandle ref, DataIterResetCallback *reset, + XGDMatrixCallbackNext *next, char const *config, + DMatrixHandle *out); + +/*! + * \brief Create a Device Quantile DMatrix with data iterator. + * \deprecated since 2.0 + * \see XGQuantileDMatrixCreateFromCallback() + */ +XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterResetCallback *reset, + XGDMatrixCallbackNext *next, float missing, + int nthread, int max_bin, DMatrixHandle *out); /*! * \brief Set data on a DMatrix proxy. diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 820d77ce0865..4f06bca3c8dd 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -6,6 +6,7 @@ from .core import ( DMatrix, DeviceQuantileDMatrix, + QuantileDMatrix, Booster, DataIter, build_info, @@ -33,6 +34,7 @@ # core "DMatrix", "DeviceQuantileDMatrix", + "QuantileDMatrix", "Booster", "DataIter", "train", diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 85c5b8b77d49..253cdd92eed8 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1146,7 +1146,7 @@ def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: Parameters ---------- - feature_types : list or None + feature_types : Labels for features. None will reset existing feature names """ @@ -1189,7 +1189,7 @@ def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: class _ProxyDMatrix(DMatrix): - """A placeholder class when DMatrix cannot be constructed (DeviceQuantileDMatrix, + """A placeholder class when DMatrix cannot be constructed (QuantileDMatrix, inplace_predict). """ @@ -1234,17 +1234,35 @@ def _set_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None: ) -class DeviceQuantileDMatrix(DMatrix): - """Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do - not use this for test/validation tasks as some information may be lost in - quantisation. This DMatrix is primarily designed to save memory in training from - device memory inputs by avoiding intermediate storage. Set max_bin to control the - number of bins during quantisation. See doc string in :py:obj:`xgboost.DMatrix` for - documents on meta info. +class QuantileDMatrix(DMatrix): + """A DMatrix variant that generates quantilized data directly from input for + ``hist`` and ``gpu_hist`` tree methods. This DMatrix is primarily designed to save + memory in training by avoiding intermediate storage. Set ``max_bin`` to control the + number of bins during quantisation, which should be consistent with the training + parameter ``max_bin``. When ``QuantileDMatrix`` is used for validation/test dataset, + ``ref`` should be another ``QuantileDMatrix``(or ``DMatrix``, but not recommended as + it defeats the purpose of saving memory) constructed from training dataset. See + :py:obj:`xgboost.DMatrix` for documents on meta info. - You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. + .. note:: - .. versionadded:: 1.1.0 + Do not use ``QuantileDMatrix`` as validation/test dataset without supplying a + reference (the training dataset) ``QuantileDMatrix`` using ``ref`` as some + information may be lost in quantisation. + + .. versionadded:: 2.0.0 + + Parameters + ---------- + max_bin : + The number of histogram bin, should be consistent with the training parameter + ``max_bin``. + + ref : + The training dataset that provides quantile information, needed when creating + validation/test dataset with ``QuantileDMatrix``. Supplying the training DMatrix + as a reference means that the same quantisation applied to the training data is + applied to the validation/test data """ @@ -1261,7 +1279,8 @@ def __init__( # pylint: disable=super-init-not-called feature_names: Optional[FeatureNames] = None, feature_types: Optional[FeatureTypes] = None, nthread: Optional[int] = None, - max_bin: int = 256, + max_bin: Optional[int] = None, + ref: Optional[DMatrix] = None, group: Optional[ArrayLike] = None, qid: Optional[ArrayLike] = None, label_lower_bound: Optional[ArrayLike] = None, @@ -1269,9 +1288,9 @@ def __init__( # pylint: disable=super-init-not-called feature_weights: Optional[ArrayLike] = None, enable_categorical: bool = False, ) -> None: - self.max_bin = max_bin + self.max_bin: int = max_bin if max_bin is not None else 256 self.missing = missing if missing is not None else np.nan - self.nthread = nthread if nthread is not None else 1 + self.nthread = nthread if nthread is not None else -1 self._silent = silent # unused, kept for compatibility if isinstance(data, ctypes.c_void_p): @@ -1280,12 +1299,13 @@ def __init__( # pylint: disable=super-init-not-called if qid is not None and group is not None: raise ValueError( - 'Only one of the eval_qid or eval_group for each evaluation ' - 'dataset should be provided.' + "Only one of the eval_qid or eval_group for each evaluation " + "dataset should be provided." ) self._init( data, + ref=ref, label=label, weight=weight, base_margin=base_margin, @@ -1299,7 +1319,13 @@ def __init__( # pylint: disable=super-init-not-called enable_categorical=enable_categorical, ) - def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None: + def _init( + self, + data: DataType, + ref: Optional[DMatrix], + enable_categorical: bool, + **meta: Any, + ) -> None: from .data import ( _is_dlpack, _transform_dlpack, @@ -1317,20 +1343,26 @@ def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None: it = SingleBatchInternalIter(data=data, **meta) handle = ctypes.c_void_p() - reset_callback, next_callback = it.get_callbacks(False, enable_categorical) + reset_callback, next_callback = it.get_callbacks(True, enable_categorical) if it.cache_prefix is not None: raise ValueError( - "DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix " + "QuantileDMatrix doesn't cache data, remove the cache_prefix " "in iterator to fix this error." ) - ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback( + + args = { + "nthread": self.nthread, + "missing": self.missing, + "max_bin": self.max_bin, + } + config = from_pystr_to_cstr(json.dumps(args)) + ret = _LIB.XGQuantileDMatrixCreateFromCallback( None, it.proxy.handle, + ref.handle if ref is not None else ref, reset_callback, next_callback, - ctypes.c_float(self.missing), - ctypes.c_int(self.nthread), - ctypes.c_int(self.max_bin), + config, ctypes.byref(handle), ) it.reraise() @@ -1339,6 +1371,20 @@ def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None: self.handle = handle +class DeviceQuantileDMatrix(QuantileDMatrix): + """ Use `QuantileDMatrix` instead. + + .. deprecated:: 2.0.0 + + .. versionadded:: 1.1.0 + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn("Please use `QuantileDMatrix` instead.", FutureWarning) + super().__init__(*args, **kwargs) + + Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 1dc557834862..fac4a868a8c0 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -35,6 +35,7 @@ import logging import platform import socket +import warnings from collections import defaultdict from contextlib import contextmanager from functools import partial, update_wrapper @@ -64,10 +65,10 @@ from .core import ( Booster, DataIter, - DeviceQuantileDMatrix, DMatrix, Metric, Objective, + QuantileDMatrix, _deprecate_positional_args, _expect, _has_categorical, @@ -495,7 +496,7 @@ async def map_worker_partitions( client: Optional["distributed.Client"], func: Callable[..., _MapRetT], *refs: Any, - workers: List[str], + workers: Sequence[str], ) -> List[_MapRetT]: """Map a function onto partitions of each worker.""" # Note for function purity: @@ -628,22 +629,7 @@ def next(self, input_data: Callable) -> int: return 1 -class DaskDeviceQuantileDMatrix(DaskDMatrix): - """Specialized data type for `gpu_hist` tree method. This class is used to reduce - the memory usage by eliminating data copies. Internally the all partitions/chunks - of data are merged by weighted GK sketching. So the number of partitions from dask - may affect training accuracy as GK generates bounded error for each merge. See doc - string for :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for - other parameters. - - .. versionadded:: 1.2.0 - - Parameters - ---------- - max_bin : Number of bins for histogram construction. - - """ - +class DaskQuantileDMatrix(DaskDMatrix): @_deprecate_positional_args def __init__( self, @@ -657,7 +643,8 @@ def __init__( silent: bool = False, # disable=unused-argument feature_names: Optional[FeatureNames] = None, feature_types: Optional[Union[Any, List[Any]]] = None, - max_bin: int = 256, + max_bin: Optional[int] = None, + ref: Optional[DMatrix] = None, group: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, @@ -684,14 +671,31 @@ def __init__( ) self.max_bin = max_bin self.is_quantile = True + self._ref: Optional[int] = id(ref) if ref is not None else None def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: args = super()._create_fn_args(worker_addr) args["max_bin"] = self.max_bin + if self._ref is not None: + args["ref"] = self._ref return args -def _create_device_quantile_dmatrix( +class DaskDeviceQuantileDMatrix(DaskQuantileDMatrix): + """Use `DaskQuantileDMatrix` instead. + + .. deprecated:: 2.0.0 + + .. versionadded:: 1.2.0 + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn("Please use `DaskQuantileDMatrix` instead.", FutureWarning) + super().__init__(*args, **kwargs) + + +def _create_quantile_dmatrix( feature_names: Optional[FeatureNames], feature_types: Optional[Union[Any, List[Any]]], feature_weights: Optional[Any], @@ -700,18 +704,20 @@ def _create_device_quantile_dmatrix( parts: Optional[_DataParts], max_bin: int, enable_categorical: bool, -) -> DeviceQuantileDMatrix: + ref: Optional[DMatrix] = None, +) -> QuantileDMatrix: worker = distributed.get_worker() if parts is None: msg = f"worker {worker.address} has an empty DMatrix." LOGGER.warning(msg) import cupy - d = DeviceQuantileDMatrix( + d = QuantileDMatrix( cupy.zeros((0, 0)), feature_names=feature_names, feature_types=feature_types, max_bin=max_bin, + ref=ref, enable_categorical=enable_categorical, ) return d @@ -719,13 +725,14 @@ def _create_device_quantile_dmatrix( unzipped_dict = _get_worker_parts(parts) it = DaskPartitionIter(**unzipped_dict) - dmatrix = DeviceQuantileDMatrix( + dmatrix = QuantileDMatrix( it, missing=missing, feature_names=feature_names, feature_types=feature_types, nthread=nthread, max_bin=max_bin, + ref=ref, enable_categorical=enable_categorical, ) dmatrix.set_info(feature_weights=feature_weights) @@ -786,11 +793,9 @@ def concat_or_none(data: Sequence[Optional[T]]) -> Optional[T]: return dmatrix -def _dmatrix_from_list_of_parts( - is_quantile: bool, **kwargs: Any -) -> Union[DMatrix, DeviceQuantileDMatrix]: +def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix: if is_quantile: - return _create_device_quantile_dmatrix(**kwargs) + return _create_quantile_dmatrix(**kwargs) return _create_dmatrix(**kwargs) @@ -921,7 +926,18 @@ def dispatched_train( if evals_id[i] == train_id: evals.append((Xy, evals_name[i])) continue - eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads) + if ref.get("ref", None) is not None: + if ref["ref"] != train_id: + raise ValueError( + "The training DMatrix should be used as a reference" + " to evaluation `QuantileDMatrix`." + ) + del ref["ref"] + eval_Xy = _dmatrix_from_list_of_parts( + **ref, nthread=n_threads, ref=Xy + ) + else: + eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads) evals.append((eval_Xy, evals_name[i])) booster = worker_train( @@ -960,12 +976,14 @@ def dispatched_train( results = await map_worker_partitions( client, dispatched_train, + # extra function parameters params, _rabit_args, id(dtrain), evals_name, evals_id, *([dtrain] + evals_data), + # workers to be used for training workers=workers, ) return list(filter(lambda ret: ret is not None, results))[0] diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 99e0c2219453..ea63ef9c855c 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -1167,6 +1167,7 @@ def _proxy_transform( if _is_dlpack(data): return _transform_dlpack(data), None, feature_names, feature_types if _is_numpy_array(data): + data, _ = _ensure_np_dtype(data, data.dtype) return data, None, feature_names, feature_types if _is_scipy_csr(data): return data, None, feature_names, feature_types diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9ba53442adb0..e92af3c50f4e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -281,11 +281,36 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr int nthread, int max_bin, DMatrixHandle *out) { API_BEGIN(); + LOG(WARNING) << __func__ << " is deprecated. Use `XGQuantileDMatrixCreateFromCallback` instead."; *out = new std::shared_ptr{ xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)}; API_END(); } +XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterHandle ref, DataIterResetCallback *reset, + XGDMatrixCallbackNext *next, char const *config, + DMatrixHandle *out) { + API_BEGIN(); + std::shared_ptr _ref{nullptr}; + if (ref) { + auto pp_ref = static_cast *>(ref); + StringView err{"Invalid handle to ref."}; + CHECK(pp_ref) << err; + _ref = *pp_ref; + CHECK(_ref) << err; + } + + auto jconfig = Json::Load(StringView{config}); + auto missing = GetMissing(jconfig); + auto n_threads = OptionalArg(jconfig, "nthread", common::OmpGetNumThreads(0)); + auto max_bin = OptionalArg(jconfig, "max_bin", 256); + + *out = new std::shared_ptr{ + xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)}; + API_END(); +} + XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) { API_BEGIN(); *out = new std::shared_ptr(new xgboost::data::DMatrixProxy);; diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index 348d75842877..f91cf6bd9aa6 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import numpy as np import xgboost as xgb import pytest @@ -6,16 +5,14 @@ sys.path.append("tests/python") import testing as tm +import test_quantile_dmatrix as tqd class TestDeviceQuantileDMatrix: - def test_dmatrix_numpy_init(self): - data = np.random.randn(5, 5) - with pytest.raises(TypeError, match='is not supported'): - xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) + cputest = tqd.TestQuantileDMatrix() @pytest.mark.skipif(**tm.no_cupy()) - def test_dmatrix_feature_weights(self): + def test_dmatrix_feature_weights(self) -> None: import cupy as cp rng = cp.random.RandomState(1994) data = rng.randn(5, 5) @@ -29,7 +26,7 @@ def test_dmatrix_feature_weights(self): feature_weights.astype(np.float32)) @pytest.mark.skipif(**tm.no_cupy()) - def test_dmatrix_cupy_init(self): + def test_dmatrix_cupy_init(self) -> None: import cupy as cp data = cp.random.randn(5, 5) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) @@ -55,3 +52,10 @@ def test_metainfo(self) -> None: cp.testing.assert_allclose(fw, got_fw) cp.testing.assert_allclose(labels, got_labels) + + @pytest.mark.skipif(**tm.no_cupy()) + @pytest.mark.skipif(**tm.no_cudf()) + def test_ref_dmatrix(self) -> None: + import cupy as cp + rng = cp.random.RandomState(1994) + self.cputest.run_ref_dmatrix(rng, "gpu_hist", False) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index dcb228adb238..3cb110bd6c6e 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -429,9 +429,10 @@ def test_interface_consistency(self) -> None: sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters) del sig["client"] ddm_names = list(sig.keys()) - sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters) + sig = OrderedDict(signature(dxgb.DaskQuantileDMatrix).parameters) del sig["client"] del sig["max_bin"] + del sig["ref"] ddqdm_names = list(sig.keys()) assert len(ddm_names) == len(ddqdm_names) @@ -442,9 +443,10 @@ def test_interface_consistency(self) -> None: sig = OrderedDict(signature(xgb.DMatrix).parameters) del sig["nthread"] # no nthread in dask dm_names = list(sig.keys()) - sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters) + sig = OrderedDict(signature(xgb.QuantileDMatrix).parameters) del sig["nthread"] del sig["max_bin"] + del sig["ref"] dqdm_names = list(sig.keys()) # between single node @@ -499,7 +501,6 @@ def runit( for arg in rabit_args: if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): port_env = arg.decode('utf-8') - port_env = arg.decode('utf-8') if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"): uri_env = arg.decode("utf-8") port = port_env.split('=') diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 233a3a4d0dce..5e0e4686002f 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -1,32 +1,12 @@ import xgboost as xgb from xgboost.data import SingleBatchInternalIter as SingleBatch import numpy as np -from testing import IteratorForTest, non_increasing -from typing import Tuple, List +from testing import IteratorForTest, non_increasing, make_batches import pytest from hypothesis import given, strategies, settings from scipy.sparse import csr_matrix -def make_batches( - n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False -) -> Tuple[List[np.ndarray], List[np.ndarray]]: - X = [] - y = [] - if use_cupy: - import cupy - - rng = cupy.random.RandomState(1994) - else: - rng = np.random.RandomState(1994) - for i in range(n_batches): - _X = rng.randn(n_samples_per_batch, n_features) - _y = rng.randn(n_samples_per_batch) - X.append(_X) - y.append(_y) - return X, y - - def test_single_batch(tree_method: str = "approx") -> None: from sklearn.datasets import load_breast_cancer @@ -111,8 +91,8 @@ def run_data_iterator( if not subsample: assert non_increasing(results_from_it["Train"]["rmse"]) - X, y = it.as_arrays() - Xy = xgb.DMatrix(X, y) + X, y, w = it.as_arrays() + Xy = xgb.DMatrix(X, y, weight=w) assert Xy.num_row() == n_samples_per_batch * n_batches assert Xy.num_col() == n_features diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py new file mode 100644 index 000000000000..d17102c98c88 --- /dev/null +++ b/tests/python/test_quantile_dmatrix.py @@ -0,0 +1,212 @@ +from typing import Dict, List, Any + +import numpy as np +import pytest +from scipy import sparse +from testing import IteratorForTest, make_batches, make_batches_sparse, make_categorical + +import xgboost as xgb + + +class TestQuantileDMatrix: + def test_basic(self) -> None: + n_samples = 234 + n_features = 8 + + rng = np.random.default_rng() + X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape( + n_samples, n_features + ) + y = rng.normal(0, 3, size=n_samples) + Xy = xgb.QuantileDMatrix(X, y) + assert Xy.num_row() == n_samples + assert Xy.num_col() == n_features + + X = sparse.random(n_samples, n_features, density=0.1, format="csr") + Xy = xgb.QuantileDMatrix(X, y) + assert Xy.num_row() == n_samples + assert Xy.num_col() == n_features + + X = sparse.random(n_samples, n_features, density=0.8, format="csr") + Xy = xgb.QuantileDMatrix(X, y) + assert Xy.num_row() == n_samples + assert Xy.num_col() == n_features + + @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9]) + def test_with_iterator(self, sparsity: float) -> None: + n_samples_per_batch = 317 + n_features = 8 + n_batches = 7 + + if sparsity == 0.0: + it = IteratorForTest( + *make_batches(n_samples_per_batch, n_features, n_batches, False), None + ) + else: + it = IteratorForTest( + *make_batches_sparse( + n_samples_per_batch, n_features, n_batches, sparsity + ), + None + ) + Xy = xgb.QuantileDMatrix(it) + assert Xy.num_row() == n_samples_per_batch * n_batches + assert Xy.num_col() == n_features + + @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9]) + def test_training(self, sparsity: float) -> None: + n_samples_per_batch = 317 + n_features = 8 + n_batches = 7 + if sparsity == 0.0: + it = IteratorForTest( + *make_batches(n_samples_per_batch, n_features, n_batches, False), None + ) + else: + it = IteratorForTest( + *make_batches_sparse( + n_samples_per_batch, n_features, n_batches, sparsity + ), + None + ) + + parameters = {"tree_method": "hist", "max_bin": 256} + Xy_it = xgb.QuantileDMatrix(it, max_bin=parameters["max_bin"]) + from_it = xgb.train(parameters, Xy_it) + + X, y, w = it.as_arrays() + w_it = Xy_it.get_weight() + np.testing.assert_allclose(w_it, w) + + Xy_arr = xgb.DMatrix(X, y, weight=w) + from_arr = xgb.train(parameters, Xy_arr) + + np.testing.assert_allclose(from_arr.predict(Xy_it), from_it.predict(Xy_arr)) + + y -= y.min() + y += 0.01 + Xy = xgb.QuantileDMatrix(X, y, weight=w) + with pytest.raises(ValueError, match=r"Only.*hist.*"): + parameters = { + "tree_method": "approx", + "max_bin": 256, + "objective": "reg:gamma", + } + xgb.train(parameters, Xy) + + def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: + n_samples, n_features = 2048, 17 + if enable_cat: + X, y = make_categorical( + n_samples, n_features, n_categories=13, onehot=False + ) + if tree_method == "gpu_hist": + import cudf + X = cudf.from_pandas(X) + y = cudf.from_pandas(y) + else: + X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape( + n_samples, n_features + ) + y = rng.normal(0, 3, size=n_samples) + + # Use ref + Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat) + Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat) + qdm_results: Dict[str, Dict[str, List[float]]] = {} + xgb.train( + {"tree_method": tree_method}, + Xy, + evals=[(Xy, "Train"), (Xy_valid, "valid")], + evals_result=qdm_results, + ) + np.testing.assert_allclose( + qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"] + ) + # No ref + Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat) + qdm_results = {} + xgb.train( + {"tree_method": tree_method}, + Xy, + evals=[(Xy, "Train"), (Xy_valid, "valid")], + evals_result=qdm_results, + ) + np.testing.assert_allclose( + qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"] + ) + + # Different number of features + Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat) + dXy = xgb.DMatrix(X, y, enable_categorical=enable_cat) + + n_samples, n_features = 256, 15 + X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape( + n_samples, n_features + ) + y = rng.normal(0, 3, size=n_samples) + with pytest.raises(ValueError, match=r".*features\."): + xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat) + + # Compare training results + n_samples, n_features = 256, 17 + if enable_cat: + X, y = make_categorical(n_samples, n_features, 13, onehot=False) + if tree_method == "gpu_hist": + import cudf + X = cudf.from_pandas(X) + y = cudf.from_pandas(y) + else: + X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape( + n_samples, n_features + ) + y = rng.normal(0, 3, size=n_samples) + Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat) + # use DMatrix as ref + Xy_valid_d = xgb.QuantileDMatrix(X, y, ref=dXy, enable_categorical=enable_cat) + dXy_valid = xgb.DMatrix(X, y, enable_categorical=enable_cat) + + qdm_results = {} + xgb.train( + {"tree_method": tree_method}, + Xy, + evals=[(Xy, "Train"), (Xy_valid, "valid")], + evals_result=qdm_results, + ) + + dm_results: Dict[str, Dict[str, List[float]]] = {} + xgb.train( + {"tree_method": tree_method}, + dXy, + evals=[(dXy, "Train"), (dXy_valid, "valid"), (Xy_valid_d, "dvalid")], + evals_result=dm_results, + ) + np.testing.assert_allclose( + dm_results["Train"]["rmse"], qdm_results["Train"]["rmse"] + ) + np.testing.assert_allclose( + dm_results["valid"]["rmse"], qdm_results["valid"]["rmse"] + ) + np.testing.assert_allclose( + dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"] + ) + + def test_ref_dmatrix(self) -> None: + rng = np.random.RandomState(1994) + self.run_ref_dmatrix(rng, "hist", True) + self.run_ref_dmatrix(rng, "hist", False) + + def test_predict(self) -> None: + n_samples, n_features = 16, 2 + X, y = make_categorical( + n_samples, n_features, n_categories=13, onehot=False + ) + Xy = xgb.DMatrix(X, y, enable_categorical=True) + + booster = xgb.train({"tree_method": "hist"}, Xy) + + Xy = xgb.DMatrix(X, y, enable_categorical=True) + a = booster.predict(Xy) + qXy = xgb.QuantileDMatrix(X, y, enable_categorical=True) + b = booster.predict(qXy) + np.testing.assert_allclose(a, b) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 0e25f1da0df2..178bfc18dbdf 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1382,6 +1382,42 @@ def test_hist( num_rounds = 30 self.run_updater_test(client, params, num_rounds, dataset, 'hist') + def test_quantile_dmatrix(self, client: Client) -> None: + X, y = make_categorical(client, 10000, 30, 13) + + Xy = xgb.dask.DaskDMatrix(client, X, y, enable_categorical=True) + valid_Xy = xgb.dask.DaskDMatrix(client, X, y, enable_categorical=True) + + output = xgb.dask.train( + client, + {"tree_method": "hist"}, + Xy, + num_boost_round=10, + evals=[(Xy, "Train"), (valid_Xy, "Valid")] + ) + dmatrix_hist = output["history"] + + Xy = xgb.dask.DaskQuantileDMatrix(client, X, y, enable_categorical=True) + valid_Xy = xgb.dask.DaskQuantileDMatrix( + client, X, y, enable_categorical=True, ref=Xy + ) + + output = xgb.dask.train( + client, + {"tree_method": "hist"}, + Xy, + num_boost_round=10, + evals=[(Xy, "Train"), (valid_Xy, "Valid")] + ) + quantile_hist = output["history"] + + np.testing.assert_allclose( + quantile_hist["Train"]["rmse"], dmatrix_hist["Train"]["rmse"] + ) + np.testing.assert_allclose( + quantile_hist["Valid"]["rmse"], dmatrix_hist["Valid"]["rmse"] + ) + @given(params=exact_parameter_strategy, dataset=tm.dataset_strategy) @settings(deadline=None, suppress_health_check=suppress, print_blob=True) diff --git a/tests/python/testing.py b/tests/python/testing.py index d1e19330b766..28afb30b6b8b 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -1,11 +1,11 @@ from concurrent.futures import ThreadPoolExecutor import os import multiprocessing -from typing import Tuple, Union +from typing import Tuple, Union, List, Sequence, Callable import urllib import zipfile import sys -from typing import Optional +from typing import Optional, Dict, Any from contextlib import contextmanager from io import StringIO from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED @@ -180,79 +180,148 @@ def skip_s390x(): class IteratorForTest(xgb.core.DataIter): - def __init__(self, X, y): + def __init__( + self, + X: Sequence, + y: Sequence, + w: Optional[Sequence], + cache: Optional[str] = "./" + ) -> None: assert len(X) == len(y) self.X = X self.y = y + self.w = w self.it = 0 - super().__init__("./") + super().__init__(cache) - def next(self, input_data): + def next(self, input_data: Callable) -> int: if self.it == len(self.X): return 0 # Use copy to make sure the iterator doesn't hold a reference to the data. - input_data(data=self.X[self.it].copy(), label=self.y[self.it].copy()) - gc.collect() # clear up the copy, see if XGBoost access freed memory. + input_data( + data=self.X[self.it].copy(), + label=self.y[self.it].copy(), + weight=self.w[self.it].copy() if self.w else None, + ) + gc.collect() # clear up the copy, see if XGBoost access freed memory. self.it += 1 return 1 - def reset(self): + def reset(self) -> None: self.it = 0 - def as_arrays(self): - X = np.concatenate(self.X, axis=0) + def as_arrays( + self, + ) -> Tuple[Union[np.ndarray, sparse.csr_matrix], np.ndarray, np.ndarray]: + if isinstance(self.X[0], sparse.csr_matrix): + X = sparse.vstack(self.X, format="csr") + else: + X = np.concatenate(self.X, axis=0) y = np.concatenate(self.y, axis=0) - return X, y + w = np.concatenate(self.w, axis=0) + return X, y, w + + +def make_batches( + n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False +) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + X = [] + y = [] + w = [] + if use_cupy: + import cupy + + rng = cupy.random.RandomState(1994) + else: + rng = np.random.RandomState(1994) + for i in range(n_batches): + _X = rng.randn(n_samples_per_batch, n_features) + _y = rng.randn(n_samples_per_batch) + _w = rng.uniform(low=0, high=1, size=n_samples_per_batch) + X.append(_X) + y.append(_y) + w.append(_w) + return X, y, w + + +def make_batches_sparse( + n_samples_per_batch: int, n_features: int, n_batches: int, sparsity: float +) -> Tuple[List[sparse.csr_matrix], List[np.ndarray], List[np.ndarray]]: + X = [] + y = [] + w = [] + rng = np.random.RandomState(1994) + for i in range(n_batches): + _X = sparse.random( + n_samples_per_batch, + n_features, + 1.0 - sparsity, + format="csr", + dtype=np.float32, + random_state=rng, + ) + _y = rng.randn(n_samples_per_batch) + _w = rng.uniform(low=0, high=1, size=n_samples_per_batch) + X.append(_X) + y.append(_y) + w.append(_w) + return X, y, w # Contains a dataset in numpy format as well as the relevant objective and metric class TestDataset: - def __init__(self, name, get_dataset, objective, metric): + def __init__( + self, name: str, get_dataset: Callable, objective: str, metric: str + ) -> None: self.name = name self.objective = objective self.metric = metric self.X, self.y = get_dataset() - self.w = None + self.w: Optional[np.ndarray] = None self.margin: Optional[np.ndarray] = None - def set_params(self, params_in): + def set_params(self, params_in: Dict[str, Any]) -> Dict[str, Any]: params_in['objective'] = self.objective params_in['eval_metric'] = self.metric if self.objective == "multi:softmax": params_in["num_class"] = int(np.max(self.y) + 1) return params_in - def get_dmat(self): + def get_dmat(self) -> xgb.DMatrix: return xgb.DMatrix( self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True ) - def get_device_dmat(self): + def get_device_dmat(self) -> xgb.DeviceQuantileDMatrix: w = None if self.w is None else cp.array(self.w) X = cp.array(self.X, dtype=np.float32) y = cp.array(self.y, dtype=np.float32) return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin) - def get_external_dmat(self): + def get_external_dmat(self) -> xgb.DMatrix: n_samples = self.X.shape[0] n_batches = 10 per_batch = n_samples // n_batches + 1 predictor = [] response = [] + weight = [] for i in range(n_batches): beg = i * per_batch end = min((i + 1) * per_batch, n_samples) assert end != beg X = self.X[beg: end, ...] y = self.y[beg: end] + w = self.w[beg: end] if self.w is not None else None predictor.append(X) response.append(y) + if w is not None: + weight.append(w) - it = IteratorForTest(predictor, response) + it = IteratorForTest(predictor, response, weight if weight else None) return xgb.DMatrix(it) - def __repr__(self): + def __repr__(self) -> str: return self.name diff --git a/tests/python/with_omp_limit.py b/tests/python/with_omp_limit.py index 7fc59a4707fe..950ec03648f2 100644 --- a/tests/python/with_omp_limit.py +++ b/tests/python/with_omp_limit.py @@ -1,4 +1,3 @@ -import os import xgboost as xgb from sklearn.datasets import make_classification from sklearn.metrics import roc_auc_score