diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 4eb7c109ffe5..7931ad4a4820 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -258,6 +258,7 @@ def __init__( self.feature_names = feature_names self.feature_types = feature_types self.missing = missing + self.enable_categorical = enable_categorical if qid is not None and weight is not None: raise NotImplementedError("per-group weight is not implemented.") @@ -307,7 +308,7 @@ async def _map_local_data( qid: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, - label_upper_bound: Optional[_DaskCollection] = None + label_upper_bound: Optional[_DaskCollection] = None, ) -> "DaskDMatrix": '''Obtain references to local data.''' @@ -426,6 +427,7 @@ def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: 'feature_weights': self.feature_weights, 'meta_names': self.meta_names, 'missing': self.missing, + 'enable_categorical': self.enable_categorical, 'parts': self.worker_map.get(worker_addr, None), 'is_quantile': self.is_quantile} @@ -664,6 +666,7 @@ def _create_device_quantile_dmatrix( missing: float, parts: Optional[_DataParts], max_bin: int, + enable_categorical: bool, ) -> DeviceQuantileDMatrix: worker = distributed.get_worker() if parts is None: @@ -676,6 +679,7 @@ def _create_device_quantile_dmatrix( feature_names=feature_names, feature_types=feature_types, max_bin=max_bin, + enable_categorical=enable_categorical, ) return d @@ -705,6 +709,7 @@ def _create_device_quantile_dmatrix( feature_types=feature_types, nthread=worker.nthreads, max_bin=max_bin, + enable_categorical=enable_categorical, ) dmatrix.set_info(feature_weights=feature_weights) return dmatrix @@ -716,6 +721,7 @@ def _create_dmatrix( feature_weights: Optional[Any], meta_names: List[str], missing: float, + enable_categorical: bool, parts: Optional[_DataParts] ) -> DMatrix: '''Get data that local to worker from DaskDMatrix. @@ -730,9 +736,12 @@ def _create_dmatrix( if list_of_parts is None: msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address) LOGGER.warning(msg) - d = DMatrix(numpy.empty((0, 0)), - feature_names=feature_names, - feature_types=feature_types) + d = DMatrix( + numpy.empty((0, 0)), + feature_names=feature_names, + feature_types=feature_types, + enable_categorical=enable_categorical, + ) return d T = TypeVar('T') @@ -760,6 +769,7 @@ def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]: feature_names=feature_names, feature_types=feature_types, nthread=worker.nthreads, + enable_categorical=enable_categorical, ) dmatrix.set_info( base_margin=_base_margin, diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index d4019c048dbe..36bf6023071c 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -232,7 +232,6 @@ def __init__(self, categorical): ''' import cudf self.rows = self.ROWS_PER_BATCH - self.categorical = categorical if categorical: self._data = [] @@ -276,9 +275,7 @@ def next(self, input_data): if self.it == len(self._data): # Return 0 when there's no more batch. return 0 - input_data( - data=self.data(), label=self.labels(), enable_categorical=self.categorical - ) + input_data(data=self.data(), label=self.labels()) self.it += 1 return 1 diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 78586a7880fe..9cd0a774b8f8 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -6,6 +6,8 @@ import asyncio import xgboost import subprocess +import tempfile +import json from collections import OrderedDict from inspect import signature from hypothesis import given, strategies, settings, note @@ -171,28 +173,27 @@ def test_categorical(rounds, local_cuda_cluster: LocalCUDACluster): X, y = make_categorical(client, 1000, 30, 13) X_onehot, _ = make_categorical(client, 1000, 30, 13, True) - by_etl_results = {} - by_builtin_results = {} - parameters = {"tree_method": "gpu_hist"} - m = dxgb.DaskDMatrix(X_onehot, y, enable_categorical=True) - xgb.train( + m = dxgb.DaskDMatrix(client, X_onehot, y, enable_categorical=True) + by_etl_results = dxgb.train( + client, parameters, m, num_boost_round=rounds, evals=[(m, "Train")], - evals_result=by_etl_results, - ) + )["history"] - m = xgb.DMatrix(X, y, enable_categorical=True) - xgb.train( + m = dxgb.DaskDMatrix(client, X, y, enable_categorical=True) + output = dxgb.train( + client, parameters, m, num_boost_round=rounds, evals=[(m, "Train")], - evals_result=by_builtin_results, ) + by_builtin_results = output["history"] + np.testing.assert_allclose( np.array(by_etl_results["Train"]["rmse"]), np.array(by_builtin_results["Train"]["rmse"]), @@ -200,6 +201,21 @@ def test_categorical(rounds, local_cuda_cluster: LocalCUDACluster): ) assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) + model = output["booster"] + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, "model.json") + model.save_model(path) + with open(path, "r") as fd: + categorical = json.load(fd) + + categories_sizes = np.array( + categorical["learner"]["gradient_booster"]["model"]["trees"][-1][ + "categories_sizes" + ] + ) + assert categories_sizes.shape[0] != 0 + np.testing.assert_allclose(categories_sizes, 1) + def to_cp(x: Any, DMatrixT: Type) -> Any: import cupy