Skip to content

Commit

Permalink
Use submit instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 5, 2020
1 parent f477092 commit a270e38
Showing 1 changed file with 74 additions and 76 deletions.
150 changes: 74 additions & 76 deletions python-package/xgboost/dask.py
Expand Up @@ -330,7 +330,7 @@ def append_meta(m_parts, name: str):

return self

def create_fn_args(self):
def create_fn_args(self, worker_addr: str):
'''Create a dictionary of objects that can be pickled for function
arguments.
Expand All @@ -339,57 +339,56 @@ def create_fn_args(self):
'feature_types': self.feature_types,
'meta_names': self.meta_names,
'missing': self.missing,
'worker_map': self.worker_map,
'worker_map': self.worker_map.get(worker_addr, None),
'is_quantile': self.is_quantile}


def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
list_of_parts: List[tuple] = worker_map[worker.address]
def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)
with distributed.worker_client() as client:
list_of_parts_value = client.gather(list_of_parts)

result = []

for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])

if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[part.key]))
list_of_parts_value = list_of_parts

result = []

for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result
raise ValueError('Unknown metainfo:', meta_names[j])

if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[list_of_keys[i]]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result


def _unzip(list_of_parts):
return list(zip(*list_of_parts))


def _get_worker_parts(worker_map, meta_names, worker):
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
def _get_worker_parts(worker_map, meta_names):
list_of_parts: List[tuple] = worker_map
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
partitions = _unzip(partitions)
return partitions

Expand Down Expand Up @@ -519,8 +518,8 @@ def __init__(self, client,
self.max_bin = max_bin
self.is_quantile = True

def create_fn_args(self):
args = super().create_fn_args()
def create_fn_args(self, worker_addr: str):
args = super().create_fn_args(worker_addr)
args['max_bin'] = self.max_bin
return args

Expand All @@ -529,11 +528,9 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
meta_names, missing, worker_map,
max_bin):
worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(worker_map.keys()))
if worker_map is None:
msg = 'worker {address} has an empty DMatrix. '.format(
address=worker.address)
LOGGER.warning(msg)
import cupy # pylint: disable=import-error
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
Expand All @@ -544,7 +541,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,

(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
worker_map, meta_names, worker)
worker_map, meta_names)
it = DaskPartitionIter(data=data, label=labels, weight=weights,
base_margin=base_margin,
label_lower_bound=label_lower_bound,
Expand All @@ -569,11 +566,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
'''
worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(worker_map.keys()))
list_of_parts = worker_map
if list_of_parts is None:
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=feature_names,
Expand All @@ -586,8 +581,7 @@ def concat_or_none(data):
return concat(data)

(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
worker_map, meta_names, worker)
label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names)

labels = concat_or_none(labels)
weights = concat_or_none(weights)
Expand Down Expand Up @@ -640,8 +634,6 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):

async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
early_stopping_rounds=None, **kwargs):
_assert_dask_support()
client: distributed.Client = _xgb_get_client(client)
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
Expand Down Expand Up @@ -700,13 +692,13 @@ def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref):
# XGBoost is deterministic in most of the cases, which means train function is
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False.
futures = client.map(dispatched_train,
workers,
[_rabit_args] * len(workers),
[dtrain.create_fn_args()] * len(workers),
[evals] * len(workers),
pure=False,
workers=workers)
futures = []
for i in range(len(workers)):
evals = [(e.create_fn_args(workers[i])) for e, name in evals]
f = client.submit(dispatched_train, workers[i], _rabit_args,
dtrain.create_fn_args(workers[i]), evals)
futures.append(f)

results = await client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]

Expand Down Expand Up @@ -802,14 +794,15 @@ def mapped_predict(partition, is_df):
missing = data.missing
meta_names = data.meta_names

def dispatched_predict(worker_id):
def dispatched_predict(worker_id, list_of_keys, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)

list_of_keys = list_of_keys.compute()
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
meta_names, worker_map, partition_order, worker)
meta_names, list_of_keys, list_of_parts, partition_order)
predictions = []

booster.set_param({'nthread': worker.nthreads})
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
Expand All @@ -828,17 +821,20 @@ def dispatched_predict(worker_id):
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order)
predictions.append(ret)

return predictions

def dispatched_get_shape(worker_id):
def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
worker = distributed.get_worker()
list_of_keys = list_of_keys.compute()
# worker = distributed.get_worker()
# list_of_parts = worker_map[worker.address]
list_of_parts = _get_worker_parts_ordered(
meta_names,
worker_map,
list_of_keys,
list_of_parts,
partition_order,
worker
)
shapes = []
for parts in list_of_parts:
Expand All @@ -850,12 +846,14 @@ async def map_function(func):
'''Run function for each part of the data.'''
futures = []
for wid in range(len(worker_map)):
list_of_workers = [list(worker_map.keys())[wid]]
f = await client.submit(func, wid,
pure=False,
workers=list_of_workers)
worker_addr = list(worker_map.keys())[wid]
list_of_parts = worker_map[worker_addr]
list_of_keys = [part.key for part in list_of_parts]
f = await client.submit(func, worker_id=wid, list_of_keys=dask.delayed(list_of_keys),
list_of_parts=list_of_parts,
pure=False, workers=[worker_addr])
futures.append(f)
# Get delayed objects
# # Get delayed objects
results = await client.gather(futures)
results = [t for l in results for t in l] # flatten into 1 dim list
# sort by order, l[0] is the delayed object, l[1] is its order
Expand Down

0 comments on commit a270e38

Please sign in to comment.