From df7c9b00e49b9caeb0b46ae7cb60cd2243372ac6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 19 Oct 2021 12:02:14 +0100 Subject: [PATCH] Scheduler to avoid paused workers --- distributed/scheduler.py | 106 +++++++++++++----- distributed/tests/test_client.py | 29 ++--- distributed/tests/test_scheduler.py | 67 +++++++++-- distributed/utils_test.py | 10 +- .../widgets/templates/worker_state.html.j2 | 1 + distributed/worker.py | 46 ++++---- 6 files changed, 181 insertions(+), 78 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index fd873469af..a44ba45276 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -541,6 +541,7 @@ def __init__( self, *, address: str, + status: Status, pid: Py_ssize_t, name: object, nthreads: Py_ssize_t = 0, @@ -560,9 +561,9 @@ def __init__( self._services = services or {} self._versions = versions or {} self._nanny = nanny + self._status = status self._hash = hash(address) - self._status = Status.undefined self._nbytes = 0 self._occupancy = 0 self._memory_unmanaged_old = 0 @@ -721,6 +722,7 @@ def clean(self): """Return a version of this object that is appropriate for serialization""" ws: WorkerState = WorkerState( address=self._address, + status=self._status, pid=self._pid, name=self._name, nthreads=self._nthreads, @@ -736,9 +738,10 @@ def clean(self): return ws def __repr__(self): - return "" % ( + return "" % ( self._address, self._name, + self._status.name, len(self._has_what), len(self._processing), ) @@ -747,6 +750,7 @@ def _repr_html_(self): return get_template("worker_state.html.j2").render( address=self.address, name=self.name, + status=self.status.name, has_what=self._has_what, processing=self.processing, ) @@ -1872,6 +1876,8 @@ class SchedulerState: Set of workers that are not fully utilized * **saturated:** ``{WorkerState}``: Set of workers that are not over-utilized + * **running:** ``{WorkerState}``: + Set of workers that are currently in running state * **clients:** ``{client key: ClientState}`` Clients currently connected to the scheduler @@ -1890,7 +1896,8 @@ class SchedulerState: _idle_dv: dict # dict[str, WorkerState] _n_tasks: Py_ssize_t _resources: dict - _saturated: set + _saturated: set # set[WorkerState] + _running: set # set[WorkerState] _tasks: dict _task_groups: dict _task_prefixes: dict @@ -1977,6 +1984,9 @@ def __init__( self._workers = workers # Note: cython.cast, not typing.cast! self._workers_dv = cast(dict, self._workers) + self._running = { + ws for ws in self._workers.values() if ws.status == Status.running + } self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} # Variables from dask.config, cached by __init__ for performance @@ -2041,9 +2051,13 @@ def resources(self): return self._resources @property - def saturated(self): + def saturated(self) -> "set[WorkerState]": return self._saturated + @property + def running(self) -> "set[WorkerState]": + return self._running + @property def tasks(self): return self._tasks @@ -3339,7 +3353,7 @@ def get_task_duration(self, ts: TaskState) -> double: @ccall @exceptval(check=False) - def valid_workers(self, ts: TaskState) -> set: + def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None """Return set of currently valid workers for key If all workers are valid then this returns ``None``. @@ -3352,7 +3366,7 @@ def valid_workers(self, ts: TaskState) -> set: s: set = None # type: ignore if ts._worker_restrictions: - s = {w for w in ts._worker_restrictions if w in self._workers_dv} + s = {addr for addr in ts._worker_restrictions if addr in self._workers_dv} if ts._host_restrictions: # Resolve the alias here rather than early, for the worker @@ -3379,9 +3393,9 @@ def valid_workers(self, ts: TaskState) -> set: self._resources[resource] = dr = {} sw: set = set() - for w, supplied in dr.items(): + for addr, supplied in dr.items(): if supplied >= required: - sw.add(w) + sw.add(addr) dw[resource] = sw @@ -3391,8 +3405,13 @@ def valid_workers(self, ts: TaskState) -> set: else: s &= ww - if s is not None: - s = {self._workers_dv[w] for w in s} + if s is None: + if len(self._running) < len(self._workers_dv): + return self._running.copy() + else: + s = {self._workers_dv[addr] for addr in s} + if len(self._running) < len(self._workers_dv): + s &= self._running return s @@ -4212,7 +4231,9 @@ def heartbeat_worker( async def add_worker( self, comm=None, - address=None, + *, + address: str, + status: str, keys=(), nthreads=None, name=None, @@ -4238,9 +4259,8 @@ async def add_worker( address = normalize_address(address) host = get_address_host(address) - ws: WorkerState = parent._workers_dv.get(address) - if ws is not None: - raise ValueError("Worker already exists %s" % ws) + if address in parent._workers_dv: + raise ValueError("Worker already exists %s" % address) if name in parent._aliases: logger.warning( @@ -4255,8 +4275,10 @@ async def add_worker( await comm.write(msg) return + ws: WorkerState parent._workers[address] = ws = WorkerState( address=address, + status=Status.lookup[status], # type: ignore pid=pid, nthreads=nthreads, memory_limit=memory_limit or 0, @@ -4267,12 +4289,14 @@ async def add_worker( nanny=nanny, extra=extra, ) + if ws._status == Status.running: + parent._running.add(ws) - dh: dict = parent._host_info.get(host) + dh: dict = parent._host_info.get(host) # type: ignore if dh is None: parent._host_info[host] = dh = {} - dh_addresses: set = dh.get("addresses") + dh_addresses: set = dh.get("addresses") # type: ignore if dh_addresses is None: dh["addresses"] = dh_addresses = set() dh["nthreads"] = 0 @@ -4292,7 +4316,8 @@ async def add_worker( metrics=metrics, ) - # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot exist before this. + # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot + # exist before this. self.check_idle_saturated(ws) # for key in keys: # TODO @@ -4318,7 +4343,7 @@ async def add_worker( assert isinstance(nbytes, dict) already_released_keys = [] for key in nbytes: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = parent._tasks.get(key) # type: ignore if ts is not None and ts.state != "released": if ts.state == "memory": self.add_keys(worker=address, keys=[key]) @@ -4347,14 +4372,15 @@ async def add_worker( "stimulus_id": f"reconnect-already-released-{time()}", } ) - for ts in list(parent._unrunnable): - valid: set = self.valid_workers(ts) - if valid is None or ws in valid: - recommendations[ts._key] = "waiting" + + if ws._status == Status.running: + for ts in parent._unrunnable: + valid: set = self.valid_workers(ts) + if valid is None or ws in valid: + recommendations[ts._key] = "waiting" if recommendations: parent._transitions(recommendations, client_msgs, worker_msgs) - recommendations = {} self.send_all(client_msgs, worker_msgs) @@ -4896,6 +4922,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): parent._saturated.discard(ws) del parent._workers[address] ws.status = Status.closed + parent._running.discard(ws) parent._total_occupancy -= ws._occupancy recommendations: dict = {} @@ -5143,6 +5170,11 @@ def validate_state(self, allow_overlap=False): if not ws._processing: assert not ws._occupancy assert ws._address in parent._idle_dv + assert (ws._status == Status.running) == (ws in parent._running) + + for ws in parent._running: + assert ws._status == Status.running + assert ws._address in parent._workers_dv ts: TaskState for k, ts in parent._tasks.items(): @@ -5423,20 +5455,42 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws._processing[ts] = 0 self.check_idle_saturated(ws) - def handle_worker_status_change(self, status: str, worker: str): + def handle_worker_status_change(self, status: str, worker: str) -> None: parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState = parent._workers_dv.get(worker) # type: ignore if not ws: return + prev_status = ws._status + ws._status = Status.lookup[status] # type: ignore + if ws._status == prev_status: + return + self.log_event( ws._address, { "action": "worker-status-change", - "prev-status": ws._status.name, + "prev-status": prev_status.name, "status": status, }, ) - ws._status = Status.lookup[status] # type: ignore + + if ws._status == Status.running: + parent._running.add(ws) + + recs = {} + ts: TaskState + for ts in parent._unrunnable: + valid: set = self.valid_workers(ts) + if valid is None or ws in valid: + recs[ts._key] = "waiting" + if recs: + client_msgs: dict = {} + worker_msgs: dict = {} + parent._transitions(recs, client_msgs, worker_msgs) + self.send_all(client_msgs, worker_msgs) + + else: + parent._running.discard(ws) async def handle_worker(self, comm=None, worker=None): """ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 982212bf74..2c4a3130f3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3650,51 +3650,46 @@ def test_reconnect(loop): "9393", "--no-dashboard", ] - with popen(scheduler_cli) as s: + with popen(scheduler_cli): c = Client("127.0.0.1:9393", loop=loop) - start = time() - while len(c.nthreads()) != 1: - sleep(0.1) - assert time() < start + 3 - + c.wait_for_workers(1, timeout=10) x = c.submit(inc, 1) - assert x.result() == 2 + assert x.result(timeout=10) == 2 start = time() while c.status != "connecting": - assert time() < start + 5 + assert time() < start + 10 sleep(0.01) assert x.status == "cancelled" with pytest.raises(CancelledError): - x.result() + x.result(timeout=10) - with popen(scheduler_cli) as s: + with popen(scheduler_cli): start = time() while c.status != "running": sleep(0.1) - assert time() < start + 5 + assert time() < start + 10 start = time() while len(c.nthreads()) != 1: sleep(0.05) - assert time() < start + 15 + assert time() < start + 10 x = c.submit(inc, 1) - assert x.result() == 2 + assert x.result(timeout=10) == 2 start = time() while True: + assert time() < start + 10 try: - x.result() + x.result(timeout=10) assert False except CommClosedError: continue except CancelledError: break - assert time() < start + 5 - sleep(0.1) - sync(loop, w.close) + sync(loop, w.close, timeout=1) c.close() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e08720ec36..affa58702c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1101,24 +1101,23 @@ async def test_worker_breaks_and_returns(c, s, a): @gen_cluster(client=True, nthreads=[]) async def test_no_workers_to_memory(c, s): - x = delayed(slowinc)(1, delay=0.4) + x = delayed(slowinc)(1, delay=10.0) y = delayed(slowinc)(x, delay=0.4) z = delayed(slowinc)(y, delay=0.4) yy, zz = c.persist([y, z]) - while not s.tasks: + while len(s.tasks) < 3: await asyncio.sleep(0.01) w = Worker(s.address, nthreads=1) w.update_data(data={y.key: 3}) - await w - start = time() - - while not s.workers: + await w + while not s.workers or s.workers[w.address].status != Status.running: await asyncio.sleep(0.01) + assert time() < start + 9 # Did not wait for x assert s.get_task_status(keys={x.key, y.key, z.key}) == { x.key: "released", @@ -1289,7 +1288,11 @@ async def test_scheduler_file(): async def test_non_existent_worker(c, s): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): await s.add_worker( - address="127.0.0.1:5738", nthreads=2, nbytes={}, host_info={} + address="127.0.0.1:5738", + status="running", + nthreads=2, + nbytes={}, + host_info={}, ) futures = c.map(inc, range(10)) await asyncio.sleep(0.300) @@ -1929,7 +1932,7 @@ async def test_default_task_duration_splits(c, s, a, b): @gen_test() -async def test_no_danglng_asyncio_tasks(): +async def test_no_dangling_asyncio_tasks(): start = asyncio.all_tasks() async with Scheduler(dashboard_address=":0") as s: async with Worker(s.address, name="0"): @@ -3160,9 +3163,32 @@ async def test_worker_heartbeat_after_cancel(c, s, *workers): await asyncio.gather(*(w.heartbeat() for w in workers)) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_worker_reconnect_task_memory(c, s, a): + a.periodic_callbacks["heartbeat"].stop() + + futs = c.map(inc, range(10)) + res = c.submit(sum, futs) + + while not a.executing_count and not a.data: + await asyncio.sleep(0.001) + + await s.remove_worker(address=a.address, close=False) + while not res.done(): + await a.heartbeat() + + await res + assert ("no-worker", "memory") in { + (start, finish) for (_, start, finish, _, _) in s.transition_log + } + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_worker_reconnect_task_memory_with_resources(c, s, a): async with Worker(s.address, resources={"A": 1}) as b: + while s.workers[b.address].status != Status.running: + await asyncio.sleep(0.001) + b.periodic_callbacks["heartbeat"].stop() futs = c.map(inc, range(10), resources={"A": 1}) @@ -3190,3 +3216,28 @@ async def test_set_restrictions(c, s, a, b): assert s.tasks[f.key].worker_restrictions == {a.address} s.reschedule(f) await f + + +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_avoid_paused_workers(c, s, w1, w2, w3): + w2.memory_pause_fraction = 1e-15 + while s.workers[w2.address].status != Status.paused: + await asyncio.sleep(0.01) + futures = c.map(slowinc, range(8), delay=0.1) + while (len(w1.tasks), len(w2.tasks), len(w3.tasks)) != (4, 0, 4): + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_unpause_schedules_unrannable_tasks(c, s, a): + a.memory_pause_fraction = 1e-15 + while s.workers[a.address].status != Status.paused: + await asyncio.sleep(0.01) + + fut = c.submit(inc, 1, key="x") + while not s.unrunnable: + await asyncio.sleep(0.001) + assert next(iter(s.unrunnable)).key == "x" + + a.memory_pause_fraction = 0.8 + assert await fut == 2 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 3d4a55cbef..c622a2faff 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -71,7 +71,7 @@ sync, thread_state, ) -from .worker import Worker +from .worker import RUNNING, Worker try: import dask.array # register config @@ -837,8 +837,10 @@ async def start_cluster( await asyncio.gather(*workers) start = time() - while len(s.workers) < len(nthreads) or any( - comm.comm is None for comm in s.stream_comms.values() + while ( + len(s.workers) < len(nthreads) + or any(ws.status != Status.running for ws in s.workers.values()) + or any(comm.comm is None for comm in s.stream_comms.values()) ): await asyncio.sleep(0.01) if time() > start + 30: @@ -1557,7 +1559,7 @@ def check_instances(): for w in Worker._instances: with suppress(RuntimeError): # closed IOLoop w.loop.add_callback(w.close, report=False, executor_wait=False) - if w.status in (Status.running, Status.paused): + if w.status in RUNNING: w.loop.add_callback(w.close) Worker._instances.clear() diff --git a/distributed/widgets/templates/worker_state.html.j2 b/distributed/widgets/templates/worker_state.html.j2 index 2646d0fa26..cd152080bf 100644 --- a/distributed/widgets/templates/worker_state.html.j2 +++ b/distributed/widgets/templates/worker_state.html.j2 @@ -1,4 +1,5 @@ WorkerState: {{ address | html_escape }} name: {{ name }} + status: {{ status }} memory: {{ has_what | length }} processing: {{ processing | length }} diff --git a/distributed/worker.py b/distributed/worker.py index 9fca11e9f8..8e511d6ce2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -97,6 +97,7 @@ no_value = "--no-value-sentinel--" +# TaskState.status subsets PROCESSING = { "waiting", "ready", @@ -108,6 +109,8 @@ } READY = {"ready", "constrained"} +# Worker.status subsets +RUNNING = {Status.running, Status.paused, Status.closing_gracefully} DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension] @@ -1024,6 +1027,9 @@ def executor(self): def status(self, value): """Override Server.status to notify the Scheduler of status changes""" ServerNode.status.__set__(self, value) + self._send_worker_status_change() + + def _send_worker_status_change(self) -> None: if ( self.batched_stream and self.batched_stream.comm @@ -1032,6 +1038,8 @@ def status(self, value): self.batched_stream.send( {"op": "worker-status-change", "status": self._status.name} ) + elif self._status != Status.closed: + self.loop.call_later(0.05, self._send_worker_status_change) async def get_metrics(self): out = dict( @@ -1105,6 +1113,7 @@ async def _register_with_scheduler(self): op="register-worker", reply=False, address=self.contact_address, + status=self.status.name, keys=list(self.data), nthreads=self.nthreads, name=self.name, @@ -1195,14 +1204,15 @@ async def heartbeat(self): self._update_latency(end - start) if response["status"] == "missing": - for i in range(10): - if self.status not in (Status.running, Status.paused): - break - else: - await asyncio.sleep(0.05) - else: - await self._register_with_scheduler() + # If running, wait 0.5s and then re-register self. Otherwise just exit. + start = time() + while self.status in RUNNING: + if time() >= start + 0.5: + await self._register_with_scheduler() + return + await asyncio.sleep(0.01) return + self.scheduler_delay = response["time"] - middle self.periodic_callbacks["heartbeat"].callback_time = ( response["heartbeat-interval"] * 1000 @@ -1231,7 +1241,7 @@ async def handle_scheduler(self, comm): logger.exception(e) raise finally: - if self.reconnect and self.status in (Status.running, Status.paused): + if self.reconnect and self.status in RUNNING: logger.info("Connection to scheduler broken. Reconnecting...") self.loop.add_callback(self.heartbeat) else: @@ -1443,11 +1453,7 @@ async def close( logger.info("Stopping worker at %s", self.address) except ValueError: # address not available if already closed logger.info("Stopping worker") - if self.status not in ( - Status.running, - Status.paused, - Status.closing_gracefully, - ): + if self.status not in RUNNING: logger.info("Closed worker has not yet started: %s", self.status) self.status = Status.closing @@ -1475,9 +1481,7 @@ async def close( # If this worker is the last one alive, clean up the worker # initialized clients if not any( - w - for w in Worker._instances - if w != self and w.status in (Status.running, Status.paused) + w for w in Worker._instances if w != self and w.status in RUNNING ): for c in Worker._initialized_clients: # Regardless of what the client was initialized with @@ -2607,7 +2611,7 @@ async def gather_dep( Total number of bytes for all the dependencies in to_gather combined """ cause: TaskState | None = None - if self.status not in (Status.running, Status.paused): + if self.status not in RUNNING: return with log_errors(): @@ -3619,7 +3623,7 @@ def validate_task(self, ts): raise def validate_state(self): - if self.status not in (Status.running, Status.paused): + if self.status not in RUNNING: return try: assert self.executing_count >= 0 @@ -3786,11 +3790,7 @@ def get_worker() -> Worker: return thread_state.execution_state["worker"] except AttributeError: try: - return first( - w - for w in Worker._instances - if w.status in (Status.running, Status.paused) - ) + return first(w for w in Worker._instances if w.status in RUNNING) except StopIteration: raise ValueError("No workers found")