Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Fix prediction on DaskDMatrix with multiple meta data. #6333

Merged
merged 3 commits into from Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 63 additions & 58 deletions python-package/xgboost/dask.py
Expand Up @@ -18,6 +18,7 @@
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread
from typing import List

import numpy

Expand Down Expand Up @@ -300,15 +301,21 @@ def append_meta(m_parts, name: str):
append_meta(margin_parts, 'base_margin')
append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound')
# At this point, `parts` looks like:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form

# delay the zipped result
parts = list(map(dask.delayed, zip(*parts)))
# At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form

parts = client.compute(parts)
await distributed.wait(parts) # async wait for parts to be computed

for part in parts:
assert part.status == 'finished'

# Preserving the partition order for prediction.
self.partition_order = {}
for i, part in enumerate(parts):
self.partition_order[part.key] = i
Expand Down Expand Up @@ -339,59 +346,55 @@ def create_fn_args(self):
'is_quantile': self.is_quantile}


def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
worker):
list_of_parts = worker_map[worker.address]
client = distributed.get_client()
list_of_parts_value = client.gather(list_of_parts)
def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
list_of_parts: List[tuple] = worker_map[worker.address]
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)
with distributed.worker_client() as client:
list_of_parts_value = client.gather(list_of_parts)

result = []
result = []

for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
if has_base_margin:
base_margin = list_of_parts_value[i][1]
else:
for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
labels = None
weights = None
base_margin = None
result.append((data, base_margin, partition_order[part.key]))
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])

if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[part.key]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result

return result

def _unzip(list_of_parts):
return list(zip(*list_of_parts))

def _get_worker_parts(worker_map, meta_names, worker):
'''Get mapped parts of data in each worker from DaskDMatrix.'''
list_of_parts = worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
assert isinstance(list_of_parts, list)

# `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`.
client = distributed.get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None

local_data = list(zip(*list_of_parts))
data = local_data[0]

for i, part in enumerate(local_data[1:]):
if meta_names[i] == 'labels':
labels = part
if meta_names[i] == 'weights':
weights = part
if meta_names[i] == 'base_margin':
base_margin = part
if meta_names[i] == 'label_lower_bound':
label_lower_bound = part
if meta_names[i] == 'label_upper_bound':
label_upper_bound = part

return (data, labels, weights, base_margin, label_lower_bound,
label_upper_bound)
def _get_worker_parts(worker_map, meta_names, worker):
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
partitions = _unzip(partitions)
return partitions


class DaskPartitionIter(DataIter): # pylint: disable=R0902
Expand Down Expand Up @@ -585,9 +588,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
return d

def concat_or_none(data):
if data is not None:
return concat(data)
return data
if all([part is None for part in data]):
return None
return concat(data)

(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
Expand Down Expand Up @@ -795,18 +798,19 @@ def mapped_predict(partition, is_df):
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
has_margin = "base_margin" in data.meta_names
meta_names = data.meta_names

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)

worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
has_margin, worker_map, partition_order, worker)
meta_names, worker_map, partition_order, worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for data, base_margin, order in list_of_parts:
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
local_part = DMatrix(
data,
base_margin=base_margin,
Expand All @@ -829,12 +833,15 @@ def dispatched_get_shape(worker_id):
LOGGER.info('Get shape on %d', worker_id)
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
False,
meta_names,
worker_map,
partition_order,
worker
)
shapes = [(part.shape, order) for part, _, order in list_of_parts]
shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
return shapes

async def map_function(func):
Expand Down Expand Up @@ -974,8 +981,7 @@ def inplace_predict(client, model, data,
missing=missing)


async def _evaluation_matrices(client, validation_set,
sample_weight, missing):
async def _evaluation_matrices(client, validation_set, sample_weight, missing):
'''
Parameters
----------
Expand All @@ -998,8 +1004,7 @@ async def _evaluation_matrices(client, validation_set,
if validation_set is not None:
assert isinstance(validation_set, list)
for i, e in enumerate(validation_set):
w = (sample_weight[i]
if sample_weight is not None else None)
w = (sample_weight[i] if sample_weight is not None else None)
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
weight=w, missing=missing)
evals.append((dmat, 'validation_{}'.format(i)))
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_with_dask.py
Expand Up @@ -566,6 +566,28 @@ def test_predict():
assert shap.shape[1] == kCols + 1


def test_predict_with_meta(client):
X, y, w = generate_array(with_weights=True)
partition_size = 20
margin = da.random.random(kRows, partition_size) + 1e4

dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=4)['booster']

prediction = xgb.dask.predict(client, model=booster, data=dtrain)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

prediction = client.compute(prediction).result()
assert np.all(prediction > 1e3)

m = xgb.DMatrix(X.compute())
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
single = booster.predict(m) # Make sure the ordering is correct.
assert np.all(prediction == single)


def run_aft_survival(client, dmatrix_t):
# survival doesn't handle empty dataset well.
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
Expand Down