Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMM: scatter and replicate to avoid paused workers #5441

Merged
merged 3 commits into from Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions distributed/client.py
Expand Up @@ -2048,9 +2048,10 @@ async def _scatter(
await asyncio.sleep(0.1)
if time() > start + timeout:
raise TimeoutError("No valid workers found")
nthreads = await self.scheduler.ncores(workers=workers)
# Exclude paused and closing_gracefully workers
nthreads = await self.scheduler.ncores_running(workers=workers)
if not nthreads:
raise ValueError("No valid workers")
raise ValueError("No valid workers found")

_, who_has, nbytes = await scatter_to_workers(
nthreads, data2, report=False, rpc=self.rpc
Expand Down
46 changes: 32 additions & 14 deletions distributed/scheduler.py
Expand Up @@ -3826,6 +3826,7 @@ def __init__(
"broadcast": self.broadcast,
"proxy": self.proxy,
"ncores": self.get_ncores,
"ncores_running": self.get_ncores_running,
"has_what": self.get_has_what,
"who_has": self.get_who_has,
"processing": self.get_processing,
Expand Down Expand Up @@ -5710,18 +5711,24 @@ async def scatter(
Scheduler.broadcast:
"""
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState

start = time()
while not parent._workers_dv:
await asyncio.sleep(0.2)
while True:
if workers is None:
wss = parent._running
else:
workers = [self.coerce_address(w) for w in workers]
wss = {parent._workers_dv[w] for w in workers}
wss = {ws for ws in wss if ws._status == Status.running}

if wss:
break
if time() > start + timeout:
raise TimeoutError("No workers found")
raise TimeoutError("No valid workers found")
await asyncio.sleep(0.1)

if workers is None:
ws: WorkerState
nthreads = {w: ws._nthreads for w, ws in parent._workers_dv.items()}
else:
workers = [self.coerce_address(w) for w in workers]
nthreads = {w: parent._workers_dv[w].nthreads for w in workers}
nthreads = {ws._address: ws.nthreads for ws in wss}

assert isinstance(data, dict)

Expand All @@ -5732,10 +5739,7 @@ async def scatter(
self.update_data(who_has=who_has, nbytes=nbytes, client=client)

if broadcast:
if broadcast == True: # noqa: E712
n = len(nthreads)
else:
n = broadcast
n = len(nthreads) if broadcast is True else broadcast
await self.replicate(keys=keys, workers=workers, n=n)

self.log_event(
Expand Down Expand Up @@ -6451,7 +6455,12 @@ async def replicate(

assert branching_factor > 0
async with self._lock if lock else empty_context:
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
if workers is not None:
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
workers = {ws for ws in workers if ws._status == Status.running}
else:
workers = parent._running

if n is None:
n = len(workers)
else:
Expand Down Expand Up @@ -6989,6 +6998,15 @@ def get_ncores(self, comm=None, workers=None):
else:
return {w: ws._nthreads for w, ws in parent._workers_dv.items()}

def get_ncores_running(self, comm=None, workers=None):
parent: SchedulerState = cast(SchedulerState, self)
ncores = self.get_ncores(workers=workers)
return {
w: n
for w, n in ncores.items()
if parent._workers_dv[w].status == Status.running
}

async def get_call_stack(self, comm=None, keys=None):
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState
Expand Down
27 changes: 27 additions & 0 deletions distributed/tests/test_client.py
Expand Up @@ -5734,6 +5734,33 @@ def bad_fn(x):
assert y.status == "error" # not cancelled


@pytest.mark.parametrize("workers_arg", [False, True])
@pytest.mark.parametrize("direct", [False, True])
@pytest.mark.parametrize("broadcast", [False, True, 10])
@gen_cluster(client=True, nthreads=[("", 1)] * 10)
async def test_scatter_and_replicate_avoid_paused_workers(
c, s, *workers, workers_arg, direct, broadcast
):
paused_workers = [w for i, w in enumerate(workers) if i not in (3, 7)]
for w in paused_workers:
w.memory_pause_fraction = 1e-15
while any(s.workers[w.address].status != Status.paused for w in paused_workers):
await asyncio.sleep(0.01)

f = await c.scatter(
{"x": 1},
workers=[w.address for w in workers[1:-1]] if workers_arg else None,
broadcast=broadcast,
direct=direct,
)
if not broadcast:
await c.replicate(f, n=10)

expect = [i in (3, 7) for i in range(10)]
actual = [("x" in w.data) for w in workers]
assert actual == expect


@pytest.mark.xfail(reason="GH#5409 Dask-Default-Threads are frequently detected")
def test_no_threads_lingering():
if threading.active_count() < 40:
Expand Down
21 changes: 14 additions & 7 deletions distributed/tests/test_scheduler.py
Expand Up @@ -784,20 +784,27 @@ async def test_story(c, s, a, b):
assert s.story(x.key) == s.story(s.tasks[x.key])


@gen_cluster(nthreads=[], client=True)
async def test_scatter_no_workers(c, s):
@pytest.mark.parametrize("direct", [False, True])
@gen_cluster(client=True, nthreads=[])
async def test_scatter_no_workers(c, s, direct):
with pytest.raises(TimeoutError):
await s.scatter(data={"x": 1}, client="alice", timeout=0.1)

start = time()
with pytest.raises(TimeoutError):
await c.scatter(123, timeout=0.1)
await c.scatter(123, timeout=0.1, direct=direct)
assert time() < start + 1.5

w = Worker(s.address, nthreads=3)
await asyncio.gather(c.scatter(data={"y": 2}, timeout=5), w)

assert w.data["y"] == 2
fut = c.scatter({"y": 2}, timeout=5, direct=direct)
await asyncio.sleep(0.1)
async with Worker(s.address) as w:
await fut
assert w.data["y"] == 2

# Test race condition between worker init and scatter
w = Worker(s.address)
await asyncio.gather(c.scatter({"z": 3}, timeout=5, direct=direct), w)
assert w.data["z"] == 3
await w.close()


Expand Down