Skip to content

Commit

Permalink
Support dask.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 16, 2021
1 parent 9182993 commit 42ca226
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
18 changes: 14 additions & 4 deletions python-package/xgboost/dask.py
Expand Up @@ -258,6 +258,7 @@ def __init__(
self.feature_names = feature_names
self.feature_types = feature_types
self.missing = missing
self.enable_categorical = enable_categorical

if qid is not None and weight is not None:
raise NotImplementedError("per-group weight is not implemented.")
Expand Down Expand Up @@ -307,7 +308,7 @@ async def _map_local_data(
qid: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None
label_upper_bound: Optional[_DaskCollection] = None,
) -> "DaskDMatrix":
'''Obtain references to local data.'''

Expand Down Expand Up @@ -426,6 +427,7 @@ def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
'feature_weights': self.feature_weights,
'meta_names': self.meta_names,
'missing': self.missing,
'enable_categorical': self.enable_categorical,
'parts': self.worker_map.get(worker_addr, None),
'is_quantile': self.is_quantile}

Expand Down Expand Up @@ -664,6 +666,7 @@ def _create_device_quantile_dmatrix(
missing: float,
parts: Optional[_DataParts],
max_bin: int,
enable_categorical: bool,
) -> DeviceQuantileDMatrix:
worker = distributed.get_worker()
if parts is None:
Expand All @@ -676,6 +679,7 @@ def _create_device_quantile_dmatrix(
feature_names=feature_names,
feature_types=feature_types,
max_bin=max_bin,
enable_categorical=enable_categorical,
)
return d

Expand Down Expand Up @@ -705,6 +709,7 @@ def _create_device_quantile_dmatrix(
feature_types=feature_types,
nthread=worker.nthreads,
max_bin=max_bin,
enable_categorical=enable_categorical,
)
dmatrix.set_info(feature_weights=feature_weights)
return dmatrix
Expand All @@ -716,6 +721,7 @@ def _create_dmatrix(
feature_weights: Optional[Any],
meta_names: List[str],
missing: float,
enable_categorical: bool,
parts: Optional[_DataParts]
) -> DMatrix:
'''Get data that local to worker from DaskDMatrix.
Expand All @@ -730,9 +736,12 @@ def _create_dmatrix(
if list_of_parts is None:
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=feature_names,
feature_types=feature_types)
d = DMatrix(
numpy.empty((0, 0)),
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
)
return d

T = TypeVar('T')
Expand Down Expand Up @@ -760,6 +769,7 @@ def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]:
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads,
enable_categorical=enable_categorical,
)
dmatrix.set_info(
base_margin=_base_margin,
Expand Down
5 changes: 1 addition & 4 deletions tests/python-gpu/test_from_cudf.py
Expand Up @@ -232,7 +232,6 @@ def __init__(self, categorical):
'''
import cudf
self.rows = self.ROWS_PER_BATCH
self.categorical = categorical

if categorical:
self._data = []
Expand Down Expand Up @@ -276,9 +275,7 @@ def next(self, input_data):
if self.it == len(self._data):
# Return 0 when there's no more batch.
return 0
input_data(
data=self.data(), label=self.labels(), enable_categorical=self.categorical
)
input_data(data=self.data(), label=self.labels())
self.it += 1
return 1

Expand Down
36 changes: 26 additions & 10 deletions tests/python-gpu/test_gpu_with_dask.py
Expand Up @@ -6,6 +6,8 @@
import asyncio
import xgboost
import subprocess
import tempfile
import json
from collections import OrderedDict
from inspect import signature
from hypothesis import given, strategies, settings, note
Expand Down Expand Up @@ -171,35 +173,49 @@ def test_categorical(rounds, local_cuda_cluster: LocalCUDACluster):
X, y = make_categorical(client, 1000, 30, 13)
X_onehot, _ = make_categorical(client, 1000, 30, 13, True)

by_etl_results = {}
by_builtin_results = {}

parameters = {"tree_method": "gpu_hist"}

m = dxgb.DaskDMatrix(X_onehot, y, enable_categorical=True)
xgb.train(
m = dxgb.DaskDMatrix(client, X_onehot, y, enable_categorical=True)
by_etl_results = dxgb.train(
client,
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
)["history"]

m = xgb.DMatrix(X, y, enable_categorical=True)
xgb.train(
m = dxgb.DaskDMatrix(client, X, y, enable_categorical=True)
output = dxgb.train(
client,
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
by_builtin_results = output["history"]

np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])

model = output["booster"]
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, "model.json")
model.save_model(path)
with open(path, "r") as fd:
categorical = json.load(fd)

categories_sizes = np.array(
categorical["learner"]["gradient_booster"]["model"]["trees"][-1][
"categories_sizes"
]
)
assert categories_sizes.shape[0] != 0
np.testing.assert_allclose(categories_sizes, 1)


def to_cp(x: Any, DMatrixT: Type) -> Any:
import cupy
Expand Down

0 comments on commit 42ca226

Please sign in to comment.