diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index d6f50bd0f80f..4b500511d96a 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -337,14 +337,22 @@ def create_fn_args(self): 'is_quantile': self.is_quantile} -def _get_worker_x_ordered(worker_map, partition_order, worker): +def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order, + worker): list_of_parts = worker_map[worker.address] client = get_client() list_of_parts_value = client.gather(list_of_parts) + result = [] + for i, part in enumerate(list_of_parts): - result.append((list_of_parts_value[i][0], - partition_order[part.key])) + data = list_of_parts_value[i][0] + if has_base_margin: + base_margin = list_of_parts_value[i][1] + else: + base_margin = None + result.append((data, base_margin, partition_order[part.key])) + return result @@ -740,9 +748,7 @@ async def _direct_predict_impl(client, data, predict_fn): # pylint: disable=too-many-statements -async def _predict_async(client: Client, model, data, missing=numpy.nan, - **kwargs): - +async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs): if isinstance(model, Booster): booster = model elif isinstance(model, dict): @@ -775,22 +781,30 @@ def mapped_predict(partition, is_df): feature_names = data.feature_names feature_types = data.feature_types missing = data.missing + has_margin = "base_margin" in data.meta_names def dispatched_predict(worker_id): '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) + worker = distributed_get_worker() - list_of_parts = _get_worker_x_ordered(worker_map, partition_order, - worker) + list_of_parts = _get_worker_parts_ordered( + has_margin, worker_map, partition_order, worker) predictions = [] booster.set_param({'nthread': worker.nthreads}) - for part, order in list_of_parts: - local_x = DMatrix(part, feature_names=feature_names, - feature_types=feature_types, - missing=missing, nthread=worker.nthreads) - predt = booster.predict(data=local_x, - validate_features=local_x.num_row() != 0, - **kwargs) + for data, base_margin, order in list_of_parts: + local_part = DMatrix( + data, + base_margin=base_margin, + feature_names=feature_names, + feature_types=feature_types, + missing=missing, + nthread=worker.nthreads + ) + predt = booster.predict( + data=local_part, + validate_features=local_part.num_row() != 0, + **kwargs) columns = 1 if len(predt.shape) == 1 else predt.shape[1] ret = ((delayed(predt), columns), order) predictions.append(ret) @@ -800,9 +814,13 @@ def dispatched_get_shape(worker_id): '''Get shape of data in each worker.''' LOGGER.info('Get shape on %d', worker_id) worker = distributed_get_worker() - list_of_parts = _get_worker_x_ordered(worker_map, - partition_order, worker) - shapes = [(part.shape, order) for part, order in list_of_parts] + list_of_parts = _get_worker_parts_ordered( + False, + worker_map, + partition_order, + worker + ) + shapes = [(part.shape, order) for part, _, order in list_of_parts] return shapes async def map_function(func): @@ -983,6 +1001,7 @@ class DaskScikitLearnBase(XGBModel): # pylint: disable=arguments-differ def fit(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): @@ -1043,12 +1062,14 @@ async def _fit_async(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): - dtrain = await DaskDMatrix(client=self.client, - data=X, label=y, weight=sample_weights, - missing=self.missing) + dtrain = await DaskDMatrix( + client=self.client, data=X, label=y, weight=sample_weights, + base_margin=base_margin, missing=self.missing + ) params = self.get_xgb_params() evals = await _evaluation_matrices(self.client, eval_set, sample_weight_eval_set, @@ -1064,24 +1085,33 @@ async def _fit_async(self, # pylint: disable=missing-docstring def fit(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): _assert_dask_support() - return self.client.sync(self._fit_async, X, y, sample_weights, - eval_set, sample_weight_eval_set, - verbose) - - async def _predict_async(self, data): # pylint: disable=arguments-differ - test_dmatrix = await DaskDMatrix(client=self.client, data=data, - missing=self.missing) + return self.client.sync( + self._fit_async, X, y, sample_weights, base_margin, + eval_set, sample_weight_eval_set, verbose + ) + + async def _predict_async( + self, data, output_margin=False, base_margin=None): + test_dmatrix = await DaskDMatrix( + client=self.client, data=data, base_margin=base_margin, + missing=self.missing + ) pred_probs = await predict(client=self.client, - model=self.get_booster(), data=test_dmatrix) + model=self.get_booster(), data=test_dmatrix, + output_margin=output_margin) return pred_probs - def predict(self, data): + # pylint: disable=arguments-differ + def predict(self, data, output_margin=False, base_margin=None): _assert_dask_support() - return self.client.sync(self._predict_async, data) + return self.client.sync(self._predict_async, data, + output_margin=output_margin, + base_margin=base_margin) @xgboost_model_doc( @@ -1091,11 +1121,13 @@ def predict(self, data): class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): async def _fit_async(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): dtrain = await DaskDMatrix(client=self.client, data=X, label=y, weight=sample_weights, + base_margin=base_margin, missing=self.missing) params = self.get_xgb_params() @@ -1125,33 +1157,46 @@ async def _fit_async(self, X, y, def fit(self, X, y, sample_weights=None, + base_margin=None, eval_set=None, sample_weight_eval_set=None, verbose=True): _assert_dask_support() - return self.client.sync(self._fit_async, X, y, sample_weights, - eval_set, sample_weight_eval_set, verbose) - - async def _predict_proba_async(self, data): - _assert_dask_support() - - test_dmatrix = await DaskDMatrix(client=self.client, data=data, - missing=self.missing) + return self.client.sync( + self._fit_async, X, y, sample_weights, base_margin, eval_set, + sample_weight_eval_set, verbose + ) + + async def _predict_proba_async(self, data, output_margin=False, + base_margin=None): + test_dmatrix = await DaskDMatrix( + client=self.client, data=data, base_margin=base_margin, + missing=self.missing + ) pred_probs = await predict(client=self.client, - model=self.get_booster(), data=test_dmatrix) + model=self.get_booster(), + data=test_dmatrix, + output_margin=output_margin) return pred_probs - def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring - _assert_dask_support() - return self.client.sync(self._predict_proba_async, data) - - async def _predict_async(self, data): + def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring _assert_dask_support() - - test_dmatrix = await DaskDMatrix(client=self.client, data=data, - missing=self.missing) + return self.client.sync( + self._predict_proba_async, + data, + output_margin=output_margin, + base_margin=base_margin + ) + + async def _predict_async(self, data, output_margin=False, base_margin=None): + test_dmatrix = await DaskDMatrix( + client=self.client, data=data, base_margin=base_margin, + missing=self.missing + ) pred_probs = await predict(client=self.client, - model=self.get_booster(), data=test_dmatrix) + model=self.get_booster(), + data=test_dmatrix, + output_margin=output_margin) if self.n_classes_ == 2: preds = (pred_probs > 0.5).astype(int) @@ -1160,6 +1205,11 @@ async def _predict_async(self, data): return preds - def predict(self, data): # pylint: disable=arguments-differ + def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ _assert_dask_support() - return self.client.sync(self._predict_async, data) + return self.client.sync( + self._predict_async, + data, + output_margin=output_margin, + base_margin=base_margin + ) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 15d7f75127a3..889538285dc4 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -133,6 +133,68 @@ def test_dask_predict_shape_infer(): assert preds.shape[1] == preds.compute().shape[1] +@pytest.mark.parametrize("tree_method", ["hist", "approx"]) +def test_boost_from_prediction(tree_method): + from sklearn.datasets import load_breast_cancer + X, y = load_breast_cancer(return_X_y=True) + + X_ = dd.from_array(X, chunksize=100) + y_ = dd.from_array(y, chunksize=100) + + with LocalCluster(n_workers=4) as cluster: + with Client(cluster) as client: + model_0 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=123, + n_estimators=4, + tree_method=tree_method, + ) + model_0.fit(X=X_, y=y_) + margin = model_0.predict_proba(X_, output_margin=True) + + model_1 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=123, + n_estimators=4, + tree_method=tree_method, + ) + model_1.fit(X=X_, y=y_, base_margin=margin) + predictions_1 = model_1.predict(X_, base_margin=margin) + proba_1 = model_1.predict_proba(X_, base_margin=margin) + + cls_2 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=123, + n_estimators=8, + tree_method=tree_method, + ) + cls_2.fit(X=X_, y=y_) + predictions_2 = cls_2.predict(X_) + proba_2 = cls_2.predict_proba(X_) + + cls_3 = xgb.dask.DaskXGBClassifier( + learning_rate=0.3, + random_state=123, + n_estimators=8, + tree_method=tree_method, + ) + cls_3.fit(X=X_, y=y_) + proba_3 = cls_3.predict_proba(X_) + + # compute variance of probability percentages between two of the + # same model, use this to check to make sure approx is functioning + # within normal parameters + expected_variance = np.max(np.abs(proba_3 - proba_2)).compute() + + if expected_variance > 0: + margin_variance = np.max(np.abs(proba_1 - proba_2)).compute() + # Ensure the margin variance is less than the expected variance + 10% + assert np.all(margin_variance <= expected_variance + .1) + else: + np.testing.assert_equal(predictions_1.compute(), predictions_2.compute()) + np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute()) + + def test_dask_missing_value_reg(): with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: