Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Refactor meta data handling. #6130

Merged
merged 1 commit into from Sep 18, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 25 additions & 30 deletions python-package/xgboost/dask.py
Expand Up @@ -213,9 +213,6 @@ def __init__(self,
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))

self.worker_map = None
self.has_label = label is not None
self.has_weights = weight is not None

self.is_quantile = False

self._init = client.sync(self.map_local_data,
Expand Down Expand Up @@ -269,14 +266,17 @@ def check_columns(parts):
w_parts = w_parts.flatten().tolist()

parts = [X_parts]
meta_names = []
if label is not None:
assert len(X_parts) == len(
y_parts), inconsistent(X_parts, 'X', y_parts, 'labels')
parts.append(y_parts)
meta_names.append('labels')
if weights is not None:
assert len(X_parts) == len(
w_parts), inconsistent(X_parts, 'X', w_parts, 'weights')
parts.append(w_parts)
meta_names.append('weights')
parts = list(map(delayed, zip(*parts)))

parts = client.compute(parts)
Expand All @@ -298,6 +298,7 @@ def check_columns(parts):
worker_map[next(iter(workers))].append(key_to_partition[key])

self.worker_map = worker_map
self.meta_names = meta_names

return self

Expand All @@ -308,8 +309,7 @@ def create_fn_args(self):
'''
return {'feature_names': self.feature_names,
'feature_types': self.feature_types,
'has_label': self.has_label,
'has_weights': self.has_weights,
'meta_names': self.meta_names,
'missing': self.missing,
'worker_map': self.worker_map,
'is_quantile': self.is_quantile}
Expand All @@ -326,7 +326,7 @@ def _get_worker_x_ordered(worker_map, partition_order, worker):
return result


def _get_worker_parts(has_label, has_weights, worker_map, worker):
def _get_worker_parts(worker_map, meta_names, worker):
'''Get mapped parts of data in each worker from DaskDMatrix.'''
list_of_parts = worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
Expand All @@ -336,17 +336,19 @@ def _get_worker_parts(has_label, has_weights, worker_map, worker):
# this should be equal to `worker._get_client`.
client = get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
weights = None

local_data = list(zip(*list_of_parts))
data = local_data[0]

for i, part in enumerate(local_data[1:]):
if meta_names[i] == 'labels':
labels = part
if meta_names[i] == 'weights':
weights = part

if has_label:
if has_weights:
data, labels, weights = zip(*list_of_parts)
else:
data, labels = zip(*list_of_parts)
weights = None
else:
data = [d[0] for d in list_of_parts]
labels = None
weights = None
return data, labels, weights


Expand Down Expand Up @@ -473,8 +475,7 @@ def create_fn_args(self):


def _create_device_quantile_dmatrix(feature_names, feature_types,
has_label,
has_weights, missing, worker_map,
meta_names, missing, worker_map,
max_bin):
worker = distributed_get_worker()
if worker.address not in set(worker_map.keys()):
Expand All @@ -490,8 +491,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
max_bin=max_bin)
return d

data, labels, weights = _get_worker_parts(has_label, has_weights,
worker_map, worker)
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights)

dmatrix = DeviceQuantileDMatrix(it,
Expand All @@ -503,8 +503,8 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
return dmatrix


def _create_dmatrix(feature_names, feature_types, has_label,
has_weights, missing, worker_map):
def _create_dmatrix(feature_names, feature_types, meta_names, missing,
worker_map):
'''Get data that local to worker from DaskDMatrix.

Returns
Expand All @@ -524,18 +524,13 @@ def _create_dmatrix(feature_names, feature_types, has_label,
feature_types=feature_types)
return d

data, labels, weights = _get_worker_parts(has_label, has_weights,
worker_map, worker)
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
data = concat(data)

if has_label:
if labels:
labels = concat(labels)
else:
labels = None
if has_weights:
if weights:
weights = concat(weights)
else:
weights = None
dmatrix = DMatrix(data,
labels,
weight=weights,
Expand Down