diff --git a/distributed/client.py b/distributed/client.py index cf370ec4692..9616ae43401 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2048,9 +2048,10 @@ async def _scatter( await asyncio.sleep(0.1) if time() > start + timeout: raise TimeoutError("No valid workers found") - nthreads = await self.scheduler.ncores(workers=workers) + # Exclude paused and closing_gracefully workers + nthreads = await self.scheduler.ncores_running(workers=workers) if not nthreads: - raise ValueError("No valid workers") + raise ValueError("No valid workers found") _, who_has, nbytes = await scatter_to_workers( nthreads, data2, report=False, rpc=self.rpc diff --git a/distributed/scheduler.py b/distributed/scheduler.py index acc8a4ac1d9..70fd5852e98 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3826,6 +3826,7 @@ def __init__( "broadcast": self.broadcast, "proxy": self.proxy, "ncores": self.get_ncores, + "ncores_running": self.get_ncores_running, "has_what": self.get_has_what, "who_has": self.get_who_has, "processing": self.get_processing, @@ -5709,18 +5710,24 @@ async def scatter( Scheduler.broadcast: """ parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + start = time() - while not parent._workers_dv: - await asyncio.sleep(0.2) + while True: + if workers is None: + wss = parent._running + else: + workers = [self.coerce_address(w) for w in workers] + wss = {parent._workers_dv[w] for w in workers} + wss = {ws for ws in wss if ws._status == Status.running} + + if wss: + break if time() > start + timeout: - raise TimeoutError("No workers found") + raise TimeoutError("No valid workers found") + await asyncio.sleep(0.1) - if workers is None: - ws: WorkerState - nthreads = {w: ws._nthreads for w, ws in parent._workers_dv.items()} - else: - workers = [self.coerce_address(w) for w in workers] - nthreads = {w: parent._workers_dv[w].nthreads for w in workers} + nthreads = {ws._address: ws.nthreads for ws in wss} assert isinstance(data, dict) @@ -5731,10 +5738,7 @@ async def scatter( self.update_data(who_has=who_has, nbytes=nbytes, client=client) if broadcast: - if broadcast == True: # noqa: E712 - n = len(nthreads) - else: - n = broadcast + n = len(nthreads) if broadcast is True else broadcast await self.replicate(keys=keys, workers=workers, n=n) self.log_event( @@ -6450,7 +6454,12 @@ async def replicate( assert branching_factor > 0 async with self._lock if lock else empty_context: - workers = {parent._workers_dv[w] for w in self.workers_list(workers)} + if workers is not None: + workers = {parent._workers_dv[w] for w in self.workers_list(workers)} + workers = {ws for ws in workers if ws._status == Status.running} + else: + workers = parent._running + if n is None: n = len(workers) else: @@ -6988,6 +6997,15 @@ def get_ncores(self, comm=None, workers=None): else: return {w: ws._nthreads for w, ws in parent._workers_dv.items()} + def get_ncores_running(self, comm=None, workers=None): + parent: SchedulerState = cast(SchedulerState, self) + ncores = self.get_ncores(workers=workers) + return { + w: n + for w, n in ncores.items() + if parent._workers_dv[w].status == Status.running + } + async def get_call_stack(self, comm=None, keys=None): parent: SchedulerState = cast(SchedulerState, self) ts: TaskState diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2c4a3130f38..6d9f366232d 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5734,6 +5734,33 @@ def bad_fn(x): assert y.status == "error" # not cancelled +@pytest.mark.parametrize("workers_arg", [False, True]) +@pytest.mark.parametrize("direct", [False, True]) +@pytest.mark.parametrize("broadcast", [False, True, 10]) +@gen_cluster(client=True, nthreads=[("", 1)] * 10) +async def test_scatter_and_replicate_avoid_paused_workers( + c, s, *workers, workers_arg, direct, broadcast +): + paused_workers = [w for i, w in enumerate(workers) if i not in (3, 7)] + for w in paused_workers: + w.memory_pause_fraction = 1e-15 + while any(s.workers[w.address].status != Status.paused for w in paused_workers): + await asyncio.sleep(0.01) + + f = await c.scatter( + {"x": 1}, + workers=[w.address for w in workers[1:-1]] if workers_arg else None, + broadcast=broadcast, + direct=direct, + ) + if not broadcast: + await c.replicate(f, n=10) + + expect = [i in (3, 7) for i in range(10)] + actual = [("x" in w.data) for w in workers] + assert actual == expect + + @pytest.mark.xfail(reason="GH#5409 Dask-Default-Threads are frequently detected") def test_no_threads_lingering(): if threading.active_count() < 40: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9f89b9ba2b5..dcd4233649c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -784,20 +784,27 @@ async def test_story(c, s, a, b): assert s.story(x.key) == s.story(s.tasks[x.key]) -@gen_cluster(nthreads=[], client=True) -async def test_scatter_no_workers(c, s): +@pytest.mark.parametrize("direct", [False, True]) +@gen_cluster(client=True, nthreads=[]) +async def test_scatter_no_workers(c, s, direct): with pytest.raises(TimeoutError): await s.scatter(data={"x": 1}, client="alice", timeout=0.1) start = time() with pytest.raises(TimeoutError): - await c.scatter(123, timeout=0.1) + await c.scatter(123, timeout=0.1, direct=direct) assert time() < start + 1.5 - w = Worker(s.address, nthreads=3) - await asyncio.gather(c.scatter(data={"y": 2}, timeout=5), w) - - assert w.data["y"] == 2 + fut = c.scatter({"y": 2}, timeout=5, direct=direct) + await asyncio.sleep(0.1) + async with Worker(s.address) as w: + await fut + assert w.data["y"] == 2 + + # Test race condition between worker init and scatter + w = Worker(s.address) + await asyncio.gather(c.scatter({"z": 3}, timeout=5, direct=direct), w) + assert w.data["z"] == 3 await w.close()