Skip to content

Commit

Permalink
scatter and replicate to avoid paused workers
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 19, 2021
1 parent 8bf6b0e commit 253d397
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 23 deletions.
5 changes: 3 additions & 2 deletions distributed/client.py
Expand Up @@ -2048,9 +2048,10 @@ async def _scatter(
await asyncio.sleep(0.1)
if time() > start + timeout:
raise TimeoutError("No valid workers found")
nthreads = await self.scheduler.ncores(workers=workers)
# Exclude paused and closing_gracefully workers
nthreads = await self.scheduler.ncores_running(workers=workers)
if not nthreads:
raise ValueError("No valid workers")
raise ValueError("No valid workers found")

_, who_has, nbytes = await scatter_to_workers(
nthreads, data2, report=False, rpc=self.rpc
Expand Down
46 changes: 32 additions & 14 deletions distributed/scheduler.py
Expand Up @@ -3826,6 +3826,7 @@ def __init__(
"broadcast": self.broadcast,
"proxy": self.proxy,
"ncores": self.get_ncores,
"ncores_running": self.get_ncores_running,
"has_what": self.get_has_what,
"who_has": self.get_who_has,
"processing": self.get_processing,
Expand Down Expand Up @@ -5709,18 +5710,24 @@ async def scatter(
Scheduler.broadcast:
"""
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState

start = time()
while not parent._workers_dv:
await asyncio.sleep(0.2)
while True:
if workers is None:
wss = parent._running
else:
workers = [self.coerce_address(w) for w in workers]
wss = {parent._workers_dv[w] for w in workers}
wss = {ws for ws in wss if ws._status == Status.running}

if wss:
break
if time() > start + timeout:
raise TimeoutError("No workers found")
raise TimeoutError("No valid workers found")
await asyncio.sleep(0.1)

if workers is None:
ws: WorkerState
nthreads = {w: ws._nthreads for w, ws in parent._workers_dv.items()}
else:
workers = [self.coerce_address(w) for w in workers]
nthreads = {w: parent._workers_dv[w].nthreads for w in workers}
nthreads = {ws._address: ws.nthreads for ws in wss}

assert isinstance(data, dict)

Expand All @@ -5731,10 +5738,7 @@ async def scatter(
self.update_data(who_has=who_has, nbytes=nbytes, client=client)

if broadcast:
if broadcast == True: # noqa: E712
n = len(nthreads)
else:
n = broadcast
n = len(nthreads) if broadcast is True else broadcast
await self.replicate(keys=keys, workers=workers, n=n)

self.log_event(
Expand Down Expand Up @@ -6450,7 +6454,12 @@ async def replicate(

assert branching_factor > 0
async with self._lock if lock else empty_context:
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
if workers is not None:
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
workers = {ws for ws in workers if ws._status == Status.running}
else:
workers = parent._running

if n is None:
n = len(workers)
else:
Expand Down Expand Up @@ -6988,6 +6997,15 @@ def get_ncores(self, comm=None, workers=None):
else:
return {w: ws._nthreads for w, ws in parent._workers_dv.items()}

def get_ncores_running(self, comm=None, workers=None):
parent: SchedulerState = cast(SchedulerState, self)
ncores = self.get_ncores(workers=workers)
return {
w: n
for w, n in ncores.items()
if parent._workers_dv[w].status == Status.running
}

async def get_call_stack(self, comm=None, keys=None):
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState
Expand Down
27 changes: 27 additions & 0 deletions distributed/tests/test_client.py
Expand Up @@ -5734,6 +5734,33 @@ def bad_fn(x):
assert y.status == "error" # not cancelled


@pytest.mark.parametrize("workers_arg", [False, True])
@pytest.mark.parametrize("direct", [False, True])
@pytest.mark.parametrize("broadcast", [False, True, 10])
@gen_cluster(client=True, nthreads=[("", 1)] * 10)
async def test_scatter_and_replicate_avoid_paused_workers(
c, s, *workers, workers_arg, direct, broadcast
):
paused_workers = [w for i, w in enumerate(workers) if i not in (3, 7)]
for w in paused_workers:
w.memory_pause_fraction = 1e-15
while any(s.workers[w.address].status != Status.paused for w in paused_workers):
await asyncio.sleep(0.01)

f = await c.scatter(
{"x": 1},
workers=[w.address for w in workers[1:-1]] if workers_arg else None,
broadcast=broadcast,
direct=direct,
)
if not broadcast:
await c.replicate(f, n=10)

expect = [i in (3, 7) for i in range(10)]
actual = [("x" in w.data) for w in workers]
assert actual == expect


@pytest.mark.xfail(reason="GH#5409 Dask-Default-Threads are frequently detected")
def test_no_threads_lingering():
if threading.active_count() < 40:
Expand Down
21 changes: 14 additions & 7 deletions distributed/tests/test_scheduler.py
Expand Up @@ -784,20 +784,27 @@ async def test_story(c, s, a, b):
assert s.story(x.key) == s.story(s.tasks[x.key])


@gen_cluster(nthreads=[], client=True)
async def test_scatter_no_workers(c, s):
@pytest.mark.parametrize("direct", [False, True])
@gen_cluster(client=True, nthreads=[])
async def test_scatter_no_workers(c, s, direct):
with pytest.raises(TimeoutError):
await s.scatter(data={"x": 1}, client="alice", timeout=0.1)

start = time()
with pytest.raises(TimeoutError):
await c.scatter(123, timeout=0.1)
await c.scatter(123, timeout=0.1, direct=direct)
assert time() < start + 1.5

w = Worker(s.address, nthreads=3)
await asyncio.gather(c.scatter(data={"y": 2}, timeout=5), w)

assert w.data["y"] == 2
fut = c.scatter({"y": 2}, timeout=5, direct=direct)
await asyncio.sleep(0.1)
async with Worker(s.address) as w:
await fut
assert w.data["y"] == 2

# Test race condition between worker init and scatter
w = Worker(s.address)
await asyncio.gather(c.scatter({"z": 3}, timeout=5, direct=direct), w)
assert w.data["z"] == 3
await w.close()


Expand Down

0 comments on commit 253d397

Please sign in to comment.