Skip to content

Commit

Permalink
Add DaskDeviceQuantileDMatrix demo. (#6156)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 24, 2020
1 parent 678ea40 commit 78d72ef
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
56 changes: 47 additions & 9 deletions demo/dask/gpu_training.py
Expand Up @@ -2,16 +2,13 @@
from dask.distributed import Client
from dask import array as da
import xgboost as xgb
from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix
import cupy as cp
import argparse


def main(client):
# generate some random data for demonstration
m = 100000
n = 100
X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100)

def using_dask_matrix(client: Client, X, y):
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
# DMatrix scatter around workers.
dtrain = DaskDMatrix(client, X, y)
Expand All @@ -31,15 +28,56 @@ def main(client):

# you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain)
prediction = prediction.compute()
print('Evaluation history:', history)
return prediction


def using_quantile_device_dmatrix(client: Client, X, y):
'''`DaskDeviceQuantileDMatrix` is a data type specialized for `gpu_hist`, tree
method that reduces memory overhead. When training on GPU pipeline, it's
preferred over `DaskDMatrix`.
.. versionadded:: 1.2.0
'''
# Input must be on GPU for `DaskDeviceQuantileDMatrix`.
X = X.map_blocks(cp.array)
y = y.map_blocks(cp.array)

# `DaskDeviceQuantileDMatrix` is used instead of `DaskDMatrix`, be careful
# that it can not be used for anything else than training.
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
output = xgb.dask.train(client,
{'verbosity': 2,
'tree_method': 'gpu_hist'},
dtrain,
num_boost_round=4)

prediction = xgb.dask.predict(client, output, X)
return prediction


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--ddqdm', choices=[0, 1], type=int, default=1,
help='''Whether should we use `DaskDeviceQuantileDMatrix`''')
args = parser.parse_args()

# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
# `n_workers` represents the number of GPUs since we use one GPU per worker
# process.
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)
# generate some random data for demonstration
m = 100000
n = 100
X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100)

if args.ddqdm == 1:
print('Using DaskDeviceQuantileDMatrix')
from_ddqdm = using_quantile_device_dmatrix(client, X, y)
else:
print('Using DMatrix')
from_dmatrix = using_dask_matrix(client, X, y)
3 changes: 2 additions & 1 deletion python-package/xgboost/dask.py
Expand Up @@ -854,7 +854,8 @@ def predict(client, model, data, missing=numpy.nan, **kwargs):
model: A Booster or a dictionary returned by `xgboost.dask.train`.
The trained model.
data: DaskDMatrix/dask.dataframe.DataFrame/dask.array.Array
Input data used for prediction.
Input data used for prediction. When input is a dataframe object,
prediction output is a series.
missing: float
Used when input data is not DaskDMatrix. Specify the value
considered as missing.
Expand Down
14 changes: 14 additions & 0 deletions tests/python-gpu/test_gpu_demos.py
Expand Up @@ -6,8 +6,22 @@
import testing as tm
import test_demos as td # noqa


@pytest.mark.skipif(**tm.no_cupy())
def test_data_iterator():
script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py')
cmd = ['python', script]
subprocess.check_call(cmd)


@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_dask_training():
script = os.path.join(tm.PROJECT_ROOT, 'demo', 'dask', 'gpu_training.py')
cmd = ['python', script, '--ddqdm=1']
subprocess.check_call(cmd)

cmd = ['python', script, '--ddqdm=0']
subprocess.check_call(cmd)

0 comments on commit 78d72ef

Please sign in to comment.