diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index bbbcd8c0e2a..f11aa357123 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -312,8 +312,7 @@ def log_reject(msg: str) -> None: # THEN temporarily keep the extra replica on the candidate with status=running. # # This prevents a ping-pong effect between ReduceReplicas (or any other policy - # that yields drop commands with multiple candidates) and RetireWorker - # (to be later introduced by https://github.com/dask/distributed/pull/5381): + # that yields drop commands with multiple candidates) and RetireWorker: # 1. RetireWorker replicates in-memory tasks from worker A (very busy and being # retired) to worker B (idle) # 2. on the next AMM iteration 2 seconds later, ReduceReplicas drops the same @@ -494,3 +493,149 @@ def run(self): ndrop, nkeys, ) + + +class RetireWorker(ActiveMemoryManagerPolicy): + """Replicate somewhere else all unique in-memory tasks on a worker, preparing for + its shutdown. + + At any given time, the AMM may have registered multiple instances of this policy, + one for each worker currently being retired - meaning that most of the time no + instances will be registered at all. For this reason, this policy doesn't figure in + the dask config (:file:`distributed.yaml`). Instances are added by + :meth:`distributed.Scheduler.retire_workers` and automatically remove themselves + once the worker has been retired. If the AMM is disabled in the dask config, + :meth:`~distributed.Scheduler.retire_workers` will start a temporary ad-hoc one. + + **Failure condition** + + There may not be any suitable workers to receive the tasks from the retiring worker. + This happens in two use cases: + + 1. This is the only worker in the cluster, or + 2. All workers are either paused or being retired at the same time + + In either case, this policy will fail to move out all keys and + Scheduler.retire_workers will abort the retirement. The flag ``no_recipients`` will + be raised. + + There is a third use case, where a task fails to be replicated away for whatever + reason; in this case we'll just wait for the next AMM iteration and try again + (possibly with a different receiving worker, e.g. if the receiving worker was + hung but not yet declared dead). + + **Retiring a worker with spilled tasks** + + On its very first iteration, this policy suggests that other workers should fetch + all unique in-memory tasks. Frequently, this means that in the next few moments the + worker to be retired will be bombarded by :meth:`distributed.worker.Worker.get_data` + calls from the rest of the cluster. This can be a problem if most of the managed + memory of the worker has been spilled out, as it could send the worker above the + terminate threshold. Two measures are in place in order to prevent this: + + - At every iteration, this policy drops all tasks that have already been replicated + somewhere else. This makes room for further tasks to be moved out of the spill + file in order to be replicated onto another worker. + - Once a worker passes the ``pause`` threshold, + :meth:`~distributed.worker.Worker.get_data` throttles the number of outgoing + connections to 1. + + Parameters + ========== + address: str + URI of the worker to be retired + """ + + address: str + no_recipients: bool + + def __init__(self, address: str): + self.address = address + self.no_recipients = False + + def __repr__(self) -> str: + return f"RetireWorker({self.address})" + + def run(self): + """""" + ws = self.manager.scheduler.workers.get(self.address) + if ws is None: + logger.debug("Removing policy %s: Worker no longer in cluster", self) + self.manager.policies.remove(self) + return + + nrepl = 0 + nno_rec = 0 + + logger.debug("Retiring %s", ws) + for ts in ws.has_what: + if len(ts.who_has) > 1: + # There are already replicas of this key on other workers. + # Suggest dropping the replica from this worker. + # Use cases: + # 1. The suggestion is accepted by the AMM and by the Worker. + # The replica on this worker is dropped. + # 2. The suggestion is accepted by the AMM, but rejected by the Worker. + # We'll try again at the next AMM iteration. + # 3. The suggestion is rejected by the AMM, because another policy + # (e.g. ReduceReplicas) already suggested the same for this worker + # 4. The suggestion is rejected by the AMM, because the task has + # dependents queued or running on the same worker. + # We'll try again at the next AMM iteration. + # 5. The suggestion is rejected by the AMM, because all replicas of the + # key are on workers being retired and the other RetireWorker + # instances already made the same suggestion. We need to deal with + # this case and create a replica elsewhere. + drop_ws = (yield "drop", ts, {ws}) + if drop_ws: + continue # Use case 1 or 2 + if ts.who_has & self.manager.scheduler.running: + continue # Use case 3 or 4 + # Use case 5 + + # Either the worker holds the only replica or all replicas are being held + # by workers that are being retired + nrepl += 1 + # Don't create an unnecessary additional replica if another policy already + # asked for one + try: + has_pending_repl = bool(self.manager.pending[ts][0]) + except KeyError: + has_pending_repl = False + + if not has_pending_repl: + rec_ws = (yield "replicate", ts, None) + if not rec_ws: + # replication was rejected by the AMM (see _find_recipient) + nno_rec += 1 + + if nno_rec: + # All workers are paused or closing_gracefully. + # Scheduler.retire_workers will read this flag and exit immediately. + # TODO after we implement the automatic transition of workers from paused + # to closing_gracefully after a timeout expires, we should revisit this + # code to wait for paused workers and only exit immediately if all + # workers are in closing_gracefully status. + self.no_recipients = True + logger.warning( + f"Tried retiring worker {self.address}, but {nno_rec} tasks could not " + "be moved as there are no suitable workers to receive them. " + "The worker will not be retired." + ) + self.manager.policies.remove(self) + elif nrepl: + logger.info( + f"Retiring worker {self.address}; {nrepl} keys are being moved away.", + ) + else: + logger.info( + f"Retiring worker {self.address}; no unique keys need to be moved away." + ) + self.manager.policies.remove(self) + + def done(self) -> bool: + """Return True if it is safe to close the worker down; False otherwise""" + ws = self.manager.scheduler.workers.get(self.address) + if ws is None: + return True + return all(len(ts.who_has) > 1 for ts in ws.has_what) diff --git a/distributed/core.py b/distributed/core.py index ef0b139811a..de7096c6615 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -69,6 +69,11 @@ class Status(Enum): Status.lookup = {s.name: s for s in Status} # type: ignore +Status.ANY_RUNNING = { # type: ignore + Status.running, + Status.paused, + Status.closing_gracefully, +} class RPCClosed(IOError): @@ -257,7 +262,7 @@ def __await__(self): async def _(): timeout = getattr(self, "death_timeout", 0) async with self._startup_lock: - if self.status in (Status.running, Status.paused): + if self.status in Status.ANY_RUNNING: return self if timeout: try: @@ -519,7 +524,7 @@ async def handle_comm(self, comm): self._ongoing_coroutines.add(result) result = await result except (CommClosedError, asyncio.CancelledError): - if self.status in (Status.running, Status.paused): + if self.status in Status.ANY_RUNNING: logger.info("Lost connection to %r", address, exc_info=True) break except Exception as e: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9f83be5051f..1fb3363ae61 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -56,7 +56,7 @@ from . import preloading, profile from . import versions as version_module -from .active_memory_manager import ActiveMemoryManagerExtension +from .active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from .batched import BatchedSend from .comm import ( Comm, @@ -6805,11 +6805,11 @@ def _key(group): async def retire_workers( self, comm=None, - workers=None, - remove=True, - close_workers=False, - names=None, - lock=True, + *, + workers: "list[str] | None" = None, + names: "list[str] | None" = None, + close_workers: bool = False, + remove: bool = True, **kwargs, ) -> dict: """Gracefully retire workers from cluster @@ -6818,16 +6818,18 @@ async def retire_workers( ---------- workers: list (optional) List of worker addresses to retire. - If not provided we call ``workers_to_close`` which finds a good set names: list (optional) List of worker names to retire. - remove: bool (defaults to True) - Whether or not to remove the worker metadata immediately or else - wait for the worker to contact us + Mutually exclusive with ``workers``. + If neither ``workers`` nor ``names`` are provided, we call + ``workers_to_close`` which finds a good set. close_workers: bool (defaults to False) Whether or not to actually close the worker explicitly from here. Otherwise we expect some external job scheduler to finish off the worker. + remove: bool (defaults to True) + Whether or not to remove the worker metadata immediately or else + wait for the worker to contact us **kwargs: dict Extra options to pass to workers_to_close to determine which workers we should drop @@ -6845,78 +6847,123 @@ async def retire_workers( ws: WorkerState ts: TaskState with log_errors(): - async with self._lock if lock else empty_context: + # This lock makes retire_workers, rebalance, and replicate mutually + # exclusive and will no longer be necessary once rebalance and replicate are + # migrated to the Active Memory Manager. + # Note that, incidentally, it also prevents multiple calls to retire_workers + # from running in parallel - this is unnecessary. + async with self._lock: if names is not None: if workers is not None: raise TypeError("names and workers are mutually exclusive") if names: logger.info("Retire worker names %s", names) - names = set(map(str, names)) - workers = { - ws._address + names_set = {str(name) for name in names} + wss = { + ws for ws in parent._workers_dv.values() - if str(ws._name) in names + if str(ws._name) in names_set } - elif workers is None: - while True: - try: - workers = self.workers_to_close(**kwargs) - if not workers: - return {} - return await self.retire_workers( - workers=workers, - remove=remove, + elif workers is not None: + wss = { + parent._workers_dv[address] + for address in workers + if address in parent._workers_dv + } + else: + wss = { + parent._workers_dv[address] + for address in self.workers_to_close(**kwargs) + } + if not wss: + return {} + + stop_amm = False + amm: ActiveMemoryManagerExtension = self.extensions["amm"] + if not amm.running: + amm = ActiveMemoryManagerExtension( + self, policies=set(), register=False, start=True, interval=2.0 + ) + stop_amm = True + + try: + coros = [] + for ws in wss: + logger.info("Retiring worker %s", ws._address) + + policy = RetireWorker(ws._address) + amm.add_policy(policy) + + # Change Worker.status to closing_gracefully. Immediately set + # the same on the scheduler to prevent race conditions. + prev_status = ws.status + ws.status = Status.closing_gracefully + self.running.discard(ws) + self.stream_comms[ws.address].send( + {"op": "worker-status-change", "status": ws.status.name} + ) + + coros.append( + self._retire_worker( + ws, + policy, + prev_status=prev_status, close_workers=close_workers, - lock=False, + remove=remove, ) - except KeyError: # keys left during replicate - pass + ) - workers = { - parent._workers_dv[w] for w in workers if w in parent._workers_dv - } - if not workers: - return {} - logger.info("Retire workers %s", workers) - - # Keys orphaned by retiring those workers - keys = {k for w in workers for k in w.has_what} - keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} - - if keys: - other_workers = set(parent._workers_dv.values()) - workers - if not other_workers: - return {} - logger.info("Moving %d keys to other workers", len(keys)) - await self.replicate( - keys=keys, - workers=[ws._address for ws in other_workers], - n=1, - delete=False, - lock=False, - ) + # Give the AMM a kick, in addition to its periodic running. This is + # to avoid unnecessarily waiting for a potentially arbitrarily long + # time (depending on interval settings) + amm.run_once() - worker_keys = {ws._address: ws.identity() for ws in workers} - if close_workers: - await asyncio.gather( - *[self.close_worker(worker=w, safe=True) for w in worker_keys] - ) - if remove: - await asyncio.gather( - *[self.remove_worker(address=w, safe=True) for w in worker_keys] - ) + workers_info = dict(await asyncio.gather(*coros)) + workers_info.pop(None, None) + finally: + if stop_amm: + amm.stop() - self.log_event( - "all", - { - "action": "retire-workers", - "workers": worker_keys, - "moved-keys": len(keys), - }, + self.log_event("all", {"action": "retire-workers", "workers": workers_info}) + self.log_event(list(workers_info), {"action": "retired"}) + + return workers_info + + async def _retire_worker( + self, + ws: WorkerState, + policy: RetireWorker, + prev_status: Status, + close_workers: bool, + remove: bool, + ) -> tuple: # tuple[str | None, dict] + parent: SchedulerState = cast(SchedulerState, self) + + while not policy.done(): + if policy.no_recipients: + # Abort retirement. This time we don't need to worry about race + # conditions and we can wait for a round-trip. + self.stream_comms[ws.address].send( + {"op": "worker-status-change", "status": prev_status.name} ) - self.log_event(list(worker_keys), {"action": "retired"}) + return None, {} + + # Sleep 0.01s when there are 4 tasks or less + # Sleep 0.5s when there are 200 or more + poll_interval = max(0.01, min(0.5, len(ws.has_what) / 400)) + await asyncio.sleep(poll_interval) + + logger.debug( + "All unique keys on worker %s have been replicated elsewhere", ws._address + ) + + if close_workers and ws._address in parent._workers_dv: + await self.close_worker(worker=ws._address, safe=True) + if remove: + await self.remove_worker(address=ws._address, safe=True) - return worker_keys + logger.info("Retired worker %s", ws._address) + return ws._address, ws.identity() def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): """ diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 9f2e9554b17..cc7dcbd88b7 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -4,10 +4,11 @@ import logging import random from contextlib import contextmanager +from time import sleep import pytest -from distributed import Nanny +from distributed import Nanny, wait from distributed.active_memory_manager import ( ActiveMemoryManagerExtension, ActiveMemoryManagerPolicy, @@ -722,6 +723,218 @@ async def test_ReduceReplicas(c, s, *workers): await asyncio.sleep(0.01) +@pytest.mark.parametrize("start_amm", [False, True]) +@gen_cluster(client=True) +async def test_RetireWorker_amm_on_off(c, s, a, b, start_amm): + """retire_workers must work both with and without the AMM started""" + if start_amm: + await c.amm.start() + else: + await c.amm.stop() + + futures = await c.scatter({"x": 1}, workers=[a.address]) + await c.retire_workers([a.address]) + assert a.address not in s.workers + assert "x" in b.data + + +@gen_cluster( + client=True, + config={ + "distributed.scheduler.active-memory-manager.start": True, + "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.policies": [], + }, +) +async def test_RetireWorker_no_remove(c, s, a, b): + """Test RetireWorker behaviour on retire_workers(..., remove=False)""" + + x = await c.scatter({"x": "x"}, workers=[a.address]) + await c.retire_workers([a.address], close_workers=False, remove=False) + # Wait 2 AMM iterations + # retire_workers may return before all keys have been dropped from a + while s.tasks["x"].who_has != {s.workers[b.address]}: + await asyncio.sleep(0.01) + assert a.address in s.workers + # Policy has been removed without waiting for worker to disappear from + # Scheduler.workers + assert not s.extensions["amm"].policies + + +@pytest.mark.slow +@pytest.mark.parametrize("use_ReduceReplicas", [False, True]) +@gen_cluster( + client=True, + Worker=Nanny, + config={ + "distributed.scheduler.active-memory-manager.start": True, + "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.policies": [ + {"class": "distributed.active_memory_manager.ReduceReplicas"}, + ], + }, +) +async def test_RetireWorker_with_ReduceReplicas(c, s, *nannies, use_ReduceReplicas): + """RetireWorker and ReduceReplicas work well with each other. + + If ReduceReplicas is enabled, + 1. On the first AMM iteration, either ReduceReplicas or RetireWorker (arbitrarily + depending on which comes first in the iteration of + ActiveMemoryManagerExtension.policies) deletes non-unique keys, choosing from + workers to be retired first. At the same time, RetireWorker replicates unique + keys. + 2. On the second AMM iteration, either ReduceReplicas or RetireWorker deletes the + keys replicated at the previous round from the worker to be retired. + + If ReduceReplicas is not enabled, all drops are performed by RetireWorker. + + This test fundamentally relies on workers in the process of being retired to be + always picked first by ActiveMemoryManagerExtension._find_dropper. + """ + ws_a, ws_b = s.workers.values() + if not use_ReduceReplicas: + s.extensions["amm"].policies.clear() + + x = c.submit(lambda: "x" * 2 ** 26, key="x", workers=[ws_a.address]) # 64 MiB + y = c.submit(lambda: "y" * 2 ** 26, key="y", workers=[ws_a.address]) # 64 MiB + z = c.submit(lambda x: None, x, key="z", workers=[ws_b.address]) # copy x to ws_b + # Make sure that the worker NOT being retired has the most RAM usage to test that + # it is not being picked first since there's a retiring worker. + w = c.submit(lambda: "w" * 2 ** 28, key="w", workers=[ws_b.address]) # 256 MiB + await wait([x, y, z, w]) + + await c.retire_workers([ws_a.address], remove=False) + # retire_workers may return before all keys have been dropped from a + while ws_a.has_what: + await asyncio.sleep(0.01) + assert {ts.key for ts in ws_b.has_what} == {"x", "y", "z", "w"} + + +@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=NO_AMM_START) +async def test_RetireWorker_all_replicas_are_being_retired(c, s, w1, w2, w3): + """There are multiple replicas of a key, but they all reside on workers that are + being retired + """ + ws1 = s.workers[w1.address] + ws2 = s.workers[w2.address] + ws3 = s.workers[w3.address] + fut = await c.scatter({"x": "x"}, workers=[w1.address, w2.address], broadcast=True) + assert s.tasks["x"].who_has == {ws1, ws2} + await c.retire_workers([w1.address, w2.address]) + assert s.tasks["x"].who_has == {ws3} + + +@gen_cluster( + client=True, + nthreads=[("", 1)] * 4, + config={ + "distributed.scheduler.active-memory-manager.start": True, + # test that we're having a manual amm.run_once() "kick" from retire_workers + "distributed.scheduler.active-memory-manager.interval": 999, + "distributed.scheduler.active-memory-manager.policies": [], + }, +) +async def test_RetireWorker_no_recipients(c, s, w1, w2, w3, w4): + """All workers are retired at once. + + Test use cases: + 1. (w1) worker contains no data -> it is retired + 2. (w2) worker contains unique data -> it is not retired + 3. (w3, w4) worker contains non-unique data, but all replicas are on workers that + are being retired -> all but one are retired + """ + x = await c.scatter({"x": "x"}, workers=[w2.address]) + y = await c.scatter({"y": "y"}, workers=[w3.address, w4.address], broadcast=True) + + out = await c.retire_workers([w1.address, w2.address, w3.address, w4.address]) + + assert set(out) in ({w1.address, w3.address}, {w1.address, w4.address}) + assert not s.extensions["amm"].policies + assert set(s.workers) in ({w2.address, w3.address}, {w2.address, w4.address}) + # After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to + # retired went back from closing_gracefully to running and can run tasks + while any(ws.status != Status.running for ws in s.workers.values()): + await asyncio.sleep(0.01) + assert await c.submit(inc, 1) == 2 + + +@gen_cluster( + client=True, + config={ + "distributed.scheduler.active-memory-manager.start": True, + "distributed.scheduler.active-memory-manager.interval": 999, + "distributed.scheduler.active-memory-manager.policies": [], + }, +) +async def test_RetireWorker_all_recipients_are_paused(c, s, a, b): + ws_a = s.workers[a.address] + ws_b = s.workers[b.address] + + b.memory_pause_fraction = 1e-15 + while ws_b.status != Status.paused: + await asyncio.sleep(0.01) + + x = await c.scatter("x", workers=[a.address]) + out = await c.retire_workers([a.address]) + assert out == {} + assert not s.extensions["amm"].policies + assert set(s.workers) == {a.address, b.address} + + # After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to + # retired went back from closing_gracefully to running and can run tasks + while ws_a.status != Status.running: + await asyncio.sleep(0.01) + assert await c.submit(inc, 1) == 2 + + +# FIXME can't drop runtime of this test below 10s; see distributed#5585 +@pytest.mark.slow +@gen_cluster( + client=True, + Worker=Nanny, + nthreads=[("", 1)] * 3, + config={ + "distributed.scheduler.worker-ttl": "500ms", + "distributed.scheduler.active-memory-manager.start": True, + "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.policies": [], + }, +) +async def test_RetireWorker_faulty_recipient(c, s, *nannies): + """RetireWorker requests to replicate a key onto a unresponsive worker. + The AMM will iterate multiple times, repeating the command, until eventually the + scheduler declares the worker dead and removes it from the pool; at that point the + AMM will choose another valid worker and complete the job. + """ + # ws1 is being retired + # ws2 has the lowest RAM usage and is chosen as a recipient, but is unresponsive + ws1, ws2, ws3 = s.workers.values() + f = c.submit(lambda: "x", key="x", workers=[ws1.address]) + await wait(f) + assert s.tasks["x"].who_has == {ws1} + + # Fill ws3 with 200 MB of managed memory + # We're using plenty to make sure it's safely more than the unmanaged memory of ws2 + clutter = c.map(lambda i: "x" * 4_000_000, range(50), workers=[ws3.address]) + await wait([f] + clutter) + while ws3.memory.process < 200_000_000: + # Wait for heartbeat + await asyncio.sleep(0.01) + assert ws2.memory.process < ws3.memory.process + + # Make ws2 unresponsive + clog_fut = asyncio.create_task(c.run(sleep, 3600, workers=[ws2.address])) + await asyncio.sleep(0.2) + assert ws2.address in s.workers + + await c.retire_workers([ws1.address]) + assert ws1.address not in s.workers + # The AMM tried over and over to send the data to ws2, until it was declared dead + assert ws2.address not in s.workers + assert s.tasks["x"].who_has == {ws3} + clog_fut.cancel() + + class DropEverything(ActiveMemoryManagerPolicy): """Inanely suggest to drop every single key in the cluster""" @@ -796,3 +1009,48 @@ async def test_ReduceReplicas_stress(c, s, *nannies): policy must not disrupt the computation too much. """ await tensordot_stress(c) + + +# @pytest.mark.slow +@pytest.mark.avoid_ci(reason="distributed#5371") +@pytest.mark.parametrize("use_ReduceReplicas", [False, True]) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 10, + Worker=Nanny, + config={ + "distributed.scheduler.active-memory-manager.start": True, + # If interval is too low, then the AMM will rerun while tasks have not yet have + # the time to migrate. This is OK if it happens occasionally, but if this + # setting is too aggressive the cluster will get flooded with repeated comm + # requests. + "distributed.scheduler.active-memory-manager.interval": 2.0, + "distributed.scheduler.active-memory-manager.policies": [ + {"class": "distributed.active_memory_manager.ReduceReplicas"}, + ], + }, +) +async def test_RetireWorker_stress(c, s, *nannies, use_ReduceReplicas): + """It is safe to retire the best part of a cluster in the middle of a computation""" + if not use_ReduceReplicas: + s.extensions["amm"].policies.clear() + + addrs = list(s.workers) + random.shuffle(addrs) + print(f"Removing all workers except {addrs[-1]}") + + # Note: Scheduler._lock effectively prevents multiple calls to retire_workers from + # running at the same time. However, the lock only exists for the benefit of legacy + # (non-AMM) rebalance() and replicate() methods. Once the lock is removed, these + # calls will become parallel and the test *should* continue working. + + tasks = [asyncio.create_task(tensordot_stress(c))] + await asyncio.sleep(1) + tasks.append(asyncio.create_task(c.retire_workers(addrs[0:2]))) + await asyncio.sleep(1) + tasks.append(asyncio.create_task(c.retire_workers(addrs[2:5]))) + await asyncio.sleep(1) + tasks.append(asyncio.create_task(c.retire_workers(addrs[5:9]))) + + await asyncio.gather(*tasks) + assert set(s.workers) == {addrs[9]} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3d316cee2d9..ea395700491 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4365,7 +4365,7 @@ async def test_retire_workers_2(c, s, a, b): assert a.address not in s.workers -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("", 1)] * 10) async def test_retire_many_workers(c, s, *workers): futures = await c.scatter(list(range(100))) @@ -4381,8 +4381,16 @@ async def test_retire_many_workers(c, s, *workers): assert all(future.done() for future in futures) assert all(s.tasks[future.key].state == "memory" for future in futures) - for w, keys in s.has_what.items(): - assert 15 < len(keys) < 50 + assert await c.gather(futures) == list(range(100)) + + # Don't count how many task landed on each worker. + # Normally, tasks would be distributed evenly over the surviving workers. However, + # here all workers share the same process memory, so you'll get an unintuitive + # distribution of tasks if for any reason one transfer take longer than 2 seconds + # and as a consequence the Active Memory Manager ends up running for two iterations. + # This is something that will happen more frequently on low-powered CI machines. + # See test_active_memory_manager.py for tests that robustly verify the statistical + # distribution of tasks after worker retirement. @gen_cluster( diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 720e3dc0015..eb343994a55 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1625,6 +1625,37 @@ async def test_worker_listens_on_same_interface_by_default(cleanup, Worker): assert s.ip == w.ip +def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None: + """Test that an in-memory key was transferred from worker w_from to worker w_to by + the Active Memory Manager and it was not recalculated on w_to + """ + assert_worker_story( + w_to.story(key), + [ + (key, "ensure-task-exists", "released"), + (key, "released", "fetch", "fetch", {}), + (key, "fetch", "missing", "missing", {}), + (key, "missing", "fetch", "fetch", {}), + ("gather-dependencies", w_from.address, lambda set_: key in set_), + (key, "fetch", "flight", "flight", {}), + ("request-dep", w_from.address, lambda set_: key in set_), + ("receive-dep", w_from.address, lambda set_: key in set_), + (key, "put-in-memory"), + (key, "flight", "memory", "memory", {}), + ], + # There may be additional ('missing', 'fetch', 'fetch') events if transfers + # are slow enough that the Active Memory Manager ends up requesting them a + # second time. Here we're asserting that no matter how slow CI is, all + # transfers will be completed within 2 seconds (hardcoded interval in + # Scheduler.retire_worker when AMM is not enabled). + strict=True, + ) + assert key in w_to.data + # The key may or may not still be in w_from.data, depending if the AMM had the + # chance to run a second time after the copy was successful. + + +@pytest.mark.slow @gen_cluster(client=True) async def test_close_gracefully(c, s, a, b): futures = c.map(slowinc, range(200), delay=0.1, workers=[b.address]) @@ -1649,15 +1680,7 @@ async def test_close_gracefully(c, s, a, b): # All tasks that were in memory in b have been copied over to a; # they have not been recomputed for key in mem: - assert_worker_story( - a.story(key), - [ - (key, "put-in-memory"), - (key, "receive-from-scatter"), - ], - strict=True, - ) - assert key in a.data + assert_amm_transfer_story(key, b, a) @pytest.mark.slow @@ -1680,15 +1703,7 @@ async def test_lifetime(c, s, a): # All tasks that were in memory in b have been copied over to a; # they have not been recomputed for key in mem: - assert_worker_story( - a.story(key), - [ - (key, "put-in-memory"), - (key, "receive-from-scatter"), - ], - strict=True, - ) - assert key in a.data + assert_amm_transfer_story(key, b, a) @gen_cluster(worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"}) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index cc628cbb401..9468b0b023b 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -76,7 +76,7 @@ reset_logger_locks, sync, ) -from .worker import RUNNING, Worker +from .worker import Worker try: import dask.array # register config @@ -1658,7 +1658,7 @@ def check_instances(): for w in Worker._instances: with suppress(RuntimeError): # closed IOLoop w.loop.add_callback(w.close, report=False, executor_wait=False) - if w.status in RUNNING: + if w.status in Status.ANY_RUNNING: w.loop.add_callback(w.close) Worker._instances.clear() diff --git a/distributed/worker.py b/distributed/worker.py index 7436f204fca..18330eee127 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -112,9 +112,6 @@ READY = {"ready", "constrained"} FETCH_INTENDED = {"missing", "fetch", "flight", "cancelled", "resumed"} -# Worker.status subsets -RUNNING = {Status.running, Status.paused, Status.closing_gracefully} - DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension] DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {} @@ -928,6 +925,7 @@ def __init__( "free-keys": self.handle_free_keys, "remove-replicas": self.handle_remove_replicas, "steal-request": self.handle_steal_request, + "worker-status-change": self.handle_worker_status_change, } super().__init__( @@ -1288,9 +1286,9 @@ async def heartbeat(self): # If running, wait up to 0.5s and then re-register self. # Otherwise just exit. start = time() - while self.status in RUNNING and time() < start + 0.5: + while self.status in Status.ANY_RUNNING and time() < start + 0.5: await asyncio.sleep(0.01) - if self.status in RUNNING: + if self.status in Status.ANY_RUNNING: await self._register_with_scheduler() return @@ -1322,7 +1320,7 @@ async def handle_scheduler(self, comm): logger.exception(e) raise finally: - if self.reconnect and self.status in RUNNING: + if self.reconnect and self.status in Status.ANY_RUNNING: logger.info("Connection to scheduler broken. Reconnecting...") self.loop.add_callback(self.heartbeat) else: @@ -1547,7 +1545,7 @@ async def close( logger.info("Stopping worker at %s", self.address) except ValueError: # address not available if already closed logger.info("Stopping worker") - if self.status not in RUNNING: + if self.status not in Status.ANY_RUNNING: logger.info("Closed worker has not yet started: %s", self.status) self.status = Status.closing @@ -1575,7 +1573,9 @@ async def close( # If this worker is the last one alive, clean up the worker # initialized clients if not any( - w for w in Worker._instances if w != self and w.status in RUNNING + w + for w in Worker._instances + if w != self and w.status in Status.ANY_RUNNING ): for c in Worker._initialized_clients: # Regardless of what the client was initialized with @@ -1656,8 +1656,12 @@ async def close_gracefully(self, restart=None): restart = self.lifetime_restart logger.info("Closing worker gracefully: %s", self.address) - self.status = Status.closing_gracefully - await self.scheduler.retire_workers(workers=[self.address], remove=False) + # Wait for all tasks to leave the worker and don't accept any new ones. + # Scheduler.retire_workers will set the status to closing_gracefully and push it + # back to this worker. + await self.scheduler.retire_workers( + workers=[self.address], close_workers=False, remove=False + ) await self.close(safe=True, nanny=not restart) async def terminate(self, comm=None, report=True, **kwargs): @@ -2973,7 +2977,7 @@ async def gather_dep( total_nbytes : int Total number of bytes for all the dependencies in to_gather combined """ - if self.status not in RUNNING: + if self.status not in Status.ANY_RUNNING: # type: ignore return recommendations: dict[TaskState, str | tuple] = {} @@ -3185,6 +3189,22 @@ def handle_steal_request(self, key, stimulus_id): # `transition_constrained_executing` self.transition(ts, "released", stimulus_id=stimulus_id) + def handle_worker_status_change(self, status: str) -> None: + new_status = Status.lookup[status] # type: ignore + + if ( + new_status == Status.closing_gracefully + and self._status not in Status.ANY_RUNNING # type: ignore + ): + logger.error( + "Invalid Worker.status transition: %s -> %s", self._status, new_status + ) + # Reiterate the current status to the scheduler to restore sync + self._send_worker_status_change() + else: + # Update status and send confirmation to the Scheduler (see status.setter) + self.status = new_status + def release_key( self, key: str, @@ -3397,7 +3417,7 @@ async def _maybe_deserialize_task(self, ts, *, stimulus_id): raise def ensure_computing(self): - if self.status == Status.paused: + if self.status in (Status.paused, Status.closing_gracefully): return try: stimulus_id = f"ensure-computing-{time()}" @@ -3434,7 +3454,7 @@ def ensure_computing(self): raise async def execute(self, key, *, stimulus_id): - if self.status in (Status.closing, Status.closed, Status.closing_gracefully): + if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: return if key not in self.tasks: return @@ -3951,7 +3971,7 @@ def validate_task(self, ts): ) from e def validate_state(self): - if self.status not in RUNNING: + if self.status not in Status.ANY_RUNNING: return try: assert self.executing_count >= 0 @@ -4118,7 +4138,11 @@ def get_worker() -> Worker: return thread_state.execution_state["worker"] except AttributeError: try: - return first(w for w in Worker._instances if w.status in RUNNING) + return first( + w + for w in Worker._instances + if w.status in Status.ANY_RUNNING # type: ignore + ) except StopIteration: raise ValueError("No workers found") diff --git a/docs/source/active_memory_manager.rst b/docs/source/active_memory_manager.rst index 8f515204f1a..0ed351ee359 100644 --- a/docs/source/active_memory_manager.rst +++ b/docs/source/active_memory_manager.rst @@ -257,3 +257,6 @@ API reference :undoc-members: .. autoclass:: distributed.active_memory_manager.ReduceReplicas + +.. autoclass:: distributed.active_memory_manager.RetireWorker + :members: