Skip to content

Commit

Permalink
ensure workers do not kill on restart
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Apr 12, 2024
1 parent 66ced13 commit 6fa046b
Show file tree
Hide file tree
Showing 13 changed files with 381 additions and 373 deletions.
4 changes: 3 additions & 1 deletion distributed/comm/tcp.py
Expand Up @@ -665,7 +665,9 @@ async def _handle_stream(self, stream, address):
try:
await self.on_connection(comm)
except CommClosedError:
logger.info("Connection from %s closed before handshake completed", address)
logger.debug(
"Connection from %s closed before handshake completed", address
)
return

await self.comm_handler(comm)
Expand Down
49 changes: 43 additions & 6 deletions distributed/deploy/spec.py
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Awaitable, Generator
from contextlib import suppress
from inspect import isawaitable
from time import time
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

from tornado import gen
Expand Down Expand Up @@ -389,28 +390,64 @@ async def _correct_state_internal(self) -> None:
# proper teardown.
await asyncio.gather(*worker_futs)

def _update_worker_status(self, op, msg):
def _update_worker_status(self, op, worker_addr):
if op == "remove":
name = self.scheduler_info["workers"][msg]["name"]
worker_info = self.scheduler_info["workers"][worker_addr].copy()
name = worker_info["name"]

from distributed import Nanny, Worker

def f():
# FIXME: SpecCluster is tracking workers by `name`` which are
# not necessarily unique.
# Clusters with Nannies (default) are susceptible to falsely
# removing the Nannies on restart due to this logic since the
# restart emits a op==remove signal on the worker address but
# the SpecCluster only tracks the names, i.e. after
# `lost-worker-timeout` the Nanny is still around and this logic
# could trigger a false close. The below code should handle this
# but it would be cleaner if the cluster tracked by address
# instead of name just like the scheduler does
if (
name in self.workers
and msg not in self.scheduler_info["workers"]
and worker_addr not in self.scheduler_info["workers"]
and not any(
d["name"] == name
for d in self.scheduler_info["workers"].values()
)
):
self._futures.add(asyncio.ensure_future(self.workers[name].close()))
del self.workers[name]
w = self.workers[name]

async def remove_worker():
await w.close(reason=f"lost-worker-timeout-{time()}")
self.workers.pop(name, None)

if (
worker_info["type"] == "Worker"
and (isinstance(w, Nanny) and w.worker_address == worker_addr)
or (isinstance(w, Worker) and w.address == worker_addr)
):
self._futures.add(
asyncio.create_task(
remove_worker(),
name="remove-worker-lost-worker-timeout",
)
)
elif worker_info["type"] == "Nanny":
# This should never happen
logger.critical(
"Unespected signal encountered. WorkerStatusPlugin "
"emitted a op==remove signal for a Nanny which "
"should not happen. This might cause a lingering "
"Nanny process."
)

delay = parse_timedelta(
dask.config.get("distributed.deploy.lost-worker-timeout")
)

asyncio.get_running_loop().call_later(delay, f)
super()._update_worker_status(op, msg)
super()._update_worker_status(op, worker_addr)

def __await__(self: Self) -> Generator[Any, Any, Self]:
async def _() -> Self:
Expand Down
15 changes: 15 additions & 0 deletions distributed/deploy/tests/test_local.py
Expand Up @@ -11,6 +11,7 @@
import pytest
from tornado.httpclient import AsyncHTTPClient

import dask
from dask.system import CPU_COUNT

from distributed import Client, LocalCluster, Nanny, Worker, get_client
Expand Down Expand Up @@ -1285,3 +1286,17 @@ def test_localcluster_get_client(loop):
with Client(cluster) as client2:
assert client1 != client2
assert client2 == cluster.get_client()


@pytest.mark.slow()
def test_localcluster_restart(loop):
with (
dask.config.set({"distributed.deploy.lost-worker-timeout": "0.5s"}),
LocalCluster(asynchronous=False, dashboard_address=":0", loop=loop) as cluster,
cluster.get_client() as client,
):
nworkers = len(client.run(lambda: None))
for _ in range(10):
assert len(client.run(lambda: None)) == nworkers
client.restart()
assert len(client.run(lambda: None)) == nworkers

0 comments on commit 6fa046b

Please sign in to comment.