Skip to content

Commit

Permalink
Update base margin dask (#6155)
Browse files Browse the repository at this point in the history
* Add `base-margin`
* Add `output_margin` to regressor.

Co-authored-by: fis <jm.yuan@outlook.com>
  • Loading branch information
kylejn27 and trivialfis committed Sep 26, 2020
1 parent 03b8fde commit e6a238c
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 51 deletions.
152 changes: 101 additions & 51 deletions python-package/xgboost/dask.py
Expand Up @@ -337,14 +337,22 @@ def create_fn_args(self):
'is_quantile': self.is_quantile}


def _get_worker_x_ordered(worker_map, partition_order, worker):
def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
worker):
list_of_parts = worker_map[worker.address]
client = get_client()
list_of_parts_value = client.gather(list_of_parts)

result = []

for i, part in enumerate(list_of_parts):
result.append((list_of_parts_value[i][0],
partition_order[part.key]))
data = list_of_parts_value[i][0]
if has_base_margin:
base_margin = list_of_parts_value[i][1]
else:
base_margin = None
result.append((data, base_margin, partition_order[part.key]))

return result


Expand Down Expand Up @@ -740,9 +748,7 @@ async def _direct_predict_impl(client, data, predict_fn):


# pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, missing=numpy.nan,
**kwargs):

async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
Expand Down Expand Up @@ -775,22 +781,30 @@ def mapped_predict(partition, is_df):
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
has_margin = "base_margin" in data.meta_names

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)

worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map, partition_order,
worker)
list_of_parts = _get_worker_parts_ordered(
has_margin, worker_map, partition_order, worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts:
local_x = DMatrix(part, feature_names=feature_names,
feature_types=feature_types,
missing=missing, nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
**kwargs)
for data, base_margin, order in list_of_parts:
local_part = DMatrix(
data,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads
)
predt = booster.predict(
data=local_part,
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order)
predictions.append(ret)
Expand All @@ -800,9 +814,13 @@ def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map,
partition_order, worker)
shapes = [(part.shape, order) for part, order in list_of_parts]
list_of_parts = _get_worker_parts_ordered(
False,
worker_map,
partition_order,
worker
)
shapes = [(part.shape, order) for part, _, order in list_of_parts]
return shapes

async def map_function(func):
Expand Down Expand Up @@ -984,6 +1002,7 @@ class DaskScikitLearnBase(XGBModel):
# pylint: disable=arguments-differ
def fit(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
Expand Down Expand Up @@ -1044,12 +1063,14 @@ async def _fit_async(self,
X,
y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
dtrain = await DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights,
missing=self.missing)
dtrain = await DaskDMatrix(
client=self.client, data=X, label=y, weight=sample_weights,
base_margin=base_margin, missing=self.missing
)
params = self.get_xgb_params()
evals = await _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set,
Expand All @@ -1065,24 +1086,33 @@ async def _fit_async(self,
# pylint: disable=missing-docstring
def fit(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
_assert_dask_support()
return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set,
verbose)

async def _predict_async(self, data): # pylint: disable=arguments-differ
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
return self.client.sync(
self._fit_async, X, y, sample_weights, base_margin,
eval_set, sample_weight_eval_set, verbose
)

async def _predict_async(
self, data, output_margin=False, base_margin=None):
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
model=self.get_booster(), data=test_dmatrix,
output_margin=output_margin)
return pred_probs

def predict(self, data):
# pylint: disable=arguments-differ
def predict(self, data, output_margin=False, base_margin=None):
_assert_dask_support()
return self.client.sync(self._predict_async, data)
return self.client.sync(self._predict_async, data,
output_margin=output_margin,
base_margin=base_margin)


@xgboost_model_doc(
Expand All @@ -1092,11 +1122,13 @@ def predict(self, data):
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
async def _fit_async(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
dtrain = await DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights,
base_margin=base_margin,
missing=self.missing)
params = self.get_xgb_params()

Expand Down Expand Up @@ -1126,33 +1158,46 @@ async def _fit_async(self, X, y,

def fit(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
_assert_dask_support()
return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set, verbose)

async def _predict_proba_async(self, data):
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
return self.client.sync(
self._fit_async, X, y, sample_weights, base_margin, eval_set,
sample_weight_eval_set, verbose
)

async def _predict_proba_async(self, data, output_margin=False,
base_margin=None):
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin)
return pred_probs

def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring
_assert_dask_support()
return self.client.sync(self._predict_proba_async, data)

async def _predict_async(self, data):
def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
return self.client.sync(
self._predict_proba_async,
data,
output_margin=output_margin,
base_margin=base_margin
)

async def _predict_async(self, data, output_margin=False, base_margin=None):
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin)

if self.n_classes_ == 2:
preds = (pred_probs > 0.5).astype(int)
Expand All @@ -1161,6 +1206,11 @@ async def _predict_async(self, data):

return preds

def predict(self, data): # pylint: disable=arguments-differ
def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ
_assert_dask_support()
return self.client.sync(self._predict_async, data)
return self.client.sync(
self._predict_async,
data,
output_margin=output_margin,
base_margin=base_margin
)
62 changes: 62 additions & 0 deletions tests/python/test_with_dask.py
Expand Up @@ -133,6 +133,68 @@ def test_dask_predict_shape_infer():
assert preds.shape[1] == preds.compute().shape[1]


@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)

X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100)

with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_0.fit(X=X_, y=y_)
margin = model_0.predict_proba(X_, output_margin=True)

model_1 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_1.fit(X=X_, y=y_, base_margin=margin)
predictions_1 = model_1.predict(X_, base_margin=margin)
proba_1 = model_1.predict_proba(X_, base_margin=margin)

cls_2 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_2.fit(X=X_, y=y_)
predictions_2 = cls_2.predict(X_)
proba_2 = cls_2.predict_proba(X_)

cls_3 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_3.fit(X=X_, y=y_)
proba_3 = cls_3.predict_proba(X_)

# compute variance of probability percentages between two of the
# same model, use this to check to make sure approx is functioning
# within normal parameters
expected_variance = np.max(np.abs(proba_3 - proba_2)).compute()

if expected_variance > 0:
margin_variance = np.max(np.abs(proba_1 - proba_2)).compute()
# Ensure the margin variance is less than the expected variance + 10%
assert np.all(margin_variance <= expected_variance + .1)
else:
np.testing.assert_equal(predictions_1.compute(), predictions_2.compute())
np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute())


def test_dask_missing_value_reg():
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
Expand Down

0 comments on commit e6a238c

Please sign in to comment.