Skip to content

Commit

Permalink
[dask] Fix prediction on DaskDMatrix with multiple meta data. (#6333)
Browse files Browse the repository at this point in the history
* Unify the meta handling methods.
  • Loading branch information
trivialfis committed Nov 3, 2020
1 parent 5a7b359 commit 7756192
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 58 deletions.
121 changes: 63 additions & 58 deletions python-package/xgboost/dask.py
Expand Up @@ -18,6 +18,7 @@
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread
from typing import List

import numpy

Expand Down Expand Up @@ -300,15 +301,21 @@ def append_meta(m_parts, name: str):
append_meta(margin_parts, 'base_margin')
append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound')
# At this point, `parts` looks like:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form

# delay the zipped result
parts = list(map(dask.delayed, zip(*parts)))
# At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form

parts = client.compute(parts)
await distributed.wait(parts) # async wait for parts to be computed

for part in parts:
assert part.status == 'finished'

# Preserving the partition order for prediction.
self.partition_order = {}
for i, part in enumerate(parts):
self.partition_order[part.key] = i
Expand Down Expand Up @@ -339,59 +346,55 @@ def create_fn_args(self):
'is_quantile': self.is_quantile}


def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
worker):
list_of_parts = worker_map[worker.address]
client = distributed.get_client()
list_of_parts_value = client.gather(list_of_parts)
def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
list_of_parts: List[tuple] = worker_map[worker.address]
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)
with distributed.worker_client() as client:
list_of_parts_value = client.gather(list_of_parts)

result = []
result = []

for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
if has_base_margin:
base_margin = list_of_parts_value[i][1]
else:
for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
labels = None
weights = None
base_margin = None
result.append((data, base_margin, partition_order[part.key]))
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])

if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[part.key]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result

return result

def _unzip(list_of_parts):
return list(zip(*list_of_parts))

def _get_worker_parts(worker_map, meta_names, worker):
'''Get mapped parts of data in each worker from DaskDMatrix.'''
list_of_parts = worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
assert isinstance(list_of_parts, list)

# `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`.
client = distributed.get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None

local_data = list(zip(*list_of_parts))
data = local_data[0]

for i, part in enumerate(local_data[1:]):
if meta_names[i] == 'labels':
labels = part
if meta_names[i] == 'weights':
weights = part
if meta_names[i] == 'base_margin':
base_margin = part
if meta_names[i] == 'label_lower_bound':
label_lower_bound = part
if meta_names[i] == 'label_upper_bound':
label_upper_bound = part

return (data, labels, weights, base_margin, label_lower_bound,
label_upper_bound)
def _get_worker_parts(worker_map, meta_names, worker):
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
partitions = _unzip(partitions)
return partitions


class DaskPartitionIter(DataIter): # pylint: disable=R0902
Expand Down Expand Up @@ -585,9 +588,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
return d

def concat_or_none(data):
if data is not None:
return concat(data)
return data
if all([part is None for part in data]):
return None
return concat(data)

(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
Expand Down Expand Up @@ -795,18 +798,19 @@ def mapped_predict(partition, is_df):
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
has_margin = "base_margin" in data.meta_names
meta_names = data.meta_names

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)

worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
has_margin, worker_map, partition_order, worker)
meta_names, worker_map, partition_order, worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for data, base_margin, order in list_of_parts:
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
local_part = DMatrix(
data,
base_margin=base_margin,
Expand All @@ -829,12 +833,15 @@ def dispatched_get_shape(worker_id):
LOGGER.info('Get shape on %d', worker_id)
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
False,
meta_names,
worker_map,
partition_order,
worker
)
shapes = [(part.shape, order) for part, _, order in list_of_parts]
shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
return shapes

async def map_function(func):
Expand Down Expand Up @@ -974,8 +981,7 @@ def inplace_predict(client, model, data,
missing=missing)


async def _evaluation_matrices(client, validation_set,
sample_weight, missing):
async def _evaluation_matrices(client, validation_set, sample_weight, missing):
'''
Parameters
----------
Expand All @@ -998,8 +1004,7 @@ async def _evaluation_matrices(client, validation_set,
if validation_set is not None:
assert isinstance(validation_set, list)
for i, e in enumerate(validation_set):
w = (sample_weight[i]
if sample_weight is not None else None)
w = (sample_weight[i] if sample_weight is not None else None)
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
weight=w, missing=missing)
evals.append((dmat, 'validation_{}'.format(i)))
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_with_dask.py
Expand Up @@ -566,6 +566,28 @@ def test_predict():
assert shap.shape[1] == kCols + 1


def test_predict_with_meta(client):
X, y, w = generate_array(with_weights=True)
partition_size = 20
margin = da.random.random(kRows, partition_size) + 1e4

dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=4)['booster']

prediction = xgb.dask.predict(client, model=booster, data=dtrain)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

prediction = client.compute(prediction).result()
assert np.all(prediction > 1e3)

m = xgb.DMatrix(X.compute())
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
single = booster.predict(m) # Make sure the ordering is correct.
assert np.all(prediction == single)


def run_aft_survival(client, dmatrix_t):
# survival doesn't handle empty dataset well.
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
Expand Down

0 comments on commit 7756192

Please sign in to comment.