From b17b423b4ed2def85c630e310fba3ca21a9710d0 Mon Sep 17 00:00:00 2001 From: Kyle Nicholson Date: Mon, 24 Aug 2020 17:00:01 -0400 Subject: [PATCH 1/8] add base-margin and test --- python-package/xgboost/dask.py | 59 +++++++++++++++++++++------------- tests/python/test_with_dask.py | 42 ++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 22 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index d6f50bd0f80f..724d0ec15bfe 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -983,6 +983,7 @@ class DaskScikitLearnBase(XGBModel): # pylint: disable=arguments-differ def fit(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): @@ -1043,12 +1044,14 @@ async def _fit_async(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): - dtrain = await DaskDMatrix(client=self.client, - data=X, label=y, weight=sample_weights, - missing=self.missing) + dtrain = await DaskDMatrix( + client=self.client, data=X, label=y, weight=sample_weights, + base_margin=base_margin, missing=self.missing + ) params = self.get_xgb_params() evals = await _evaluation_matrices(self.client, eval_set, sample_weight_eval_set, @@ -1064,17 +1067,20 @@ async def _fit_async(self, # pylint: disable=missing-docstring def fit(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): _assert_dask_support() - return self.client.sync(self._fit_async, X, y, sample_weights, - eval_set, sample_weight_eval_set, - verbose) - - async def _predict_async(self, data): # pylint: disable=arguments-differ - test_dmatrix = await DaskDMatrix(client=self.client, data=data, - missing=self.missing) + return self.client.sync( + self._fit_async, X, y, sample_weights, base_margin, + eval_set, sample_weight_eval_set, verbose + ) + + async def _predict_async(self, data, base_margin=None): # pylint: disable=arguments-differ + test_dmatrix = await DaskDMatrix( + client=self.client, data=data, base_margin=base_margin, missing=self.missing + ) pred_probs = await predict(client=self.client, model=self.get_booster(), data=test_dmatrix) return pred_probs @@ -1091,11 +1097,13 @@ def predict(self, data): class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): async def _fit_async(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): dtrain = await DaskDMatrix(client=self.client, data=X, label=y, weight=sample_weights, + base_margin=base_margin, missing=self.missing) params = self.get_xgb_params() @@ -1125,31 +1133,38 @@ async def _fit_async(self, X, y, def fit(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): _assert_dask_support() - return self.client.sync(self._fit_async, X, y, sample_weights, - eval_set, sample_weight_eval_set, verbose) + return self.client.sync( + self._fit_async, X, y, sample_weights, base_margin, eval_set, + sample_weight_eval_set, verbose + ) - async def _predict_proba_async(self, data): + async def _predict_proba_async(self, data, base_margin=None): _assert_dask_support() - test_dmatrix = await DaskDMatrix(client=self.client, data=data, - missing=self.missing) + test_dmatrix = await DaskDMatrix( + client=self.client, data=data,base_margin=base_margin, + missing=self.missing + ) pred_probs = await predict(client=self.client, model=self.get_booster(), data=test_dmatrix) return pred_probs - def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring + def predict_proba(self, data, base_margin=None): # pylint: disable=arguments-differ,missing-docstring _assert_dask_support() - return self.client.sync(self._predict_proba_async, data) + return self.client.sync(self._predict_proba_async, data, base_margin) - async def _predict_async(self, data): + async def _predict_async(self, data, base_margin=None): _assert_dask_support() - test_dmatrix = await DaskDMatrix(client=self.client, data=data, - missing=self.missing) + test_dmatrix = await DaskDMatrix( + client=self.client, data=data, base_margin=base_margin, + missing=self.missing + ) pred_probs = await predict(client=self.client, model=self.get_booster(), data=test_dmatrix) @@ -1160,6 +1175,6 @@ async def _predict_async(self, data): return preds - def predict(self, data): # pylint: disable=arguments-differ + def predict(self, data, base_margin=None): # pylint: disable=arguments-differ _assert_dask_support() - return self.client.sync(self._predict_async, data) + return self.client.sync(self._predict_async, data, base_margin) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 15d7f75127a3..63cd8aea5c3c 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -133,6 +133,48 @@ def test_dask_predict_shape_infer(): assert preds.shape[1] == preds.compute().shape[1] +@pytest.mark.parametrize("tree_method", ["hist", "approx"]) +def test_boost_from_prediction(tree_method): + from sklearn.datasets import load_breast_cancer + + with LocalCluster(n_workers=4) as cluster: + with Client(cluster) as client: + X, y = load_breast_cancer(return_X_y=True) + X_ = dd.from_array(X, chunksize=100) + y_ = dd.from_array(y, chunksize=100) + + from sklearn.datasets import load_breast_cancer + + X, y = load_breast_cancer(return_X_y=True) + model_0 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=0, + n_estimators=4, + tree_method=tree_method, + ) + model_0.fit(X=X_, y=y_) + margin = model_0.predict_proba(X_) + + model_1 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=0, + n_estimators=4, + tree_method=tree_method, + ) + model_1.fit(X=X_, y=y_, base_margin=margin) + predictions_1 = model_1.predict(X_, base_margin=margin) + + cls_2 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=0, + n_estimators=8, + tree_method=tree_method, + ) + cls_2.fit(X=X_, y=y_) + predictions_2 = cls_2.predict(X_) + np.testing.assert_equal(predictions_1.compute(), predictions_2.compute()) + + def test_dask_missing_value_reg(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: From b89d44eeed35548ffbac0352c3cbf88fa7bc2d91 Mon Sep 17 00:00:00 2001 From: Kyle Nicholson Date: Tue, 25 Aug 2020 16:42:05 -0400 Subject: [PATCH 2/8] format --- python-package/xgboost/dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 724d0ec15bfe..0e7aa3bb3759 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1147,7 +1147,7 @@ async def _predict_proba_async(self, data, base_margin=None): _assert_dask_support() test_dmatrix = await DaskDMatrix( - client=self.client, data=data,base_margin=base_margin, + client=self.client, data=data, base_margin=base_margin, missing=self.missing ) pred_probs = await predict(client=self.client, From 366fae142419c1fab747052bbd30b518c44d8b80 Mon Sep 17 00:00:00 2001 From: Kyle Nicholson Date: Wed, 16 Sep 2020 18:10:38 -0400 Subject: [PATCH 3/8] add _get_worker_parts_ordered --- python-package/xgboost/dask.py | 67 ++++++++++++++++++++++------------ tests/python/test_with_dask.py | 15 ++++++-- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 0e7aa3bb3759..e4db16222410 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -337,14 +337,22 @@ def create_fn_args(self): 'is_quantile': self.is_quantile} -def _get_worker_x_ordered(worker_map, partition_order, worker): +def _get_worker_parts_ordered(has_base_margin, worker_map, + partition_order, worker): list_of_parts = worker_map[worker.address] client = get_client() list_of_parts_value = client.gather(list_of_parts) + result = [] + for i, part in enumerate(list_of_parts): - result.append((list_of_parts_value[i][0], - partition_order[part.key])) + data = list_of_parts_value[i][0] + if has_base_margin: + base_margin = list_of_parts_value[i][1] + else: + base_margin = None + result.append((data, base_margin, partition_order[part.key])) + return result @@ -740,9 +748,8 @@ async def _direct_predict_impl(client, data, predict_fn): # pylint: disable=too-many-statements -async def _predict_async(client: Client, model, data, missing=numpy.nan, - **kwargs): - +async def _predict_async(client: Client, model, data, output_margin, + missing=numpy.nan, **kwargs): if isinstance(model, Booster): booster = model elif isinstance(model, dict): @@ -775,21 +782,29 @@ def mapped_predict(partition, is_df): feature_names = data.feature_names feature_types = data.feature_types missing = data.missing + has_base_margin = data.has_base_margin def dispatched_predict(worker_id): '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) + worker = distributed_get_worker() - list_of_parts = _get_worker_x_ordered(worker_map, partition_order, - worker) + list_of_parts = _get_worker_parts_ordered(has_base_margin, + worker_map, partition_order, worker + ) predictions = [] booster.set_param({'nthread': worker.nthreads}) - for part, order in list_of_parts: - local_x = DMatrix(part, feature_names=feature_names, - feature_types=feature_types, - missing=missing, nthread=worker.nthreads) - predt = booster.predict(data=local_x, - validate_features=local_x.num_row() != 0, + for data, base_margin, order in list_of_parts: + local_part = DMatrix( + data, + base_margin=base_margin, + feature_names=feature_names, + feature_types=feature_types, + missing=missing, + nthread=worker.nthreads + ) + predt = booster.predict(data=local_part, + validate_features=local_part.num_row() != 0, **kwargs) columns = 1 if len(predt.shape) == 1 else predt.shape[1] ret = ((delayed(predt), columns), order) @@ -800,9 +815,9 @@ def dispatched_get_shape(worker_id): '''Get shape of data in each worker.''' LOGGER.info('Get shape on %d', worker_id) worker = distributed_get_worker() - list_of_parts = _get_worker_x_ordered(worker_map, + list_of_parts = _get_worker_parts_ordered(False, worker_map, partition_order, worker) - shapes = [(part.shape, order) for part, order in list_of_parts] + shapes = [(part.shape, order) for part, _, order in list_of_parts] return shapes async def map_function(func): @@ -1143,7 +1158,7 @@ def fit(self, X, y, sample_weight_eval_set, verbose ) - async def _predict_proba_async(self, data, base_margin=None): + async def _predict_proba_async(self, data, output_margin=False, base_margin=None): _assert_dask_support() test_dmatrix = await DaskDMatrix( @@ -1151,14 +1166,16 @@ async def _predict_proba_async(self, data, base_margin=None): missing=self.missing ) pred_probs = await predict(client=self.client, - model=self.get_booster(), data=test_dmatrix) + model=self.get_booster(), + data=test_dmatrix, + output_margin=output_margin) return pred_probs - def predict_proba(self, data, base_margin=None): # pylint: disable=arguments-differ,missing-docstring + def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring _assert_dask_support() - return self.client.sync(self._predict_proba_async, data, base_margin) + return self.client.sync(self._predict_proba_async, data, output_margin, base_margin) - async def _predict_async(self, data, base_margin=None): + async def _predict_async(self, data, output_margin=False, base_margin=None): _assert_dask_support() test_dmatrix = await DaskDMatrix( @@ -1166,7 +1183,9 @@ async def _predict_async(self, data, base_margin=None): missing=self.missing ) pred_probs = await predict(client=self.client, - model=self.get_booster(), data=test_dmatrix) + model=self.get_booster(), + data=test_dmatrix, + output_margin=output_margin) if self.n_classes_ == 2: preds = (pred_probs > 0.5).astype(int) @@ -1175,6 +1194,6 @@ async def _predict_async(self, data, base_margin=None): return preds - def predict(self, data, base_margin=None): # pylint: disable=arguments-differ + def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ _assert_dask_support() - return self.client.sync(self._predict_async, data, base_margin) + return self.client.sync(self._predict_async, data, output_margin, base_margin) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 63cd8aea5c3c..d69546089f3d 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -148,32 +148,39 @@ def test_boost_from_prediction(tree_method): X, y = load_breast_cancer(return_X_y=True) model_0 = xgb.dask.DaskXGBClassifier( learning_rate=0.3, - random_state=0, + random_state=123, n_estimators=4, tree_method=tree_method, ) model_0.fit(X=X_, y=y_) - margin = model_0.predict_proba(X_) + margin = model_0.predict_proba(X_, output_margin=True) model_1 = xgb.dask.DaskXGBClassifier( learning_rate=0.3, - random_state=0, + random_state=123, n_estimators=4, tree_method=tree_method, ) model_1.fit(X=X_, y=y_, base_margin=margin) predictions_1 = model_1.predict(X_, base_margin=margin) + proba_1 = model_1.predict_proba(X_, base_margin=margin) cls_2 = xgb.dask.DaskXGBClassifier( learning_rate=0.3, - random_state=0, + random_state=123, n_estimators=8, tree_method=tree_method, ) cls_2.fit(X=X_, y=y_) predictions_2 = cls_2.predict(X_) + proba_2 = cls_2.predict_proba(X_) + np.testing.assert_equal(predictions_1.compute(), predictions_2.compute()) + # This won't pass for approx + if tree_method != "approx": + np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute()) + def test_dask_missing_value_reg(): with LocalCluster(n_workers=kWorkers) as cluster: From f5fdce1efb91fac1f36a6550ab64d7a0b83ffeb5 Mon Sep 17 00:00:00 2001 From: Kyle Nicholson Date: Thu, 17 Sep 2020 10:18:02 -0400 Subject: [PATCH 4/8] if same model varies, use variance as check --- tests/python/test_with_dask.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index d69546089f3d..cb1f2804ef53 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -175,10 +175,25 @@ def test_boost_from_prediction(tree_method): predictions_2 = cls_2.predict(X_) proba_2 = cls_2.predict_proba(X_) - np.testing.assert_equal(predictions_1.compute(), predictions_2.compute()) - - # This won't pass for approx - if tree_method != "approx": + cls_3 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=123, + n_estimators=8, + tree_method=tree_method, + ) + cls_3.fit(X=X_, y=y_) + predictions_3 = cls_3.predict(X_) + proba_3 = cls_3.predict_proba(X_) + + # compute variance between two of the same model, use this to check + # to make sure approx is functioning within normal parameters + variance = np.max(np.abs(proba_3 - proba_2)).compute() + + if variance > 0: + print("variance > 0") + assert np.all(np.abs(proba_2 - proba_1) <= variance) + else: + np.testing.assert_equal(predictions_1.compute(), predictions_2.compute()) np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute()) From 39b4c45bdb017dc694828be474728ea78cfd1fb8 Mon Sep 17 00:00:00 2001 From: Kyle Nicholson Date: Wed, 23 Sep 2020 12:30:33 -0400 Subject: [PATCH 5/8] fix old test with updates from master --- python-package/xgboost/dask.py | 22 ++++++++++++++++------ tests/python/test_with_dask.py | 25 ++++++++++++------------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index e4db16222410..d5211bef7b63 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -748,8 +748,7 @@ async def _direct_predict_impl(client, data, predict_fn): # pylint: disable=too-many-statements -async def _predict_async(client: Client, model, data, output_margin, - missing=numpy.nan, **kwargs): +async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs): if isinstance(model, Booster): booster = model elif isinstance(model, dict): @@ -782,14 +781,14 @@ def mapped_predict(partition, is_df): feature_names = data.feature_names feature_types = data.feature_types missing = data.missing - has_base_margin = data.has_base_margin + has_margin = "base_margin" in data.meta_names def dispatched_predict(worker_id): '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) worker = distributed_get_worker() - list_of_parts = _get_worker_parts_ordered(has_base_margin, + list_of_parts = _get_worker_parts_ordered(has_margin, worker_map, partition_order, worker ) predictions = [] @@ -881,6 +880,7 @@ def predict(client, model, data, missing=numpy.nan, **kwargs): ''' _assert_dask_support() client = _xgb_get_client(client) + LOGGER.warning(kwargs) return client.sync(_predict_async, client, model, data, missing=missing, **kwargs) @@ -1173,7 +1173,12 @@ async def _predict_proba_async(self, data, output_margin=False, base_margin=None def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring _assert_dask_support() - return self.client.sync(self._predict_proba_async, data, output_margin, base_margin) + return self.client.sync( + self._predict_proba_async, + data, + output_margin=output_margin, + base_margin=base_margin + ) async def _predict_async(self, data, output_margin=False, base_margin=None): _assert_dask_support() @@ -1196,4 +1201,9 @@ async def _predict_async(self, data, output_margin=False, base_margin=None): def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ _assert_dask_support() - return self.client.sync(self._predict_async, data, output_margin, base_margin) + return self.client.sync( + self._predict_async, + data, + output_margin=output_margin, + base_margin=base_margin + ) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index cb1f2804ef53..e3c4bdc8c74e 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -136,16 +136,13 @@ def test_dask_predict_shape_infer(): @pytest.mark.parametrize("tree_method", ["hist", "approx"]) def test_boost_from_prediction(tree_method): from sklearn.datasets import load_breast_cancer + X, y = load_breast_cancer(return_X_y=True) + + X_ = dd.from_array(X, chunksize=100) + y_ = dd.from_array(y, chunksize=100) with LocalCluster(n_workers=4) as cluster: with Client(cluster) as client: - X, y = load_breast_cancer(return_X_y=True) - X_ = dd.from_array(X, chunksize=100) - y_ = dd.from_array(y, chunksize=100) - - from sklearn.datasets import load_breast_cancer - - X, y = load_breast_cancer(return_X_y=True) model_0 = xgb.dask.DaskXGBClassifier( learning_rate=0.3, random_state=123, @@ -185,13 +182,15 @@ def test_boost_from_prediction(tree_method): predictions_3 = cls_3.predict(X_) proba_3 = cls_3.predict_proba(X_) - # compute variance between two of the same model, use this to check - # to make sure approx is functioning within normal parameters - variance = np.max(np.abs(proba_3 - proba_2)).compute() + # compute variance of probability percentages between two of the + # same model, use this to check to make sure approx is functioning + # within normal parameters + expected_variance = np.max(np.abs(proba_3 - proba_2)).compute() - if variance > 0: - print("variance > 0") - assert np.all(np.abs(proba_2 - proba_1) <= variance) + if expected_variance > 0: + margin_variance = np.max(np.abs(proba_1 - proba_2)).compute() + # Ensure the margin variance is less than the expected variance + 10% + assert np.all(margin_variance <= expected_variance + .1) else: np.testing.assert_equal(predictions_1.compute(), predictions_2.compute()) np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute()) From 59e0e2130859e15d3c6e76a67881f7eec0955cc4 Mon Sep 17 00:00:00 2001 From: Kyle Nicholson Date: Wed, 23 Sep 2020 12:33:34 -0400 Subject: [PATCH 6/8] remove errant logger warning --- python-package/xgboost/dask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index d5211bef7b63..1408e618fd9a 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -880,7 +880,6 @@ def predict(client, model, data, missing=numpy.nan, **kwargs): ''' _assert_dask_support() client = _xgb_get_client(client) - LOGGER.warning(kwargs) return client.sync(_predict_async, client, model, data, missing=missing, **kwargs) From e87678edfdbdcfd8693b5cc209d019299a307730 Mon Sep 17 00:00:00 2001 From: Kyle Nicholson Date: Thu, 24 Sep 2020 09:57:40 -0400 Subject: [PATCH 7/8] formatting --- python-package/xgboost/dask.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 1408e618fd9a..49159e903313 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -814,8 +814,12 @@ def dispatched_get_shape(worker_id): '''Get shape of data in each worker.''' LOGGER.info('Get shape on %d', worker_id) worker = distributed_get_worker() - list_of_parts = _get_worker_parts_ordered(False, worker_map, - partition_order, worker) + list_of_parts = _get_worker_parts_ordered( + False, + worker_map, + partition_order, + worker + ) shapes = [(part.shape, order) for part, _, order in list_of_parts] return shapes From 972cd178bea70625734f59fdf0bad364c6462eaa Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Sep 2020 18:54:45 +0800 Subject: [PATCH 8/8] Add `output_margin` to regressor. * Address format. * Remove unused var. --- python-package/xgboost/dask.py | 39 ++++++++++++++++++---------------- tests/python/test_with_dask.py | 1 - 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 49159e903313..4b500511d96a 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -337,8 +337,8 @@ def create_fn_args(self): 'is_quantile': self.is_quantile} -def _get_worker_parts_ordered(has_base_margin, worker_map, - partition_order, worker): +def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order, + worker): list_of_parts = worker_map[worker.address] client = get_client() list_of_parts_value = client.gather(list_of_parts) @@ -788,9 +788,8 @@ def dispatched_predict(worker_id): LOGGER.info('Predicting on %d', worker_id) worker = distributed_get_worker() - list_of_parts = _get_worker_parts_ordered(has_margin, - worker_map, partition_order, worker - ) + list_of_parts = _get_worker_parts_ordered( + has_margin, worker_map, partition_order, worker) predictions = [] booster.set_param({'nthread': worker.nthreads}) for data, base_margin, order in list_of_parts: @@ -802,9 +801,10 @@ def dispatched_predict(worker_id): missing=missing, nthread=worker.nthreads ) - predt = booster.predict(data=local_part, - validate_features=local_part.num_row() != 0, - **kwargs) + predt = booster.predict( + data=local_part, + validate_features=local_part.num_row() != 0, + **kwargs) columns = 1 if len(predt.shape) == 1 else predt.shape[1] ret = ((delayed(predt), columns), order) predictions.append(ret) @@ -1095,17 +1095,23 @@ def fit(self, X, y, eval_set, sample_weight_eval_set, verbose ) - async def _predict_async(self, data, base_margin=None): # pylint: disable=arguments-differ + async def _predict_async( + self, data, output_margin=False, base_margin=None): test_dmatrix = await DaskDMatrix( - client=self.client, data=data, base_margin=base_margin, missing=self.missing + client=self.client, data=data, base_margin=base_margin, + missing=self.missing ) pred_probs = await predict(client=self.client, - model=self.get_booster(), data=test_dmatrix) + model=self.get_booster(), data=test_dmatrix, + output_margin=output_margin) return pred_probs - def predict(self, data): + # pylint: disable=arguments-differ + def predict(self, data, output_margin=False, base_margin=None): _assert_dask_support() - return self.client.sync(self._predict_async, data) + return self.client.sync(self._predict_async, data, + output_margin=output_margin, + base_margin=base_margin) @xgboost_model_doc( @@ -1161,9 +1167,8 @@ def fit(self, X, y, sample_weight_eval_set, verbose ) - async def _predict_proba_async(self, data, output_margin=False, base_margin=None): - _assert_dask_support() - + async def _predict_proba_async(self, data, output_margin=False, + base_margin=None): test_dmatrix = await DaskDMatrix( client=self.client, data=data, base_margin=base_margin, missing=self.missing @@ -1184,8 +1189,6 @@ def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: ) async def _predict_async(self, data, output_margin=False, base_margin=None): - _assert_dask_support() - test_dmatrix = await DaskDMatrix( client=self.client, data=data, base_margin=base_margin, missing=self.missing diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index e3c4bdc8c74e..889538285dc4 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -179,7 +179,6 @@ def test_boost_from_prediction(tree_method): tree_method=tree_method, ) cls_3.fit(X=X_, y=y_) - predictions_3 = cls_3.predict(X_) proba_3 = cls_3.predict_proba(X_) # compute variance of probability percentages between two of the