Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update base margin dask #6155

Merged
merged 8 commits into from Sep 26, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
129 changes: 88 additions & 41 deletions python-package/xgboost/dask.py
Expand Up @@ -337,14 +337,22 @@ def create_fn_args(self):
'is_quantile': self.is_quantile}


def _get_worker_x_ordered(worker_map, partition_order, worker):
def _get_worker_parts_ordered(has_base_margin, worker_map,
partition_order, worker):
list_of_parts = worker_map[worker.address]
client = get_client()
list_of_parts_value = client.gather(list_of_parts)

result = []

for i, part in enumerate(list_of_parts):
result.append((list_of_parts_value[i][0],
partition_order[part.key]))
data = list_of_parts_value[i][0]
if has_base_margin:
base_margin = list_of_parts_value[i][1]
else:
base_margin = None
result.append((data, base_margin, partition_order[part.key]))

return result


Expand Down Expand Up @@ -740,9 +748,7 @@ async def _direct_predict_impl(client, data, predict_fn):


# pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, missing=numpy.nan,
**kwargs):

async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
Expand Down Expand Up @@ -775,21 +781,29 @@ def mapped_predict(partition, is_df):
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
has_margin = "base_margin" in data.meta_names

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)

worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map, partition_order,
worker)
list_of_parts = _get_worker_parts_ordered(has_margin,
worker_map, partition_order, worker
)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts:
local_x = DMatrix(part, feature_names=feature_names,
feature_types=feature_types,
missing=missing, nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
for data, base_margin, order in list_of_parts:
local_part = DMatrix(
data,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads
)
predt = booster.predict(data=local_part,
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order)
Expand All @@ -800,9 +814,13 @@ def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map,
partition_order, worker)
shapes = [(part.shape, order) for part, order in list_of_parts]
list_of_parts = _get_worker_parts_ordered(
False,
worker_map,
partition_order,
worker
)
shapes = [(part.shape, order) for part, _, order in list_of_parts]
return shapes

async def map_function(func):
Expand Down Expand Up @@ -983,6 +1001,7 @@ class DaskScikitLearnBase(XGBModel):
# pylint: disable=arguments-differ
def fit(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
Expand Down Expand Up @@ -1043,12 +1062,14 @@ async def _fit_async(self,
X,
y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
dtrain = await DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights,
missing=self.missing)
dtrain = await DaskDMatrix(
client=self.client, data=X, label=y, weight=sample_weights,
base_margin=base_margin, missing=self.missing
)
params = self.get_xgb_params()
evals = await _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set,
Expand All @@ -1064,17 +1085,20 @@ async def _fit_async(self,
# pylint: disable=missing-docstring
def fit(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
_assert_dask_support()
return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set,
verbose)

async def _predict_async(self, data): # pylint: disable=arguments-differ
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
return self.client.sync(
self._fit_async, X, y, sample_weights, base_margin,
eval_set, sample_weight_eval_set, verbose
)

async def _predict_async(self, data, base_margin=None): # pylint: disable=arguments-differ
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin, missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs
Expand All @@ -1091,11 +1115,13 @@ def predict(self, data):
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
async def _fit_async(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
dtrain = await DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights,
base_margin=base_margin,
missing=self.missing)
params = self.get_xgb_params()

Expand Down Expand Up @@ -1125,33 +1151,49 @@ async def _fit_async(self, X, y,

def fit(self, X, y,
sample_weights=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
_assert_dask_support()
return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set, verbose)
return self.client.sync(
self._fit_async, X, y, sample_weights, base_margin, eval_set,
sample_weight_eval_set, verbose
)

async def _predict_proba_async(self, data):
async def _predict_proba_async(self, data, output_margin=False, base_margin=None):
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin)
return pred_probs

def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring
def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring
_assert_dask_support()
return self.client.sync(self._predict_proba_async, data)
return self.client.sync(
self._predict_proba_async,
data,
output_margin=output_margin,
base_margin=base_margin
)

async def _predict_async(self, data):
async def _predict_async(self, data, output_margin=False, base_margin=None):
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin)

if self.n_classes_ == 2:
preds = (pred_probs > 0.5).astype(int)
Expand All @@ -1160,6 +1202,11 @@ async def _predict_async(self, data):

return preds

def predict(self, data): # pylint: disable=arguments-differ
def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ
_assert_dask_support()
return self.client.sync(self._predict_async, data)
return self.client.sync(
self._predict_async,
data,
output_margin=output_margin,
base_margin=base_margin
)
63 changes: 63 additions & 0 deletions tests/python/test_with_dask.py
Expand Up @@ -133,6 +133,69 @@ def test_dask_predict_shape_infer():
assert preds.shape[1] == preds.compute().shape[1]


@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)

X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100)

with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_0.fit(X=X_, y=y_)
margin = model_0.predict_proba(X_, output_margin=True)

model_1 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_1.fit(X=X_, y=y_, base_margin=margin)
predictions_1 = model_1.predict(X_, base_margin=margin)
proba_1 = model_1.predict_proba(X_, base_margin=margin)

cls_2 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_2.fit(X=X_, y=y_)
predictions_2 = cls_2.predict(X_)
proba_2 = cls_2.predict_proba(X_)

cls_3 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_3.fit(X=X_, y=y_)
predictions_3 = cls_3.predict(X_)
proba_3 = cls_3.predict_proba(X_)

# compute variance of probability percentages between two of the
# same model, use this to check to make sure approx is functioning
# within normal parameters
expected_variance = np.max(np.abs(proba_3 - proba_2)).compute()

if expected_variance > 0:
margin_variance = np.max(np.abs(proba_1 - proba_2)).compute()
# Ensure the margin variance is less than the expected variance + 10%
assert np.all(margin_variance <= expected_variance + .1)
else:
np.testing.assert_equal(predictions_1.compute(), predictions_2.compute())
np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute())


def test_dask_missing_value_reg():
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
Expand Down