Skip to content

Commit

Permalink
Remove report and safe from Worker.close (#6363)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed May 20, 2022
1 parent 41a54ee commit 9bb999d
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 83 deletions.
4 changes: 1 addition & 3 deletions distributed/chaos.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ async def setup(self, worker):
)

def graceful(self):
asyncio.create_task(
self.worker.close(report=False, nanny=False, executor_wait=False)
)
asyncio.create_task(self.worker.close(nanny=False, executor_wait=False))

def sys_exit(self):
sys.exit(0)
Expand Down
37 changes: 37 additions & 0 deletions distributed/cli/tests/test_dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from distributed.compatibility import LINUX, WINDOWS
from distributed.deploy.utils import nprocesses_nthreads
from distributed.metrics import time
from distributed.utils import open_port
from distributed.utils_test import gen_cluster, popen, requires_ipv6


Expand Down Expand Up @@ -713,3 +714,39 @@ async def test_signal_handling(c, s, nanny, sig):
assert "timed out" not in logs
assert "error" not in logs
assert "exception" not in logs


@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"])
def test_error_during_startup(monkeypatch, nanny):
# see https://github.com/dask/distributed/issues/6320
scheduler_port = str(open_port())
scheduler_addr = f"tcp://127.0.0.1:{scheduler_port}"

monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", scheduler_addr)
with popen(
[
"dask-scheduler",
"--port",
scheduler_port,
],
flush_output=False,
) as scheduler:
start = time()
# Wait for the scheduler to be up
while line := scheduler.stdout.readline():
if b"Scheduler at" in line:
break
# Ensure this is not killed by pytest-timeout
if time() - start > 5:
raise TimeoutError("Scheduler failed to start in time.")

with popen(
[
"dask-worker",
scheduler_addr,
nanny,
"--worker-port",
scheduler_port,
],
) as worker:
assert worker.wait(5) == 1
2 changes: 1 addition & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,7 @@ async def _shutdown(self):
else:
with suppress(CommClosedError):
self.status = "closing"
await self.scheduler.terminate(close_workers=True)
await self.scheduler.terminate()

def shutdown(self):
"""Shut down the connected scheduler and workers
Expand Down
2 changes: 1 addition & 1 deletion distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ async def _close(self):
if self.scheduler_comm:
async with self._lock:
with suppress(OSError):
await self.scheduler_comm.terminate(close_workers=True)
await self.scheduler_comm.terminate()
await self.scheduler_comm.close_rpc()
else:
logger.warning("Cluster closed without starting up")
Expand Down
24 changes: 22 additions & 2 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@

from dask.system import CPU_COUNT

from distributed import Client, Nanny, Worker, get_client
from distributed import Client, LocalCluster, Nanny, Worker, get_client
from distributed.compatibility import LINUX
from distributed.core import Status
from distributed.deploy.local import LocalCluster
from distributed.deploy.utils_test import ClusterTest
from distributed.metrics import time
from distributed.system import MEMORY_LIMIT
Expand All @@ -29,6 +28,7 @@
clean,
gen_test,
inc,
raises_with_cause,
slowinc,
tls_only_security,
xfail_ssl_issue5601,
Expand Down Expand Up @@ -1155,3 +1155,23 @@ async def test_connect_to_closed_cluster():
# Raises during init without actually connecting since we're not
# awaiting anything
Client(cluster, asynchronous=True)


class MyPlugin:
def setup(self, worker=None):
import my_nonexistent_library # noqa


@pytest.mark.slow
@gen_test(
clean_kwargs={
# FIXME: This doesn't close the LoopRunner properly, leaving a thread around
"threads": False
}
)
async def test_localcluster_start_exception():
with raises_with_cause(RuntimeError, None, ImportError, "my_nonexistent_library"):
async with LocalCluster(
plugins={MyPlugin()},
):
return
2 changes: 1 addition & 1 deletion distributed/diagnostics/tests/test_cluster_dump_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_cluster_dump_plugin(c, s, *workers, tmp_path):
f2 = c.submit(inc, f1)

assert (await f2) == 3
await s.close(close_workers=True)
await s.close()

dump = DumpArtefact.from_url(str(dump_file))
assert {f1.key, f2.key} == set(dump.scheduler_story(f1.key, f2.key).keys())
Expand Down
12 changes: 4 additions & 8 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ async def plugin_remove(self, name=None):

return {"status": "OK"}

async def restart(self, timeout=30, executor_wait=True):
async def restart(self, timeout=30):
async def _():
if self.process is not None:
await self.kill()
Expand Down Expand Up @@ -556,7 +556,7 @@ def close_gracefully(self):
"""
self.status = Status.closing_gracefully

async def close(self, comm=None, timeout=5, report=None):
async def close(self, timeout=5):
"""
Close the worker process, stop all comms.
"""
Expand All @@ -569,9 +569,8 @@ async def close(self, comm=None, timeout=5, report=None):

self.status = Status.closing
logger.info(
"Closing Nanny at %r. Report closure to scheduler: %s",
"Closing Nanny at %r.",
self.address_safe,
report,
)

for preload in self.preloads:
Expand All @@ -594,9 +593,8 @@ async def close(self, comm=None, timeout=5, report=None):
self.process = None
await self.rpc.close()
self.status = Status.closed
if comm:
await comm.write("OK")
await super().close()
return "OK"

async def _log_event(self, topic, msg):
await self.scheduler.log_event(
Expand Down Expand Up @@ -837,9 +835,7 @@ def _run(
async def do_stop(timeout=5, executor_wait=True):
try:
await worker.close(
report=True,
nanny=False,
safe=True, # TODO: Graceful or not?
executor_wait=executor_wait,
timeout=timeout,
)
Expand Down
49 changes: 18 additions & 31 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3347,7 +3347,7 @@ def del_scheduler_file():
setproctitle(f"dask-scheduler [{self.address}]")
return self

async def close(self, fast=False, close_workers=False):
async def close(self):
"""Send cleanup signal to all coroutines then wait until finished
See Also
Expand All @@ -3370,19 +3370,6 @@ async def close(self, fast=False, close_workers=False):
for preload in self.preloads:
await preload.teardown()

if close_workers:
await self.broadcast(msg={"op": "close_gracefully"}, nanny=True)
for worker in self.workers:
# Report would require the worker to unregister with the
# currently closing scheduler. This is not necessary and might
# delay shutdown of the worker unnecessarily
self.worker_send(worker, {"op": "close", "report": False})
for i in range(20): # wait a second for send signals to clear
if self.workers:
await asyncio.sleep(0.05)
else:
break

await asyncio.gather(
*[plugin.close() for plugin in list(self.plugins.values())]
)
Expand All @@ -3399,15 +3386,16 @@ async def close(self, fast=False, close_workers=False):
logger.info("Scheduler closing all comms")

futures = []
for w, comm in list(self.stream_comms.items()):
for _, comm in list(self.stream_comms.items()):
if not comm.closed():
comm.send({"op": "close", "report": False})
# This closes the Worker and ensures that if a Nanny is around,
# it is closed as well
comm.send({"op": "terminate"})
comm.send({"op": "close-stream"})
with suppress(AttributeError):
futures.append(comm.close())

for future in futures: # TODO: do all at once
await future
await asyncio.gather(*futures)

for comm in self.client_comms.values():
comm.abort()
Expand All @@ -3431,8 +3419,8 @@ async def close_worker(self, worker: str, stimulus_id: str, safe: bool = False):
"""
logger.info("Closing worker %s", worker)
self.log_event(worker, {"action": "close-worker"})
# FIXME: This does not handle nannies
self.worker_send(worker, {"op": "close", "report": False})
ws = self.workers[worker]
self.worker_send(worker, {"op": "close", "nanny": bool(ws.nanny)})
await self.remove_worker(address=worker, safe=safe, stimulus_id=stimulus_id)

###########
Expand Down Expand Up @@ -4183,7 +4171,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):
logger.info("Remove worker %s", ws)
if close:
with suppress(AttributeError, CommClosedError):
self.stream_comms[address].send({"op": "close", "report": False})
self.stream_comms[address].send({"op": "close"})

self.remove_resources(address)

Expand Down Expand Up @@ -4744,7 +4732,7 @@ def handle_long_running(
ws.long_running.add(ts)
self.check_idle_saturated(ws)

def handle_worker_status_change(
async def handle_worker_status_change(
self, status: str, worker: str, stimulus_id: str
) -> None:
ws = self.workers.get(worker)
Expand Down Expand Up @@ -4772,9 +4760,12 @@ def handle_worker_status_change(
worker_msgs: dict = {}
self._transitions(recs, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)

else:
self.running.discard(ws)
elif ws.status == Status.paused:
self.running.remove(ws)
elif ws.status == Status.closing:
await self.remove_worker(
address=ws.address, stimulus_id=stimulus_id, close=False
)

async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Expand Down Expand Up @@ -5102,12 +5093,7 @@ async def restart(self, client=None, timeout=30):
]

resps = All(
[
nanny.restart(
close=True, timeout=timeout * 0.8, executor_wait=False
)
for nanny in nannies
]
[nanny.restart(close=True, timeout=timeout * 0.8) for nanny in nannies]
)
try:
resps = await asyncio.wait_for(resps, timeout)
Expand Down Expand Up @@ -6000,6 +5986,7 @@ async def retire_workers(
prev_status = ws.status
ws.status = Status.closing_gracefully
self.running.discard(ws)
# FIXME: We should send a message to the nanny first.
self.stream_comms[ws.address].send(
{
"op": "worker-status-change",
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3655,7 +3655,7 @@ async def hard_stop(s):
except CancelledError:
break

await w.close(report=False)
await w.close()
await c._close(fast=True)


Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def remove_worker(self, **kwargs):


@gen_cluster(client=True, nthreads=[])
async def test_nanny_closes_cleanly_2(c, s):
async def test_nanny_closes_cleanly_if_worker_is_terminated(c, s):
async with Nanny(s.address) as n:
async with c.rpc(n.worker_address) as w:
IOLoop.current().add_callback(w.terminate)
Expand Down
22 changes: 16 additions & 6 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,10 +1758,12 @@ async def test_result_type(c, s, a, b):


@gen_cluster()
async def test_close_workers(s, a, b):
await s.close(close_workers=True)
assert a.status == Status.closed
assert b.status == Status.closed
async def test_close_workers(s, *workers):
await s.close()

for w in workers:
if not w.status == Status.closed:
await asyncio.sleep(0.1)


@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost")
Expand Down Expand Up @@ -2591,7 +2593,7 @@ async def test_memory_is_none(c, s):
@gen_cluster()
async def test_close_scheduler__close_workers_Worker(s, a, b):
with captured_logger("distributed.comm", level=logging.DEBUG) as log:
await s.close(close_workers=True)
await s.close()
while not a.status == Status.closed:
await asyncio.sleep(0.05)
log = log.getvalue()
Expand All @@ -2601,7 +2603,7 @@ async def test_close_scheduler__close_workers_Worker(s, a, b):
@gen_cluster(Worker=Nanny)
async def test_close_scheduler__close_workers_Nanny(s, a, b):
with captured_logger("distributed.comm", level=logging.DEBUG) as log:
await s.close(close_workers=True)
await s.close()
while not a.status == Status.closed:
await asyncio.sleep(0.05)
log = log.getvalue()
Expand Down Expand Up @@ -2729,6 +2731,14 @@ async def test_rebalance_raises_missing_data3(c, s, a, b, explicit):
futures = await c.scatter(range(100), workers=[a.address])

if explicit:
pytest.xfail(
reason="""Freeing keys and gathering data is using different
channels (stream vs explicit RPC). Therefore, the
partial-fail is very timing sensitive and subject to a race
condition. This test assumes that the data is freed before
the rebalance get_data requests come in but merely deleting
the futures is not sufficient to guarantee this"""
)
keys = [f.key for f in futures]
del futures
out = await s.rebalance(keys=keys)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def g():
assert result == 123

await c.close()
await s.close(close_workers=True)
await s.close()
assert not os.path.exists(os.path.join(a.local_directory, "foobar.py"))


Expand Down Expand Up @@ -2962,7 +2962,7 @@ async def test_missing_released_zombie_tasks(c, s, a, b):
while key not in b.tasks or b.tasks[key].state != "fetch":
await asyncio.sleep(0.01)

await a.close(report=False)
await a.close()

del f1, f2

Expand Down

0 comments on commit 9bb999d

Please sign in to comment.