diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 1e1a3e69b32..f93499fb421 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3218,3 +3218,29 @@ async def test_set_restrictions(c, s, a, b): assert s.tasks[f.key].worker_restrictions == {a.address} s.reschedule(f) await f + + +@gen_cluster( + client=True, + nthreads=[("", 1, {}), ("", 1, {"memory_pause_fraction": 1e-15}), ("", 1, {})], +) +async def test_avoid_paused_workers(c, s, w1, w2, w3): + while s.workers[w2.address].status != Status.paused: + await asyncio.sleep(0.01) + futures = c.map(slowinc, range(8), delay=0.1) + while (len(w1.tasks), len(w2.tasks), len(w3.tasks)) != (4, 0, 4): + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, nthreads=[("", 1, {"memory_pause_fraction": 1e-15})]) +async def test_unpause_schedules_unrannable_tasks(c, s, a): + while s.workers[a.address].status != Status.paused: + await asyncio.sleep(0.01) + + fut = c.submit(inc, 1, key="x") + while not s.unrunnable: + await asyncio.sleep(.001) + assert next(iter(s.unrunnable)).key == "x" + + a.memory_pause_fraction = 0.8 + assert await fut == 2 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 3a72ed07889..e56e3865c77 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -69,7 +69,7 @@ sync, thread_state, ) -from .worker import Worker +from .worker import RUNNING, Worker try: import dask.array # register config @@ -1549,7 +1549,7 @@ def check_instances(): for w in Worker._instances: with suppress(RuntimeError): # closed IOLoop w.loop.add_callback(w.close, report=False, executor_wait=False) - if w.status in (Status.running, Status.paused): + if w.status in RUNNING: w.loop.add_callback(w.close) Worker._instances.clear() diff --git a/distributed/worker.py b/distributed/worker.py index 3d9c9504054..06d914e3838 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3673,11 +3673,7 @@ def get_worker() -> Worker: return thread_state.execution_state["worker"] except AttributeError: try: - return first( - w - for w in Worker._instances - if w.status in (Status.running, Status.paused) - ) + return first(w for w in Worker._instances if w.status in RUNNING) except StopIteration: raise ValueError("No workers found")