Skip to content

Commit

Permalink
[dask] Refactor meta data handling. (#6130)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 18, 2020
1 parent 5384ed8 commit cc82ca1
Showing 1 changed file with 25 additions and 30 deletions.
55 changes: 25 additions & 30 deletions python-package/xgboost/dask.py
Expand Up @@ -213,9 +213,6 @@ def __init__(self,
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))

self.worker_map = None
self.has_label = label is not None
self.has_weights = weight is not None

self.is_quantile = False

self._init = client.sync(self.map_local_data,
Expand Down Expand Up @@ -269,14 +266,17 @@ def check_columns(parts):
w_parts = w_parts.flatten().tolist()

parts = [X_parts]
meta_names = []
if label is not None:
assert len(X_parts) == len(
y_parts), inconsistent(X_parts, 'X', y_parts, 'labels')
parts.append(y_parts)
meta_names.append('labels')
if weights is not None:
assert len(X_parts) == len(
w_parts), inconsistent(X_parts, 'X', w_parts, 'weights')
parts.append(w_parts)
meta_names.append('weights')
parts = list(map(delayed, zip(*parts)))

parts = client.compute(parts)
Expand All @@ -298,6 +298,7 @@ def check_columns(parts):
worker_map[next(iter(workers))].append(key_to_partition[key])

self.worker_map = worker_map
self.meta_names = meta_names

return self

Expand All @@ -308,8 +309,7 @@ def create_fn_args(self):
'''
return {'feature_names': self.feature_names,
'feature_types': self.feature_types,
'has_label': self.has_label,
'has_weights': self.has_weights,
'meta_names': self.meta_names,
'missing': self.missing,
'worker_map': self.worker_map,
'is_quantile': self.is_quantile}
Expand All @@ -326,7 +326,7 @@ def _get_worker_x_ordered(worker_map, partition_order, worker):
return result


def _get_worker_parts(has_label, has_weights, worker_map, worker):
def _get_worker_parts(worker_map, meta_names, worker):
'''Get mapped parts of data in each worker from DaskDMatrix.'''
list_of_parts = worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
Expand All @@ -336,17 +336,19 @@ def _get_worker_parts(has_label, has_weights, worker_map, worker):
# this should be equal to `worker._get_client`.
client = get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
weights = None

local_data = list(zip(*list_of_parts))
data = local_data[0]

for i, part in enumerate(local_data[1:]):
if meta_names[i] == 'labels':
labels = part
if meta_names[i] == 'weights':
weights = part

if has_label:
if has_weights:
data, labels, weights = zip(*list_of_parts)
else:
data, labels = zip(*list_of_parts)
weights = None
else:
data = [d[0] for d in list_of_parts]
labels = None
weights = None
return data, labels, weights


Expand Down Expand Up @@ -473,8 +475,7 @@ def create_fn_args(self):


def _create_device_quantile_dmatrix(feature_names, feature_types,
has_label,
has_weights, missing, worker_map,
meta_names, missing, worker_map,
max_bin):
worker = distributed_get_worker()
if worker.address not in set(worker_map.keys()):
Expand All @@ -490,8 +491,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
max_bin=max_bin)
return d

data, labels, weights = _get_worker_parts(has_label, has_weights,
worker_map, worker)
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights)

dmatrix = DeviceQuantileDMatrix(it,
Expand All @@ -503,8 +503,8 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
return dmatrix


def _create_dmatrix(feature_names, feature_types, has_label,
has_weights, missing, worker_map):
def _create_dmatrix(feature_names, feature_types, meta_names, missing,
worker_map):
'''Get data that local to worker from DaskDMatrix.
Returns
Expand All @@ -524,18 +524,13 @@ def _create_dmatrix(feature_names, feature_types, has_label,
feature_types=feature_types)
return d

data, labels, weights = _get_worker_parts(has_label, has_weights,
worker_map, worker)
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
data = concat(data)

if has_label:
if labels:
labels = concat(labels)
else:
labels = None
if has_weights:
if weights:
weights = concat(weights)
else:
weights = None
dmatrix = DMatrix(data,
labels,
weight=weights,
Expand Down

0 comments on commit cc82ca1

Please sign in to comment.