Skip to content

Commit

Permalink
[dask] Use nthread in DMatrix construction. (#7337)
Browse files Browse the repository at this point in the history
This is consistent with the thread overriding behavior.
  • Loading branch information
trivialfis committed Oct 20, 2021
1 parent b8e8f0f commit f999897
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 33 deletions.
71 changes: 40 additions & 31 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ def _create_device_quantile_dmatrix(
feature_weights: Optional[Any],
meta_names: List[str],
missing: float,
nthread: int,
parts: Optional[_DataParts],
max_bin: int,
enable_categorical: bool,
Expand Down Expand Up @@ -717,7 +718,7 @@ def _create_device_quantile_dmatrix(
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads,
nthread=nthread,
max_bin=max_bin,
enable_categorical=enable_categorical,
)
Expand All @@ -731,6 +732,7 @@ def _create_dmatrix(
feature_weights: Optional[Any],
meta_names: List[str],
missing: float,
nthread: int,
enable_categorical: bool,
parts: Optional[_DataParts]
) -> DMatrix:
Expand Down Expand Up @@ -778,7 +780,7 @@ def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]:
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads,
nthread=nthread,
enable_categorical=enable_categorical,
)
dmatrix.set_info(
Expand Down Expand Up @@ -856,46 +858,53 @@ def dispatched_train(
rabit_args: List[bytes],
dtrain_ref: Dict,
dtrain_idt: int,
evals_ref: Dict
evals_ref: List[Tuple[Dict, str, int]],
) -> Optional[Dict[str, Union[Booster, Dict]]]:
'''Perform training on a single worker. A local function prevents pickling.
'''
LOGGER.debug('Training on %s', str(worker_addr))
"""Perform training on a single worker. A local function prevents pickling."""
LOGGER.debug("Training on %s", str(worker_addr))
worker = distributed.get_worker()

n_threads: int = 0
local_param = params.copy()
for p in ["nthread", "n_jobs"]:
if local_param.get(p, worker.nthreads) != worker.nthreads:
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0:
n_threads = worker.nthreads
local_param.update({"nthread": n_threads, "n_jobs": n_threads})

with RabitContext(rabit_args), config.config_context(**global_config):
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref, nthread=n_threads)
local_evals = []
if evals_ref:
for ref, name, idt in evals_ref:
if idt == dtrain_idt:
local_evals.append((local_dtrain, name))
continue
local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
local_evals.append(
(_dmatrix_from_list_of_parts(**ref, nthread=n_threads), name)
)

local_history: Dict = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
override = ['nthread', 'n_jobs']
for p in override:
val = local_param.get(p, None)
if val is not None and val != worker.nthreads:
LOGGER.info(msg)
else:
local_param[p] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
num_boost_round=num_boost_round,
evals_result=local_history,
evals=local_evals,
obj=obj,
feval=feval,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
bst = worker_train(
params=local_param,
dtrain=local_dtrain,
num_boost_round=num_boost_round,
evals_result=local_history,
evals=local_evals,
obj=obj,
feval=feval,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks,
)
ret: Optional[Dict[str, Union[Booster, Dict]]] = {
'booster': bst, 'history': local_history}
"booster": bst,
"history": local_history,
}
if local_dtrain.num_row() == 0:
ret = None
return ret
Expand Down Expand Up @@ -924,7 +933,7 @@ def dispatched_train(
evals_per_worker,
pure=False,
workers=[worker_addr],
allow_other_workers=False
allow_other_workers=False,
)
futures.append(f)

Expand Down
2 changes: 1 addition & 1 deletion tests/python-gpu/test_gpu_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None

def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()

Expand Down
4 changes: 3 additions & 1 deletion tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,9 @@ def test_no_duplicated_partition(self) -> None:

def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with xgb.dask.RabitContext(rabit_args):
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
**data_ref, nthread=7
)
total = np.array([local_dtrain.num_row()])
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
assert total[0] == kRows
Expand Down

0 comments on commit f999897

Please sign in to comment.