From 442064475ec7f6ce92597e9f5d40b1031d73286c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 20 May 2022 14:11:52 +0100 Subject: [PATCH] Do not filter tasks before gathering data (#6371) --- distributed/tests/test_cancelled_state.py | 4 +- .../tests/test_worker_state_machine.py | 48 +++++++- distributed/utils_test.py | 2 +- distributed/worker.py | 109 ++++-------------- 4 files changed, 75 insertions(+), 88 deletions(-) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 5878120370..a4dca9e287 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -321,7 +321,7 @@ async def get_data(self, comm, *args, **kwargs): # The initial free-keys is rejected ("free-keys", (fut1.key,)), (fut1.key, "resumed", "released", "cancelled", {}), - # After gather_dep receives the data, it tries to transition to memory but the task will release instead + # After gather_dep receives the data, the task is forgotten (fut1.key, "cancelled", "memory", "released", {fut1.key: "forgotten"}), ], ) @@ -366,7 +366,7 @@ def block_execution(event, lock): @gen_cluster(client=True, nthreads=[("", 1, {"resources": {"A": 1}})]) -async def test_cancelled_error_with_ressources(c, s, a): +async def test_cancelled_error_with_resources(c, s, a): executing = Event() lock_executing = Lock() await lock_executing.acquire() diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 9295b9987d..1b6f2e3f80 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -5,7 +5,7 @@ from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict -from distributed.utils_test import assert_story, gen_cluster, inc +from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc from distributed.worker_state_machine import ( ExecuteFailureEvent, ExecuteSuccessEvent, @@ -302,3 +302,49 @@ async def test_fetch_via_amm_to_compute(c, s, a, b): b.comm_threshold_bytes = old_comm_threshold await f1 + + +@gen_cluster(client=True) +async def test_cancelled_while_in_flight(c, s, a, b): + event = asyncio.Event() + a.rpc = _LockedCommPool(a.rpc, write_event=event) + + x = c.submit(inc, 1, key="x", workers=[b.address]) + y = c.submit(inc, x, key="y", workers=[a.address]) + await wait_for_state("x", "flight", a) + y.release() + await wait_for_state("x", "cancelled", a) + + # Let the comm from b to a return the result + event.set() + # upon reception, x transitions cancelled->forgotten + while a.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_in_memory_while_in_flight(c, s, a, b): + """ + 1. A client scatters x to a + 2. The scheduler does not know about scattered keys until the three-way round-trip + between client, worker, and scheduler has been completed (see Scheduler.scatter) + 3. In the middle of that handshake, a client (not necessarily the same client) calls + ``{op: compute-task, key: x}`` on b and then + ``{op: compute-task, key: y, who_has: {x: [b]}`` on a, which triggers a + gather_dep call to copy x key from b to a. + 4. while x is in flight from b to a, the scatter finishes, which triggers + update_data, which in turn transitions x from flight to memory. + 5. later on, gather_dep finishes, but the key is already in memory. + """ + event = asyncio.Event() + a.rpc = _LockedCommPool(a.rpc, write_event=event) + + x = c.submit(inc, 1, key="x", workers=[b.address]) + y = c.submit(inc, x, key="y", workers=[a.address]) + await wait_for_state("x", "flight", a) + a.update_data({"x": 3}) + await wait_for_state("x", "memory", a) + + # Let the comm from b to a return the result + event.set() + assert await y == 4 # Data in flight from b has been discarded diff --git a/distributed/utils_test.py b/distributed/utils_test.py index fc75ea1817..8cce6d4449 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1923,7 +1923,7 @@ class _LockedCommPool(ConnectionPool): ) # It might be necessary to remove all existing comms # if the wrapped pool has been used before - >>> w.remove(remote_address) + >>> w.rpc.remove(remote_address) >>> async def ping_pong(): return await w.rpc(remote_address).ping() diff --git a/distributed/worker.py b/distributed/worker.py index dc5068835d..9d09814047 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3156,61 +3156,32 @@ def total_comm_bytes(self): ) return self.comm_threshold_bytes - def _filter_deps_for_fetch( - self, to_gather_keys: Iterable[str] - ) -> tuple[set[str], set[str], TaskState | None]: - """Filter a list of keys before scheduling coroutines to fetch data from workers. + def _get_cause(self, keys: Iterable[str]) -> TaskState: + """For diagnostics, we want to attach a transfer to a single task. This task is + typically the next to be executed but since we're fetching tasks for potentially + many dependents, an exact match is not possible. Additionally, if a key was + fetched through acquire-replicas, dependents may not be known at all. Returns ------- - in_flight_keys: - The subset of keys in to_gather_keys in state `flight` or `resumed` - cancelled_keys: - The subset of tasks in to_gather_keys in state `cancelled` or `memory` - cause: - The task to attach startstops of this transfer to + The task to attach startstops of this transfer to """ - in_flight_tasks: set[TaskState] = set() - cancelled_keys: set[str] = set() - for key in to_gather_keys: - ts = self.tasks.get(key) - if ts is None: - continue - - # At this point, a task has been transitioned fetch->flight - # flight is only allowed to be transitioned into - # {memory, resumed, cancelled} - # resumed and cancelled will block any further transition until this - # coro has been finished - - if ts.state in ("flight", "resumed"): - in_flight_tasks.add(ts) - # If the key is already in memory, the fetch should not happen which - # is signalled by the cancelled_keys - elif ts.state in {"cancelled", "memory"}: - cancelled_keys.add(key) - else: - raise RuntimeError( - f"Task {ts.key} found in illegal state {ts.state}. " - "Only states `flight`, `resumed` and `cancelled` possible." - ) - - # For diagnostics we want to attach the transfer to a single task. this - # task is typically the next to be executed but since we're fetching - # tasks for potentially many dependents, an exact match is not possible. - # If there are no dependents, this is a pure replica fetch cause = None - for ts in in_flight_tasks: + for key in keys: + ts = self.tasks[key] if ts.dependents: - cause = next(iter(ts.dependents)) - break - else: - cause = ts - in_flight_keys = {ts.key for ts in in_flight_tasks} - return in_flight_keys, cancelled_keys, cause + return next(iter(ts.dependents)) + cause = ts + assert cause # Always at least one key + return cause def _update_metrics_received_data( - self, start: float, stop: float, data: dict, cause: TaskState, worker: str + self, + start: float, + stop: float, + data: dict[str, Any], + cause: TaskState, + worker: str, ) -> None: total_bytes = sum(self.tasks[key].get_nbytes() for key in data) @@ -3259,7 +3230,7 @@ def _update_metrics_received_data( async def gather_dep( self, worker: str, - to_gather: Iterable[str], + to_gather: Collection[str], total_nbytes: int, *, stimulus_id: str, @@ -3283,46 +3254,23 @@ async def gather_dep( recommendations: Recs = {} instructions: Instructions = [] response = {} - to_gather_keys: set[str] = set() - cancelled_keys: set[str] = set() def done_event(): return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}") try: - to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch( - to_gather - ) - - if not to_gather_keys: - self.log.append( - ("nothing-to-gather", worker, to_gather, stimulus_id, time()) - ) - return done_event() - - assert cause - # Keep namespace clean since this func is long and has many - # dep*, *ts* variables - del to_gather - - self.log.append( - ("request-dep", worker, to_gather_keys, stimulus_id, time()) - ) - logger.debug( - "Request %d keys for task %s from %s", - len(to_gather_keys), - cause, - worker, - ) + self.log.append(("request-dep", worker, to_gather, stimulus_id, time())) + logger.debug("Request %d keys from %s", len(to_gather), worker) start = time() response = await get_data_from_worker( - self.rpc, to_gather_keys, worker, who=self.address + self.rpc, to_gather, worker, who=self.address ) stop = time() if response["status"] == "busy": return done_event() + cause = self._get_cause(to_gather) self._update_metrics_received_data( start=start, stop=stop, @@ -3375,9 +3323,7 @@ def done_event(): data = response.get("data", {}) if busy: - self.log.append( - ("busy-gather", worker, to_gather_keys, stimulus_id, time()) - ) + self.log.append(("busy-gather", worker, to_gather, stimulus_id, time())) # Avoid hammering the worker. If there are multiple replicas # available, immediately try fetching from a different worker. self.busy_workers.add(worker) @@ -3388,12 +3334,7 @@ def done_event(): for d in self.in_flight_workers.pop(worker): ts = self.tasks[d] ts.done = True - if d in cancelled_keys: - if ts.state == "cancelled": - recommendations[ts] = "released" - else: - recommendations[ts] = "fetch" - elif d in data: + if d in data: recommendations[ts] = ("memory", data[d]) elif busy: recommendations[ts] = "fetch"