Skip to content

Commit

Permalink
[Breaking] Fix .predict() method and add .predict_proba() in xgboost.…
Browse files Browse the repository at this point in the history
…dask.DaskXGBClassifier (#5986)
  • Loading branch information
jameskrach committed Aug 11, 2020
1 parent 6f7112a commit bd6b7f4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
23 changes: 22 additions & 1 deletion python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,13 +1079,34 @@ def fit(self, X, y,
return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set, verbose)

async def _predict_async(self, data):
async def _predict_proba_async(self, data):
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs

def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring
_assert_dask_support()
return self.client.sync(self._predict_proba_async, data)

async def _predict_async(self, data):
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)

if self.n_classes_ == 2:
preds = (pred_probs > 0.5).astype(int)
else:
preds = da.argmax(pred_probs, axis=1)

return preds

def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support()
return self.client.sync(self._predict_async, data)
32 changes: 24 additions & 8 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ def test_dask_missing_value_cls():
missing=0.0)
cls.client = client
cls.fit(X, y, eval_set=[(X, y)])
dd_predt = cls.predict(X).compute()
dd_pred_proba = cls.predict_proba(X).compute()

np_X = X.compute()
np_predt = cls.get_booster().predict(
np_pred_proba = cls.get_booster().predict(
xgb.DMatrix(np_X, missing=0.0))
np.testing.assert_allclose(np_predt, dd_predt)
np.testing.assert_allclose(np_pred_proba, dd_pred_proba)

cls = xgb.dask.DaskXGBClassifier()
assert hasattr(cls, 'missing')
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_dask_classifier():
classifier.fit(X, y, eval_set=[(X, y)])
prediction = classifier.predict(X)

assert prediction.ndim == 2
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

history = classifier.evals_result()
Expand All @@ -222,7 +222,18 @@ def test_dask_classifier():
assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2

# Test .predict_proba()
probas = classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10

cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute())

np.testing.assert_allclose(single_node_proba,
probas.compute())

# Test with dataframe.
X_d = dd.from_dask_array(X)
Expand All @@ -232,7 +243,7 @@ def test_dask_classifier():
assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d)

assert prediction.ndim == 2
assert prediction.ndim == 1
assert prediction.shape[0] == kRows


Expand Down Expand Up @@ -407,7 +418,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
await classifier.fit(X, y, eval_set=[(X, y)])
prediction = await classifier.predict(X)

assert prediction.ndim == 2
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

history = classifier.evals_result()
Expand All @@ -420,7 +431,13 @@ async def run_dask_classifier_asyncio(scheduler_address):
assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2

# Test .predict_proba()
probas = await classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10


# Test with dataframe.
X_d = dd.from_dask_array(X)
Expand All @@ -430,9 +447,8 @@ async def run_dask_classifier_asyncio(scheduler_address):
assert classifier.n_classes_ == 10
prediction = await classifier.predict(X_d)

assert prediction.ndim == 2
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
assert prediction.shape[1] == 10


def test_with_asyncio():
Expand Down

0 comments on commit bd6b7f4

Please sign in to comment.