Skip to content

Commit

Permalink
Unify the meta handling methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 31, 2020
1 parent 5822cb8 commit 18e23c4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 51 deletions.
107 changes: 56 additions & 51 deletions python-package/xgboost/dask.py
Expand Up @@ -301,15 +301,21 @@ def append_meta(m_parts, name: str):
append_meta(margin_parts, 'base_margin')
append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound')
# At this point, `parts` looks like:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form

# delay the zipped result
parts = list(map(dask.delayed, zip(*parts)))
# At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form

parts = client.compute(parts)
await distributed.wait(parts) # async wait for parts to be computed

for part in parts:
assert part.status == 'finished'

# Preserving the partition order for prediction.
self.partition_order = {}
for i, part in enumerate(parts):
self.partition_order[part.key] = i
Expand Down Expand Up @@ -342,7 +348,7 @@ def create_fn_args(self):

def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
list_of_parts: List[tuple] = worker_map[worker.address]
# List of partitions like: [(data, label, weight, margin, ...), ...]
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)
with distributed.worker_client() as client:
list_of_parts_value = client.gather(list_of_parts)
Expand All @@ -351,51 +357,44 @@ def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):

for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
if 'base_margin' in meta_names:
for j, name in enumerate(meta_names):
if name == 'base_margin':
base_margin = list_of_parts_value[i][j + 1]
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])

if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[part.key]))
else:
base_margin = None
result.append((data, base_margin, partition_order[part.key]))
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result

return result

def _unzip(list_of_parts):
return list(zip(*list_of_parts))

def _get_worker_parts(worker_map, meta_names, worker):
'''Get mapped parts of data in each worker from DaskDMatrix.'''
list_of_parts = worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
assert isinstance(list_of_parts, list)

# `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`.
client = distributed.get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None

local_data = list(zip(*list_of_parts))
data = local_data[0]

for i, part in enumerate(local_data[1:]):
if meta_names[i] == 'labels':
labels = part
if meta_names[i] == 'weights':
weights = part
if meta_names[i] == 'base_margin':
base_margin = part
if meta_names[i] == 'label_lower_bound':
label_lower_bound = part
if meta_names[i] == 'label_upper_bound':
label_upper_bound = part

return (data, labels, weights, base_margin, label_lower_bound,
label_upper_bound)
def _get_worker_parts(worker_map, meta_names, worker):
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
partitions = _unzip(partitions)
return partitions


class DaskPartitionIter(DataIter): # pylint: disable=R0902
Expand Down Expand Up @@ -589,9 +588,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
return d

def concat_or_none(data):
if data is not None:
return concat(data)
return data
if all([part is None for part in data]):
return None
return concat(data)

(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
Expand Down Expand Up @@ -810,7 +809,8 @@ def dispatched_predict(worker_id):
meta_names, worker_map, partition_order, worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for data, base_margin, order in list_of_parts:
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
local_part = DMatrix(
data,
base_margin=base_margin,
Expand Down Expand Up @@ -838,7 +838,10 @@ def dispatched_get_shape(worker_id):
partition_order,
worker
)
shapes = [(part.shape, order) for part, _, order in list_of_parts]
shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
return shapes

async def map_function(func):
Expand Down Expand Up @@ -978,8 +981,8 @@ def inplace_predict(client, model, data,
missing=missing)


async def _evaluation_matrices(client, validation_set,
sample_weight, missing):
async def _evaluation_matrices(client, validation_set, sample_weight, base_margin,
missing):
'''
Parameters
----------
Expand All @@ -1002,10 +1005,10 @@ async def _evaluation_matrices(client, validation_set,
if validation_set is not None:
assert isinstance(validation_set, list)
for i, e in enumerate(validation_set):
w = (sample_weight[i]
if sample_weight is not None else None)
w = (sample_weight[i] if sample_weight is not None else None)
margin = (base_margin[i] if base_margin is not None else None)
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
weight=w, missing=missing)
weight=w, missing=missing, base_margin=margin)
evals.append((dmat, 'validation_{}'.format(i)))
else:
evals = None
Expand Down Expand Up @@ -1090,6 +1093,7 @@ async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
params = self.get_xgb_params()
evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
base_margin,
self.missing)
results = await train(client=self.client,
params=params,
Expand Down Expand Up @@ -1173,6 +1177,7 @@ async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,

evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
base_margin,
self.missing)
results = await train(client=self.client,
params=params,
Expand Down
5 changes: 5 additions & 0 deletions tests/python/test_with_dask.py
Expand Up @@ -582,6 +582,11 @@ def test_predict_with_meta(client):
prediction = client.compute(prediction).result()
assert np.all(prediction > 1e3)

m = xgb.DMatrix(X.compute())
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
single = booster.predict(m) # Make sure the ordering is correct.
assert np.all(prediction == single)


def run_aft_survival(client, dmatrix_t):
# survival doesn't handle empty dataset well.
Expand Down

0 comments on commit 18e23c4

Please sign in to comment.