Skip to content

Commit

Permalink
[dask] Fix asyncio. (#7508) (#7561)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 13, 2022
1 parent afb9dfd commit 3e2d751
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python-package/xgboost/dask.py
Expand Up @@ -1606,8 +1606,9 @@ def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
should use `worker_client' instead of default client.
"""
asynchronous = getattr(self, "_asynchronous", False)

if self._client is None:
asynchronous = getattr(self, "_asynchronous", False)
try:
distributed.get_worker()
in_worker = True
Expand All @@ -1620,7 +1621,7 @@ def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
return ret
return ret

return self.client.sync(func, **kwargs, asynchronous=asynchronous)
return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous)


@xgboost_model_doc(
Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_with_dask.py
Expand Up @@ -705,8 +705,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR
async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
async with Client(scheduler_address, asynchronous=True) as client:
X, y, _ = generate_array()
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
n_estimators=2)
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method='hist')
regressor.client = client
await regressor.fit(X, y, eval_set=[(X, y)])
Expand Down

0 comments on commit 3e2d751

Please sign in to comment.