diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index ba5316800d14..3433c890cb1f 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -66,9 +66,11 @@ LOGGER = logging.getLogger('[xgboost.dask]') -def _start_tracker(host, n_workers): +def _start_tracker(n_workers): """Start Rabit tracker """ env = {'DMLC_NUM_WORKER': n_workers} + import socket + host = socket.gethostbyname(socket.gethostname()) rabit_context = RabitTracker(hostIP=host, nslave=n_workers) env.update(rabit_context.slave_envs()) @@ -141,11 +143,6 @@ def _xgb_get_client(client): ret = distributed.get_client() if client is None else client return ret - -def _get_client_workers(client): - workers = client.scheduler_info()['workers'] - return workers - # From the implementation point of view, DaskDMatrix complicates a lots of # things. A large portion of the code base is about syncing and extracting # stuffs from DaskDMatrix. But having an independent data structure gives us a @@ -333,7 +330,7 @@ def append_meta(m_parts, name: str): return self - def create_fn_args(self): + def create_fn_args(self, worker_addr: str): '''Create a dictionary of objects that can be pickled for function arguments. @@ -342,57 +339,55 @@ def create_fn_args(self): 'feature_types': self.feature_types, 'meta_names': self.meta_names, 'missing': self.missing, - 'worker_map': self.worker_map, + 'parts': self.worker_map.get(worker_addr, None), 'is_quantile': self.is_quantile} -def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker): - list_of_parts: List[tuple] = worker_map[worker.address] +def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order): # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved. assert isinstance(list_of_parts, list) - with distributed.worker_client() as client: - list_of_parts_value = client.gather(list_of_parts) - - result = [] - - for i, part in enumerate(list_of_parts): - data = list_of_parts_value[i][0] - labels = None - weights = None - base_margin = None - label_lower_bound = None - label_upper_bound = None - # Iterate through all possible meta info, brings small overhead as in xgboost - # there are constant number of meta info available. - for j, blob in enumerate(list_of_parts_value[i][1:]): - if meta_names[j] == 'labels': - labels = blob - elif meta_names[j] == 'weights': - weights = blob - elif meta_names[j] == 'base_margin': - base_margin = blob - elif meta_names[j] == 'label_lower_bound': - label_lower_bound = blob - elif meta_names[j] == 'label_upper_bound': - label_upper_bound = blob - else: - raise ValueError('Unknown metainfo:', meta_names[j]) - - if partition_order: - result.append((data, labels, weights, base_margin, label_lower_bound, - label_upper_bound, partition_order[part.key])) + list_of_parts_value = list_of_parts + + result = [] + + for i, _ in enumerate(list_of_parts): + data = list_of_parts_value[i][0] + labels = None + weights = None + base_margin = None + label_lower_bound = None + label_upper_bound = None + # Iterate through all possible meta info, brings small overhead as in xgboost + # there are constant number of meta info available. + for j, blob in enumerate(list_of_parts_value[i][1:]): + if meta_names[j] == 'labels': + labels = blob + elif meta_names[j] == 'weights': + weights = blob + elif meta_names[j] == 'base_margin': + base_margin = blob + elif meta_names[j] == 'label_lower_bound': + label_lower_bound = blob + elif meta_names[j] == 'label_upper_bound': + label_upper_bound = blob else: - result.append((data, labels, weights, base_margin, label_lower_bound, - label_upper_bound)) - return result + raise ValueError('Unknown metainfo:', meta_names[j]) + + if partition_order: + result.append((data, labels, weights, base_margin, label_lower_bound, + label_upper_bound, partition_order[list_of_keys[i]])) + else: + result.append((data, labels, weights, base_margin, label_lower_bound, + label_upper_bound)) + return result def _unzip(list_of_parts): return list(zip(*list_of_parts)) -def _get_worker_parts(worker_map, meta_names, worker): - partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker) +def _get_worker_parts(list_of_parts: List[tuple], meta_names): + partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None) partitions = _unzip(partitions) return partitions @@ -522,21 +517,19 @@ def __init__(self, client, self.max_bin = max_bin self.is_quantile = True - def create_fn_args(self): - args = super().create_fn_args() + def create_fn_args(self, worker_addr: str): + args = super().create_fn_args(worker_addr) args['max_bin'] = self.max_bin return args def _create_device_quantile_dmatrix(feature_names, feature_types, - meta_names, missing, worker_map, + meta_names, missing, parts, max_bin): worker = distributed.get_worker() - if worker.address not in set(worker_map.keys()): - msg = 'worker {address} has an empty DMatrix. ' \ - 'All workers associated with this DMatrix: {workers}'.format( - address=worker.address, - workers=set(worker_map.keys())) + if parts is None: + msg = 'worker {address} has an empty DMatrix. '.format( + address=worker.address) LOGGER.warning(msg) import cupy # pylint: disable=import-error d = DeviceQuantileDMatrix(cupy.zeros((0, 0)), @@ -547,7 +540,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types, (data, labels, weights, base_margin, label_lower_bound, label_upper_bound) = _get_worker_parts( - worker_map, meta_names, worker) + parts, meta_names) it = DaskPartitionIter(data=data, label=labels, weight=weights, base_margin=base_margin, label_lower_bound=label_lower_bound, @@ -562,8 +555,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types, return dmatrix -def _create_dmatrix(feature_names, feature_types, meta_names, missing, - worker_map): +def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts): '''Get data that local to worker from DaskDMatrix. Returns @@ -572,11 +564,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing, ''' worker = distributed.get_worker() - if worker.address not in set(worker_map.keys()): - msg = 'worker {address} has an empty DMatrix. ' \ - 'All workers associated with this DMatrix: {workers}'.format( - address=worker.address, - workers=set(worker_map.keys())) + list_of_parts = parts + if list_of_parts is None: + msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address) LOGGER.warning(msg) d = DMatrix(numpy.empty((0, 0)), feature_names=feature_names, @@ -584,13 +574,12 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing, return d def concat_or_none(data): - if all([part is None for part in data]): + if any([part is None for part in data]): return None return concat(data) (data, labels, weights, base_margin, - label_lower_bound, label_upper_bound) = _get_worker_parts( - worker_map, meta_names, worker) + label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names) labels = concat_or_none(labels) weights = concat_or_none(weights) @@ -611,17 +600,15 @@ def concat_or_none(data): return dmatrix -def _dmatrix_from_worker_map(is_quantile, **kwargs): +def _dmatrix_from_list_of_parts(is_quantile, **kwargs): if is_quantile: return _create_device_quantile_dmatrix(**kwargs) return _create_dmatrix(**kwargs) -async def _get_rabit_args(worker_map, client): +async def _get_rabit_args(n_workers: int, client): '''Get rabit context arguments from data distribution in DaskDMatrix.''' - host = distributed.comm.get_address_host(client.scheduler.address) - env = await client.run_on_scheduler( - _start_tracker, host.strip('/:'), len(worker_map)) + env = await client.run_on_scheduler(_start_tracker, n_workers) rabit_args = [('%s=%s' % item).encode() for item in env.items()] return rabit_args @@ -632,49 +619,58 @@ async def _get_rabit_args(worker_map, client): # evaluation history is instead returned. -async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(), - early_stopping_rounds=None, **kwargs): - _assert_dask_support() - client: distributed.Client = _xgb_get_client(client) +def _get_workers_from_data(dtrain: DaskDMatrix, evals=()): + X_worker_map = set(dtrain.worker_map.keys()) + if evals: + for e in evals: + assert len(e) == 2 + assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str) + worker_map = set(e[0].worker_map.keys()) + X_worker_map.union(worker_map) + return X_worker_map + + +async def _train_async(client, + params, + dtrain: DaskDMatrix, + *args, + evals=(), + early_stopping_rounds=None, + **kwargs): if 'evals_result' in kwargs.keys(): raise ValueError( 'evals_result is not supported in dask interface.', 'The evaluation history is returned as result of training.') - workers = list(_get_client_workers(client).keys()) - _rabit_args = await _get_rabit_args(workers, client) + workers = list(_get_workers_from_data(dtrain, evals)) + _rabit_args = await _get_rabit_args(len(workers), client) - def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref): + def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref): '''Perform training on a single worker. A local function prevents pickling. ''' LOGGER.info('Training on %s', str(worker_addr)) worker = distributed.get_worker() with RabitContext(rabit_args): - local_dtrain = _dmatrix_from_worker_map(**dtrain_ref) + local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref) local_evals = [] if evals_ref: - for ref, name in evals_ref: - if ref['worker_map'] == dtrain_ref['worker_map']: + for ref, name, idt in evals_ref: + if idt == dtrain_idt: local_evals.append((local_dtrain, name)) continue - local_evals.append((_dmatrix_from_worker_map(**ref), name)) + local_evals.append((_dmatrix_from_list_of_parts(**ref), name)) local_history = {} local_param = params.copy() # just to be consistent msg = 'Overriding `nthreads` defined in dask worker.' - if 'nthread' in local_param.keys() and \ - local_param['nthread'] is not None and \ - local_param['nthread'] != worker.nthreads: - msg += '`nthread` is specified. ' + msg - LOGGER.warning(msg) - elif 'n_jobs' in local_param.keys() and \ - local_param['n_jobs'] is not None and \ - local_param['n_jobs'] != worker.nthreads: - msg = '`n_jobs` is specified. ' + msg - LOGGER.warning(msg) - else: - local_param['nthread'] = worker.nthreads + override = ['nthread', 'n_jobs'] + for p in override: + val = local_param.get(p, None) + if val is not None and val != worker.nthreads: + LOGGER.info(msg) + else: + local_param[p] = worker.nthreads bst = worker_train(params=local_param, dtrain=local_dtrain, *args, @@ -687,20 +683,26 @@ def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref): ret = None return ret - if evals: - evals = [(e.create_fn_args(), name) for e, name in evals] - # Note for function purity: # XGBoost is deterministic in most of the cases, which means train function is # supposed to be idempotent. One known exception is gblinear with shotgun updater. # We haven't been able to do a full verification so here we keep pure to be False. - futures = client.map(dispatched_train, - workers, - [_rabit_args] * len(workers), - [dtrain.create_fn_args()] * len(workers), - [evals] * len(workers), - pure=False, - workers=workers) + futures = [] + for i, worker_addr in enumerate(workers): + if evals: + evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e)) + for e, name in evals] + else: + evals_per_worker = [] + f = client.submit(dispatched_train, + worker_addr, + _rabit_args, + dtrain.create_fn_args(workers[i]), + id(dtrain), + evals_per_worker, + pure=False) + futures.append(f) + results = await client.gather(futures) return list(filter(lambda ret: ret is not None, results))[0] @@ -796,14 +798,16 @@ def mapped_predict(partition, is_df): missing = data.missing meta_names = data.meta_names - def dispatched_predict(worker_id): + def dispatched_predict(worker_id, list_of_keys, list_of_parts): '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) - + c = distributed.get_client() + list_of_keys = c.compute(list_of_keys).result() worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered( - meta_names, worker_map, partition_order, worker) + meta_names, list_of_keys, list_of_parts, partition_order) predictions = [] + booster.set_param({'nthread': worker.nthreads}) for parts in list_of_parts: (data, _, _, base_margin, _, _, order) = parts @@ -822,17 +826,19 @@ def dispatched_predict(worker_id): columns = 1 if len(predt.shape) == 1 else predt.shape[1] ret = ((dask.delayed(predt), columns), order) predictions.append(ret) + return predictions - def dispatched_get_shape(worker_id): + def dispatched_get_shape(worker_id, list_of_keys, list_of_parts): '''Get shape of data in each worker.''' LOGGER.info('Get shape on %d', worker_id) - worker = distributed.get_worker() + c = distributed.get_client() + list_of_keys = c.compute(list_of_keys).result() list_of_parts = _get_worker_parts_ordered( meta_names, - worker_map, + list_of_keys, + list_of_parts, partition_order, - worker ) shapes = [] for parts in list_of_parts: @@ -843,15 +849,20 @@ def dispatched_get_shape(worker_id): async def map_function(func): '''Run function for each part of the data.''' futures = [] - for wid in range(len(worker_map)): - list_of_workers = [list(worker_map.keys())[wid]] - f = await client.submit(func, wid, - pure=False, - workers=list_of_workers) + workers_address = list(worker_map.keys()) + for wid, worker_addr in enumerate(workers_address): + worker_addr = workers_address[wid] + list_of_parts = worker_map[worker_addr] + list_of_keys = [part.key for part in list_of_parts] + f = await client.submit(func, worker_id=wid, + list_of_keys=dask.delayed(list_of_keys), + list_of_parts=list_of_parts, + pure=False, workers=[worker_addr]) futures.append(f) # Get delayed objects results = await client.gather(futures) - results = [t for l in results for t in l] # flatten into 1 dim list + # flatten into 1 dim list + results = [t for list_per_worker in results for t in list_per_worker] # sort by order, l[0] is the delayed object, l[1] is its order results = sorted(results, key=lambda l: l[1]) results = [predt for predt, order in results] # remove order @@ -1144,6 +1155,7 @@ def predict(self, data, output_margin=False, base_margin=None): 'Implementation of the scikit-learn API for XGBoost classification.', ['estimators', 'model']) class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): + # pylint: disable=missing-class-docstring async def _fit_async(self, X, y, sample_weight, base_margin, eval_set, sample_weight_eval_set, early_stopping_rounds, verbose): @@ -1215,7 +1227,8 @@ async def _predict_proba_async(self, data, output_margin=False, output_margin=output_margin) return pred_probs - def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring + # pylint: disable=arguments-differ,missing-docstring + def predict_proba(self, data, output_margin=False, base_margin=None): _assert_dask_support() return self.client.sync( self._predict_proba_async, @@ -1241,7 +1254,8 @@ async def _predict_async(self, data, output_margin=False, base_margin=None): return preds - def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ + # pylint: disable=arguments-differ + def predict(self, data, output_margin=False, base_margin=None): _assert_dask_support() return self.client.sync( self._predict_async, diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index a06bfc28361f..a0bafd2ef539 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -15,6 +15,7 @@ sys.path.append("tests/python") from test_with_dask import run_empty_dmatrix_reg # noqa from test_with_dask import run_empty_dmatrix_cls # noqa +from test_with_dask import _get_client_workers # noqa from test_with_dask import generate_array # noqa import testing as tm # noqa @@ -217,7 +218,7 @@ def runit(worker_addr, rabit_args): return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE) with Client(local_cuda_cluster) as client: - workers = list(dxgb._get_client_workers(client).keys()) + workers = list(_get_client_workers(client).keys()) rabit_args = client.sync(dxgb._get_rabit_args, workers, client) futures = client.map(runit, workers, diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 450d1b40dc03..93a62ea56d97 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -23,11 +23,12 @@ def test_rabit_tracker(): def run_rabit_ops(client, n_workers): - from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers + from test_with_dask import _get_client_workers + from xgboost.dask import RabitContext, _get_rabit_args from xgboost import rabit workers = list(_get_client_workers(client).keys()) - rabit_args = client.sync(_get_rabit_args, workers, client) + rabit_args = client.sync(_get_rabit_args, len(workers), client) assert not rabit.is_distributed() n_workers_from_dask = len(workers) assert n_workers == n_workers_from_dask diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 9eff943b5783..d8643251275b 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -41,6 +41,11 @@ kWorkers = 5 +def _get_client_workers(client): + workers = client.scheduler_info()['workers'] + return workers + + def generate_array(with_weights=False): partition_size = 20 X = da.random.random((kRows, kCols), partition_size) @@ -704,9 +709,9 @@ def runit(worker_addr, rabit_args): with LocalCluster(n_workers=4) as cluster: with Client(cluster) as client: - workers = list(xgb.dask._get_client_workers(client).keys()) + workers = list(_get_client_workers(client).keys()) rabit_args = client.sync( - xgb.dask._get_rabit_args, workers, client) + xgb.dask._get_rabit_args, len(workers), client) futures = client.map(runit, workers, pure=False, @@ -750,7 +755,6 @@ def test_early_stopping(self, client): num_boost_round=1000, early_stopping_rounds=early_stopping_rounds)['booster'] assert hasattr(booster, 'best_score') - assert booster.best_iteration == 10 dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 @@ -783,20 +787,22 @@ def test_data_initialization(self): X, y = generate_array() n_partitions = X.npartitions m = xgb.dask.DaskDMatrix(client, X, y) - workers = list(xgb.dask._get_client_workers(client).keys()) - rabit_args = client.sync(xgb.dask._get_rabit_args, workers, client) + workers = list(_get_client_workers(client).keys()) + rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client) n_workers = len(workers) def worker_fn(worker_addr, data_ref): with xgb.dask.RabitContext(rabit_args): - local_dtrain = xgb.dask._dmatrix_from_worker_map(**data_ref) + local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref) total = np.array([local_dtrain.num_row()]) total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM) assert total[0] == kRows - futures = client.map( - worker_fn, workers, [m.create_fn_args()] * len(workers), - pure=False, workers=workers) + futures = [] + for i in range(len(workers)): + futures.append(client.submit(worker_fn, workers[i], + m.create_fn_args(workers[i]), pure=False, + workers=[workers[i]])) client.gather(futures) has_what = client.has_what()