Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dask predict #6412

Merged
merged 2 commits into from Nov 20, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 21 additions & 33 deletions python-package/xgboost/dask.py
Expand Up @@ -344,7 +344,7 @@ def create_fn_args(self, worker_addr: str):
'is_quantile': self.is_quantile}


def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
def _get_worker_parts_ordered(meta_names, list_of_parts):
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)

Expand Down Expand Up @@ -372,13 +372,8 @@ def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])

if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[list_of_keys[i]]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result


Expand All @@ -387,7 +382,7 @@ def _unzip(list_of_parts):


def _get_worker_parts(list_of_parts: List[tuple], meta_names):
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
partitions = _get_worker_parts_ordered(meta_names, list_of_parts)
partitions = _unzip(partitions)
return partitions

Expand Down Expand Up @@ -799,19 +794,17 @@ def mapped_predict(partition, is_df):
missing = data.missing
meta_names = data.meta_names

def dispatched_predict(worker_id, list_of_keys, list_of_parts):
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
meta_names, list_of_keys, list_of_parts, partition_order)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []

booster.set_param({'nthread': worker.nthreads})
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
Expand All @@ -830,21 +823,14 @@ def dispatched_predict(worker_id, list_of_keys, list_of_parts):

return predictions

def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
list_of_parts = _get_worker_parts_ordered(
meta_names,
list_of_keys,
list_of_parts,
partition_order,
)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
for i, parts in enumerate(list_of_parts):
(data, _, _, _, _, _) = parts
shapes.append((data.shape, list_of_orders[i]))
return shapes

async def map_function(func):
Expand All @@ -854,11 +840,13 @@ async def map_function(func):
for wid, worker_addr in enumerate(workers_address):
worker_addr = workers_address[wid]
list_of_parts = worker_map[worker_addr]
list_of_keys = [part.key for part in list_of_parts]
f = await client.submit(func, worker_id=wid,
list_of_keys=dask.delayed(list_of_keys),
list_of_parts=list_of_parts,
pure=False, workers=[worker_addr])
list_of_orders = [partition_order[part.key] for part in list_of_parts]

f = client.submit(func, worker_id=wid,
list_of_orders=list_of_orders,
list_of_parts=list_of_parts,
pure=True, workers=[worker_addr])
assert isinstance(f, distributed.client.Future)
futures.append(f)
# Get delayed objects
results = await client.gather(futures)
Expand Down