From 05497a91419e35ee915f0811c475c853882c0a8f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 13 Dec 2021 01:48:25 +0800 Subject: [PATCH] [dask] Fix asyncio. (#7508) --- python-package/xgboost/dask.py | 5 +++-- tests/python/test_with_dask.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 6ae6d7bd959a..2f5b732bcf3b 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1622,8 +1622,9 @@ def _client_sync(self, func: Callable, **kwargs: Any) -> Any: should use `worker_client' instead of default client. """ - asynchronous = getattr(self, "_asynchronous", False) + if self._client is None: + asynchronous = getattr(self, "_asynchronous", False) try: distributed.get_worker() in_worker = True @@ -1636,7 +1637,7 @@ def _client_sync(self, func: Callable, **kwargs: Any) -> Any: return ret return ret - return self.client.sync(func, **kwargs, asynchronous=asynchronous) + return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous) @xgboost_model_doc( diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index dfca5b206511..2602529cec85 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -751,8 +751,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR async def run_dask_regressor_asyncio(scheduler_address: str) -> None: async with Client(scheduler_address, asynchronous=True) as client: X, y, _ = generate_array() - regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, - n_estimators=2) + regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor.set_params(tree_method='hist') regressor.client = client await regressor.fit(X, y, eval_set=[(X, y)])