Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow kwargs in dask predict #6117

Merged
merged 1 commit into from Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions python-package/xgboost/dask.py
Expand Up @@ -688,8 +688,8 @@ async def _direct_predict_impl(client, data, predict_fn):


# pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, *args,
missing=numpy.nan):
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):

if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
Expand All @@ -704,7 +704,7 @@ def mapped_predict(partition, is_df):
worker = distributed_get_worker()
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, *args, validate_features=False)
predt = booster.predict(m, validate_features=False, **kwargs)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error
Expand Down Expand Up @@ -737,7 +737,7 @@ def dispatched_predict(worker_id):
missing=missing, nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order)
predictions.append(ret)
Expand Down Expand Up @@ -784,7 +784,7 @@ async def map_function(func):
return predictions


def predict(client, model, data, *args, missing=numpy.nan):
def predict(client, model, data, missing=numpy.nan, **kwargs):
'''Run prediction with a trained booster.

.. note::
Expand Down Expand Up @@ -813,8 +813,8 @@ def predict(client, model, data, *args, missing=numpy.nan):
'''
_assert_dask_support()
client = _xgb_get_client(client)
return client.sync(_predict_async, client, model, data, *args,
missing=missing)
return client.sync(_predict_async, client, model, data,
missing=missing, **kwargs)


async def _inplace_predict_async(client, model, data,
Expand Down
32 changes: 26 additions & 6 deletions tests/python/test_with_dask.py
Expand Up @@ -215,7 +215,7 @@ def test_dask_classifier():
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2)
classifier.client = client
classifier.fit(X, y, eval_set=[(X, y)])
classifier.fit(X, y, eval_set=[(X, y)])
prediction = classifier.predict(X)

assert prediction.ndim == 1
Expand Down Expand Up @@ -276,7 +276,6 @@ def test_sklearn_grid_search():


def run_empty_dmatrix_reg(client, parameters):

def _check_outputs(out, predictions):
assert isinstance(out['booster'], xgb.dask.Booster)
assert len(out['history']['validation']['rmse']) == 2
Expand Down Expand Up @@ -424,7 +423,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
classifier = await xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2)
classifier.client = client
await classifier.fit(X, y, eval_set=[(X, y)])
await classifier.fit(X, y, eval_set=[(X, y)])
prediction = await classifier.predict(X)

assert prediction.ndim == 1
Expand All @@ -447,7 +446,6 @@ async def run_dask_classifier_asyncio(scheduler_address):
assert probas.shape[0] == kRows
assert probas.shape[1] == 10


# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
Expand All @@ -472,6 +470,28 @@ def test_with_asyncio():
asyncio.run(run_dask_classifier_asyncio(address))


def test_predict():
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = generate_array()
dtrain = DaskDMatrix(client, X, y)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=2)['booster']

pred = xgb.dask.predict(client, model=booster, data=dtrain)
assert pred.ndim == 1
assert pred.shape[0] == kRows

margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
assert margin.ndim == 1
assert margin.shape[0] == kRows

shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
assert shap.ndim == 2
assert shap.shape[0] == kRows
assert shap.shape[1] == kCols + 1


class TestWithDask:
def run_updater_test(self, client, params, num_rounds, dataset,
tree_method):
Expand All @@ -489,9 +509,9 @@ def run_updater_test(self, client, params, num_rounds, dataset,
chunk = 128
X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, ))
y = da.from_array(dataset.y, chunks=(chunk,))
if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, ))
w = da.from_array(dataset.w, chunks=(chunk,))
else:
w = None

Expand Down