diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 52892adf299..423779b64c9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -511,6 +511,7 @@ class WorkerState: _occupancy: double _pid: Py_ssize_t _processing: dict + _long_running: set _resources: dict _services: dict _status: Status @@ -539,6 +540,7 @@ class WorkerState: "_occupancy", "_pid", "_processing", + "_long_running", "_resources", "_services", "_status", @@ -588,6 +590,7 @@ def __init__( self._actors = set() self._has_what = {} self._processing = {} + self._long_running = set() self._executing = {} self._resources = {} self._used_resources = {} @@ -2670,8 +2673,10 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: total_duration = duration + comm old = ws._processing.get(ts, 0) ws._processing[ts] = total_duration - self._total_occupancy += total_duration - old - ws._occupancy += total_duration - old + + if ts not in ws._long_running: + self._total_occupancy += total_duration - old + ws._occupancy += total_duration - old return total_duration @@ -3543,6 +3548,23 @@ def remove_all_replicas(self, ts: TaskState): self._replicated_tasks.remove(ts) ts._who_has.clear() + @ccall + @exceptval(check=False) + def _reevaluate_occupancy_worker(self, ws: WorkerState): + """See reevaluate_occupancy""" + ts: TaskState + old = ws._occupancy + for ts in ws._processing: + self.set_duration_estimate(ts, ws) + + self.check_idle_saturated(ws) + steal = self.extensions.get("stealing") + if steal is None: + return + if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3: + for ts in ws._processing: + steal.recalculate_cost(ts) + class Scheduler(SchedulerState, ServerNode): """Dynamic distributed task scheduler @@ -5521,7 +5543,12 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): occ: double = ws._processing[ts] ws._occupancy -= occ parent._total_occupancy -= occ + # Cannot remove from processing since we're using this for things like + # idleness detection. Idle workers are typically targeted for + # downscaling but we should not downscale workers with long running + # tasks ws._processing[ts] = 0 + ws._long_running.add(ts) self.check_idle_saturated(ws) def handle_worker_status_change(self, status: str, worker: str) -> None: @@ -7827,7 +7854,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): try: if ws is None or not ws._processing: continue - _reevaluate_occupancy_worker(parent, ws) + parent._reevaluate_occupancy_worker(ws) finally: del ws # lose ref @@ -8168,24 +8195,6 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: return {} -@cfunc -@exceptval(check=False) -def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): - """See reevaluate_occupancy""" - ts: TaskState - old = ws._occupancy - for ts in ws._processing: - state.set_duration_estimate(ts, ws) - - state.check_idle_saturated(ws) - steal = state.extensions.get("stealing") - if not steal: - return - if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3: - for ts in ws._processing: - steal.recalculate_cost(ts) - - @cfunc @exceptval(check=False) def decide_worker( diff --git a/distributed/stealing.py b/distributed/stealing.py index 444ba6ccef2..cc9737796a6 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -313,16 +313,8 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): _log_msg = [key, state, victim.address, thief.address, stimulus_id] if ts.state != "processing": - self.log(("not-processing", *_log_msg)) - old_thief = thief.occupancy - new_thief = sum(thief.processing.values()) - old_victim = victim.occupancy - new_victim = sum(victim.processing.values()) - thief.occupancy = new_thief - victim.occupancy = new_victim - self.scheduler.total_occupancy += ( - new_thief - old_thief + new_victim - old_victim - ) + self.scheduler._reevaluate_occupancy_worker(thief) + self.scheduler._reevaluate_occupancy_worker(victim) elif ( state in _WORKER_STATE_UNDEFINED or state in _WORKER_STATE_CONFIRM diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 62d77590639..a603e747d79 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -31,7 +31,7 @@ import dask.bag as db from dask import delayed from dask.optimization import SubgraphCallable -from dask.utils import stringify, tmpfile +from dask.utils import parse_timedelta, stringify, tmpfile from distributed import ( CancelledError, @@ -5156,6 +5156,42 @@ def f(x): assert results == [sum(map(inc, range(10)))] * 10 +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_long_running_not_in_occupancy(c, s, a): + # https://github.com/dask/distributed/issues/5332 + from distributed import Lock + + l = Lock() + await l.acquire() + + def long_running(lock): + sleep(0.1) + secede() + lock.acquire() + + f = c.submit(long_running, l) + while f.key not in s.tasks: + await asyncio.sleep(0.01) + assert s.workers[a.address].occupancy == parse_timedelta( + dask.config.get("distributed.scheduler.unknown-task-duration") + ) + + while s.workers[a.address].occupancy: + await asyncio.sleep(0.01) + await a.heartbeat() + + ts = s.tasks[f.key] + ws = s.workers[a.address] + s.set_duration_estimate(ts, ws) + assert s.workers[a.address].occupancy == 0 + + s.reevaluate_occupancy(0) + assert s.workers[a.address].occupancy == 0 + await l.release() + + await f + + @gen_cluster(client=True) async def test_sub_submit_priority(c, s, a, b): def func(): diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 70ce308ae9f..754965533cc 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -7,6 +7,7 @@ from time import sleep import pytest +import yaml from tornado import gen from distributed import Client, Nanny, Scheduler, Worker, config, default_client @@ -18,6 +19,7 @@ _UnhashableCallable, assert_worker_story, cluster, + dump_cluster_state, gen_cluster, gen_test, inc, @@ -480,3 +482,69 @@ def test_assert_worker_story(): def test_assert_worker_story_malformed_story(story): with pytest.raises(AssertionError, match="Malformed story event"): assert_worker_story(story, []) + + +@gen_cluster() +async def test_dump_cluster_state(s, a, b, tmpdir): + await dump_cluster_state(s, [a, b], str(tmpdir), "dump") + with open(f"{tmpdir}/dump.yaml") as fh: + out = yaml.safe_load(fh) + + assert out.keys() == {"scheduler", "workers", "versions"} + assert out["workers"].keys() == {a.address, b.address} + + +@gen_cluster(nthreads=[]) +async def test_dump_cluster_state_no_workers(s, tmpdir): + await dump_cluster_state(s, [], str(tmpdir), "dump") + with open(f"{tmpdir}/dump.yaml") as fh: + out = yaml.safe_load(fh) + + assert out.keys() == {"scheduler", "workers", "versions"} + assert out["workers"] == {} + + +@gen_cluster(Worker=Nanny) +async def test_dump_cluster_state_nannies(s, a, b, tmpdir): + await dump_cluster_state(s, [a, b], str(tmpdir), "dump") + with open(f"{tmpdir}/dump.yaml") as fh: + out = yaml.safe_load(fh) + + assert out.keys() == {"scheduler", "workers", "versions"} + assert out["workers"].keys() == s.workers.keys() + + +@gen_cluster() +async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmpdir): + a.stop() + await dump_cluster_state(s, [a, b], str(tmpdir), "dump") + with open(f"{tmpdir}/dump.yaml") as fh: + out = yaml.safe_load(fh) + + assert out.keys() == {"scheduler", "workers", "versions"} + assert isinstance(out["workers"][a.address], dict) + assert isinstance(out["workers"][b.address], dict) + + +@pytest.mark.slow +@gen_cluster( + client=True, + Worker=Nanny, + config={"distributed.comm.timeouts.connect": "200ms"}, +) +async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmpdir): + addr1, addr2 = s.workers + clog_fut = asyncio.create_task( + c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[addr1]) + ) + await asyncio.sleep(0.2) + + await dump_cluster_state(s, [a, b], str(tmpdir), "dump") + with open(f"{tmpdir}/dump.yaml") as fh: + out = yaml.safe_load(fh) + + assert out.keys() == {"scheduler", "workers", "versions"} + assert isinstance(out["workers"][addr2], dict) + assert out["workers"][addr1].startswith("OSError('Timed out trying to connect to") + + clog_fut.cancel() diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 64b5da6d037..4477b365efb 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -2,6 +2,7 @@ import random import threading import warnings +from collections import defaultdict from time import sleep import pytest @@ -18,7 +19,7 @@ worker_client, ) from distributed.metrics import time -from distributed.utils_test import double, gen_cluster, inc +from distributed.utils_test import double, gen_cluster, inc, slowinc @gen_cluster(client=True) @@ -315,3 +316,27 @@ async def test_submit_different_names(s, a, b): assert fut > 0 finally: await c.close() + + +@gen_cluster(client=True) +async def test_secede_does_not_claim_worker(c, s, a, b): + """A seceded task must not block the task running it. Tasks scheduled from + within should be evenly distributed""" + # https://github.com/dask/distributed/issues/5332 + def get_addr(x): + w = get_worker() + slowinc(x) + return w.address + + def long_running(): + with worker_client() as client: + futs = client.map(get_addr, range(100)) + workers = defaultdict(int) + for f in futs: + workers[f.result()] += 1 + return dict(workers) + + res = await c.submit(long_running) + assert len(res) == 2 + assert res[a.address] > 25 + assert res[b.address] > 25 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 346008e589e..c550cd39ed8 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -994,18 +994,12 @@ async def coro(): task.print_stack(file=buffer) if cluster_dump_directory: - try: - await dump_cluster_state( - s, - ws, - output_dir=cluster_dump_directory, - func_name=func.__name__, - ) - except Exception: - print( - f"Exception {sys.exc_info()} while trying to " - "dump cluster state." - ) + await dump_cluster_state( + s, + ws, + output_dir=cluster_dump_directory, + func_name=func.__name__, + ) task.cancel() while not task.cancelled(): @@ -1090,30 +1084,25 @@ def get_unclosed(): async def dump_cluster_state( s: Scheduler, ws: list[ServerNode], output_dir: str, func_name: str ) -> None: - """A variant of Client.dump_cluster_state, which does not expect on any of the below + """A variant of Client.dump_cluster_state, which does not rely on any of the below to work: - - individual Workers + - Having a client at all - Client->Scheduler comms - Scheduler->Worker comms (unless using Nannies) """ scheduler_info = s._to_dict() + workers_info: dict[str, Any] versions_info = version_module.get_versions() if not ws or isinstance(ws[0], Worker): workers_info = {w.address: w._to_dict() for w in ws} else: - # Variant of s.broadcast() that deals with unresponsive workers - async def safe_broadcast(addr: str) -> dict: - try: - return await s.broadcast(msg={"op": "dump_state"}, workers=[addr]) - except Exception: - msg = f"Exception {sys.exc_info()} while trying to dump worker state" - return {addr: msg} - - workers_info = merge( - await asyncio.gather(safe_broadcast(w.address) for w in ws) - ) + workers_info = await s.broadcast(msg={"op": "dump_state"}, on_error="return") + workers_info = { + k: repr(v) if isinstance(v, Exception) else v + for k, v in workers_info.items() + } state = { "scheduler": scheduler_info, @@ -1123,7 +1112,7 @@ async def safe_broadcast(addr: str) -> dict: os.makedirs(output_dir, exist_ok=True) fname = os.path.join(output_dir, func_name) + ".yaml" with open(fname, "w") as fh: - yaml.dump(state, fh) + yaml.safe_dump(state, fh) # Automatically convert tuples to lists print(f"Dumped cluster state to {fname}") @@ -1832,8 +1821,10 @@ def assert_worker_story( assert isinstance(ev, tuple) assert isinstance(ev[-2], str) and ev[-2] # stimulus_id assert isinstance(ev[-1], float) # timestamp - assert prev_ts <= ev[-1] # timestamps are monotonic ascending - assert now - 3600 < ev[-1] <= now # timestamps are within the last hour + assert prev_ts <= ev[-1] # Timestamps are monotonic ascending + # Timestamps are within the last hour. It's been observed that a timestamp + # generated in a Nanny process can be a few milliseconds in the future. + assert now - 3600 < ev[-1] <= now + 1 prev_ts = ev[-1] except AssertionError: raise AssertionError(