From 47350f6acb5b9696f640a3b8515940d3bd0a2868 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 15 Sep 2020 13:04:03 +1200 Subject: [PATCH] Allow kwargs in dask predict (#6117) --- python-package/xgboost/dask.py | 14 +++++++------- tests/python/test_with_dask.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index a08c21367a94..b5479fdb9bb5 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -688,8 +688,8 @@ async def _direct_predict_impl(client, data, predict_fn): # pylint: disable=too-many-statements -async def _predict_async(client: Client, model, data, *args, - missing=numpy.nan): +async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs): + if isinstance(model, Booster): booster = model elif isinstance(model, dict): @@ -704,7 +704,7 @@ def mapped_predict(partition, is_df): worker = distributed_get_worker() booster.set_param({'nthread': worker.nthreads}) m = DMatrix(partition, missing=missing, nthread=worker.nthreads) - predt = booster.predict(m, *args, validate_features=False) + predt = booster.predict(m, validate_features=False, **kwargs) if is_df: if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'): import cudf # pylint: disable=import-error @@ -737,7 +737,7 @@ def dispatched_predict(worker_id): missing=missing, nthread=worker.nthreads) predt = booster.predict(data=local_x, validate_features=local_x.num_row() != 0, - *args) + **kwargs) columns = 1 if len(predt.shape) == 1 else predt.shape[1] ret = ((delayed(predt), columns), order) predictions.append(ret) @@ -784,7 +784,7 @@ async def map_function(func): return predictions -def predict(client, model, data, *args, missing=numpy.nan): +def predict(client, model, data, missing=numpy.nan, **kwargs): '''Run prediction with a trained booster. .. note:: @@ -813,8 +813,8 @@ def predict(client, model, data, *args, missing=numpy.nan): ''' _assert_dask_support() client = _xgb_get_client(client) - return client.sync(_predict_async, client, model, data, *args, - missing=missing) + return client.sync(_predict_async, client, model, data, + missing=missing, **kwargs) async def _inplace_predict_async(client, model, data, diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 145fa0b524cd..72abf25c96e2 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -215,7 +215,7 @@ def test_dask_classifier(): classifier = xgb.dask.DaskXGBClassifier( verbosity=1, n_estimators=2) classifier.client = client - classifier.fit(X, y, eval_set=[(X, y)]) + classifier.fit(X, y, eval_set=[(X, y)]) prediction = classifier.predict(X) assert prediction.ndim == 1 @@ -276,7 +276,6 @@ def test_sklearn_grid_search(): def run_empty_dmatrix_reg(client, parameters): - def _check_outputs(out, predictions): assert isinstance(out['booster'], xgb.dask.Booster) assert len(out['history']['validation']['rmse']) == 2 @@ -424,7 +423,7 @@ async def run_dask_classifier_asyncio(scheduler_address): classifier = await xgb.dask.DaskXGBClassifier( verbosity=1, n_estimators=2) classifier.client = client - await classifier.fit(X, y, eval_set=[(X, y)]) + await classifier.fit(X, y, eval_set=[(X, y)]) prediction = await classifier.predict(X) assert prediction.ndim == 1 @@ -447,7 +446,6 @@ async def run_dask_classifier_asyncio(scheduler_address): assert probas.shape[0] == kRows assert probas.shape[1] == 10 - # Test with dataframe. X_d = dd.from_dask_array(X) y_d = dd.from_dask_array(y) @@ -472,6 +470,28 @@ def test_with_asyncio(): asyncio.run(run_dask_classifier_asyncio(address)) +def test_predict(): + with LocalCluster(n_workers=kWorkers) as cluster: + with Client(cluster) as client: + X, y = generate_array() + dtrain = DaskDMatrix(client, X, y) + booster = xgb.dask.train( + client, {}, dtrain, num_boost_round=2)['booster'] + + pred = xgb.dask.predict(client, model=booster, data=dtrain) + assert pred.ndim == 1 + assert pred.shape[0] == kRows + + margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True) + assert margin.ndim == 1 + assert margin.shape[0] == kRows + + shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True) + assert shap.ndim == 2 + assert shap.shape[0] == kRows + assert shap.shape[1] == kCols + 1 + + class TestWithDask: def run_updater_test(self, client, params, num_rounds, dataset, tree_method): @@ -489,9 +509,9 @@ def run_updater_test(self, client, params, num_rounds, dataset, chunk = 128 X = da.from_array(dataset.X, chunks=(chunk, dataset.X.shape[1])) - y = da.from_array(dataset.y, chunks=(chunk, )) + y = da.from_array(dataset.y, chunks=(chunk,)) if dataset.w is not None: - w = da.from_array(dataset.w, chunks=(chunk, )) + w = da.from_array(dataset.w, chunks=(chunk,)) else: w = None