Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure workers do not kill on restart #8611

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion distributed/cli/dask_worker.py
Expand Up @@ -417,7 +417,11 @@ async def run():

async def wait_for_nannies_to_finish():
"""Wait for all nannies to initialize and finish"""
await asyncio.gather(*nannies)
try:
await asyncio.gather(*nannies)
except Exception:
if not signal_fired:
raise
await asyncio.gather(*(n.finished() for n in nannies))

async def wait_for_signals_and_close():
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/tcp.py
Expand Up @@ -665,7 +665,9 @@ async def _handle_stream(self, stream, address):
try:
await self.on_connection(comm)
except CommClosedError:
logger.info("Connection from %s closed before handshake completed", address)
logger.debug(
"Connection from %s closed before handshake completed", address
)
return

await self.comm_handler(comm)
Expand Down
49 changes: 43 additions & 6 deletions distributed/deploy/spec.py
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Awaitable, Generator
from contextlib import suppress
from inspect import isawaitable
from time import time
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

from tornado import gen
Expand Down Expand Up @@ -389,28 +390,64 @@
# proper teardown.
await asyncio.gather(*worker_futs)

def _update_worker_status(self, op, msg):
def _update_worker_status(self, op, worker_addr):
if op == "remove":
name = self.scheduler_info["workers"][msg]["name"]
worker_info = self.scheduler_info["workers"][worker_addr].copy()
name = worker_info["name"]

from distributed import Nanny, Worker

def f():
# FIXME: SpecCluster is tracking workers by `name`` which are
# not necessarily unique.
# Clusters with Nannies (default) are susceptible to falsely
# removing the Nannies on restart due to this logic since the
# restart emits a op==remove signal on the worker address but
# the SpecCluster only tracks the names, i.e. after
# `lost-worker-timeout` the Nanny is still around and this logic
# could trigger a false close. The below code should handle this
# but it would be cleaner if the cluster tracked by address
# instead of name just like the scheduler does
if (
name in self.workers
and msg not in self.scheduler_info["workers"]
and worker_addr not in self.scheduler_info["workers"]
and not any(
d["name"] == name
for d in self.scheduler_info["workers"].values()
)
):
self._futures.add(asyncio.ensure_future(self.workers[name].close()))
del self.workers[name]
w = self.workers[name]

async def remove_worker():
await w.close(reason=f"lost-worker-timeout-{time()}")
self.workers.pop(name, None)

if (
worker_info["type"] == "Worker"
and (isinstance(w, Nanny) and w.worker_address == worker_addr)
or (isinstance(w, Worker) and w.address == worker_addr)
):
self._futures.add(
asyncio.create_task(
remove_worker(),
name="remove-worker-lost-worker-timeout",
)
)
elif worker_info["type"] == "Nanny":
# This should never happen
logger.critical(

Check warning on line 438 in distributed/deploy/spec.py

View check run for this annotation

Codecov / codecov/patch

distributed/deploy/spec.py#L438

Added line #L438 was not covered by tests
"Unespected signal encountered. WorkerStatusPlugin "
"emitted a op==remove signal for a Nanny which "
"should not happen. This might cause a lingering "
"Nanny process."
)

delay = parse_timedelta(
dask.config.get("distributed.deploy.lost-worker-timeout")
)

asyncio.get_running_loop().call_later(delay, f)
super()._update_worker_status(op, msg)
super()._update_worker_status(op, worker_addr)

def __await__(self: Self) -> Generator[Any, Any, Self]:
async def _() -> Self:
Expand Down
15 changes: 15 additions & 0 deletions distributed/deploy/tests/test_local.py
Expand Up @@ -11,6 +11,7 @@
import pytest
from tornado.httpclient import AsyncHTTPClient

import dask
from dask.system import CPU_COUNT

from distributed import Client, LocalCluster, Nanny, Worker, get_client
Expand Down Expand Up @@ -1285,3 +1286,17 @@ def test_localcluster_get_client(loop):
with Client(cluster) as client2:
assert client1 != client2
assert client2 == cluster.get_client()


@pytest.mark.slow()
def test_localcluster_restart(loop):
with (
dask.config.set({"distributed.deploy.lost-worker-timeout": "0.5s"}),
LocalCluster(asynchronous=False, dashboard_address=":0", loop=loop) as cluster,
cluster.get_client() as client,
):
nworkers = len(client.run(lambda: None))
for _ in range(10):
assert len(client.run(lambda: None)) == nworkers
client.restart()
assert len(client.run(lambda: None)) == nworkers