diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index fb0f7f0d97e1..bbd2fe1d2c64 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -344,7 +344,7 @@ def create_fn_args(self, worker_addr: str): 'is_quantile': self.is_quantile} -def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order): +def _get_worker_parts_ordered(meta_names, list_of_parts): # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved. assert isinstance(list_of_parts, list) @@ -372,13 +372,8 @@ def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition label_upper_bound = blob else: raise ValueError('Unknown metainfo:', meta_names[j]) - - if partition_order: - result.append((data, labels, weights, base_margin, label_lower_bound, - label_upper_bound, partition_order[list_of_keys[i]])) - else: - result.append((data, labels, weights, base_margin, label_lower_bound, - label_upper_bound)) + result.append((data, labels, weights, base_margin, label_lower_bound, + label_upper_bound)) return result @@ -387,7 +382,7 @@ def _unzip(list_of_parts): def _get_worker_parts(list_of_parts: List[tuple], meta_names): - partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None) + partitions = _get_worker_parts_ordered(meta_names, list_of_parts) partitions = _unzip(partitions) return partitions @@ -799,19 +794,17 @@ def mapped_predict(partition, is_df): missing = data.missing meta_names = data.meta_names - def dispatched_predict(worker_id, list_of_keys, list_of_parts): + def dispatched_predict(worker_id, list_of_orders, list_of_parts): '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) - c = distributed.get_client() - list_of_keys = c.compute(list_of_keys).result() worker = distributed.get_worker() - list_of_parts = _get_worker_parts_ordered( - meta_names, list_of_keys, list_of_parts, partition_order) + list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) predictions = [] booster.set_param({'nthread': worker.nthreads}) - for parts in list_of_parts: - (data, _, _, base_margin, _, _, order) = parts + for i, parts in enumerate(list_of_parts): + (data, _, _, base_margin, _, _) = parts + order = list_of_orders[i] local_part = DMatrix( data, base_margin=base_margin, @@ -830,21 +823,14 @@ def dispatched_predict(worker_id, list_of_keys, list_of_parts): return predictions - def dispatched_get_shape(worker_id, list_of_keys, list_of_parts): + def dispatched_get_shape(worker_id, list_of_orders, list_of_parts): '''Get shape of data in each worker.''' LOGGER.info('Get shape on %d', worker_id) - c = distributed.get_client() - list_of_keys = c.compute(list_of_keys).result() - list_of_parts = _get_worker_parts_ordered( - meta_names, - list_of_keys, - list_of_parts, - partition_order, - ) + list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) shapes = [] - for parts in list_of_parts: - (data, _, _, _, _, _, order) = parts - shapes.append((data.shape, order)) + for i, parts in enumerate(list_of_parts): + (data, _, _, _, _, _) = parts + shapes.append((data.shape, list_of_orders[i])) return shapes async def map_function(func): @@ -854,11 +840,13 @@ async def map_function(func): for wid, worker_addr in enumerate(workers_address): worker_addr = workers_address[wid] list_of_parts = worker_map[worker_addr] - list_of_keys = [part.key for part in list_of_parts] - f = await client.submit(func, worker_id=wid, - list_of_keys=dask.delayed(list_of_keys), - list_of_parts=list_of_parts, - pure=False, workers=[worker_addr]) + list_of_orders = [partition_order[part.key] for part in list_of_parts] + + f = client.submit(func, worker_id=wid, + list_of_orders=list_of_orders, + list_of_parts=list_of_parts, + pure=True, workers=[worker_addr]) + assert isinstance(f, distributed.client.Future) futures.append(f) # Get delayed objects results = await client.gather(futures)