diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index b5479fdb9bb5..2c2da82f5cbc 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -213,9 +213,6 @@ def __init__(self, _expect((dd.DataFrame, da.Array, dd.Series), type(label))) self.worker_map = None - self.has_label = label is not None - self.has_weights = weight is not None - self.is_quantile = False self._init = client.sync(self.map_local_data, @@ -269,14 +266,17 @@ def check_columns(parts): w_parts = w_parts.flatten().tolist() parts = [X_parts] + meta_names = [] if label is not None: assert len(X_parts) == len( y_parts), inconsistent(X_parts, 'X', y_parts, 'labels') parts.append(y_parts) + meta_names.append('labels') if weights is not None: assert len(X_parts) == len( w_parts), inconsistent(X_parts, 'X', w_parts, 'weights') parts.append(w_parts) + meta_names.append('weights') parts = list(map(delayed, zip(*parts))) parts = client.compute(parts) @@ -298,6 +298,7 @@ def check_columns(parts): worker_map[next(iter(workers))].append(key_to_partition[key]) self.worker_map = worker_map + self.meta_names = meta_names return self @@ -308,8 +309,7 @@ def create_fn_args(self): ''' return {'feature_names': self.feature_names, 'feature_types': self.feature_types, - 'has_label': self.has_label, - 'has_weights': self.has_weights, + 'meta_names': self.meta_names, 'missing': self.missing, 'worker_map': self.worker_map, 'is_quantile': self.is_quantile} @@ -326,7 +326,7 @@ def _get_worker_x_ordered(worker_map, partition_order, worker): return result -def _get_worker_parts(has_label, has_weights, worker_map, worker): +def _get_worker_parts(worker_map, meta_names, worker): '''Get mapped parts of data in each worker from DaskDMatrix.''' list_of_parts = worker_map[worker.address] assert list_of_parts, 'data in ' + worker.address + ' was moved.' @@ -336,17 +336,19 @@ def _get_worker_parts(has_label, has_weights, worker_map, worker): # this should be equal to `worker._get_client`. client = get_client() list_of_parts = client.gather(list_of_parts) + data = None + labels = None + weights = None + + local_data = list(zip(*list_of_parts)) + data = local_data[0] + + for i, part in enumerate(local_data[1:]): + if meta_names[i] == 'labels': + labels = part + if meta_names[i] == 'weights': + weights = part - if has_label: - if has_weights: - data, labels, weights = zip(*list_of_parts) - else: - data, labels = zip(*list_of_parts) - weights = None - else: - data = [d[0] for d in list_of_parts] - labels = None - weights = None return data, labels, weights @@ -473,8 +475,7 @@ def create_fn_args(self): def _create_device_quantile_dmatrix(feature_names, feature_types, - has_label, - has_weights, missing, worker_map, + meta_names, missing, worker_map, max_bin): worker = distributed_get_worker() if worker.address not in set(worker_map.keys()): @@ -490,8 +491,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types, max_bin=max_bin) return d - data, labels, weights = _get_worker_parts(has_label, has_weights, - worker_map, worker) + data, labels, weights = _get_worker_parts(worker_map, meta_names, worker) it = DaskPartitionIter(data=data, label=labels, weight=weights) dmatrix = DeviceQuantileDMatrix(it, @@ -503,8 +503,8 @@ def _create_device_quantile_dmatrix(feature_names, feature_types, return dmatrix -def _create_dmatrix(feature_names, feature_types, has_label, - has_weights, missing, worker_map): +def _create_dmatrix(feature_names, feature_types, meta_names, missing, + worker_map): '''Get data that local to worker from DaskDMatrix. Returns @@ -524,18 +524,13 @@ def _create_dmatrix(feature_names, feature_types, has_label, feature_types=feature_types) return d - data, labels, weights = _get_worker_parts(has_label, has_weights, - worker_map, worker) + data, labels, weights = _get_worker_parts(worker_map, meta_names, worker) data = concat(data) - if has_label: + if labels: labels = concat(labels) - else: - labels = None - if has_weights: + if weights: weights = concat(weights) - else: - weights = None dmatrix = DMatrix(data, labels, weight=weights,