Skip to content

Commit

Permalink
Fix dask predict (#6412)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 20, 2020
1 parent 44a9d69 commit a7b42ad
Showing 1 changed file with 21 additions and 33 deletions.
54 changes: 21 additions & 33 deletions python-package/xgboost/dask.py
Expand Up @@ -344,7 +344,7 @@ def create_fn_args(self, worker_addr: str):
'is_quantile': self.is_quantile}


def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
def _get_worker_parts_ordered(meta_names, list_of_parts):
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)

Expand Down Expand Up @@ -372,13 +372,8 @@ def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])

if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[list_of_keys[i]]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result


Expand All @@ -387,7 +382,7 @@ def _unzip(list_of_parts):


def _get_worker_parts(list_of_parts: List[tuple], meta_names):
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
partitions = _get_worker_parts_ordered(meta_names, list_of_parts)
partitions = _unzip(partitions)
return partitions

Expand Down Expand Up @@ -799,19 +794,17 @@ def mapped_predict(partition, is_df):
missing = data.missing
meta_names = data.meta_names

def dispatched_predict(worker_id, list_of_keys, list_of_parts):
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
meta_names, list_of_keys, list_of_parts, partition_order)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []

booster.set_param({'nthread': worker.nthreads})
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
Expand All @@ -830,21 +823,14 @@ def dispatched_predict(worker_id, list_of_keys, list_of_parts):

return predictions

def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
list_of_parts = _get_worker_parts_ordered(
meta_names,
list_of_keys,
list_of_parts,
partition_order,
)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
for i, parts in enumerate(list_of_parts):
(data, _, _, _, _, _) = parts
shapes.append((data.shape, list_of_orders[i]))
return shapes

async def map_function(func):
Expand All @@ -854,11 +840,13 @@ async def map_function(func):
for wid, worker_addr in enumerate(workers_address):
worker_addr = workers_address[wid]
list_of_parts = worker_map[worker_addr]
list_of_keys = [part.key for part in list_of_parts]
f = await client.submit(func, worker_id=wid,
list_of_keys=dask.delayed(list_of_keys),
list_of_parts=list_of_parts,
pure=False, workers=[worker_addr])
list_of_orders = [partition_order[part.key] for part in list_of_parts]

f = client.submit(func, worker_id=wid,
list_of_orders=list_of_orders,
list_of_parts=list_of_parts,
pure=True, workers=[worker_addr])
assert isinstance(f, distributed.client.Future)
futures.append(f)
# Get delayed objects
results = await client.gather(futures)
Expand Down

0 comments on commit a7b42ad

Please sign in to comment.