diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d3aa7bcb9d..b6b5154f33 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -102,7 +102,7 @@ @gen_cluster(client=True) async def test_submit(c, s, a, b): - x = c.submit(inc, 10) + x = c.submit(inc, 10, key="x") assert not x.done() assert isinstance(x, Future) @@ -112,7 +112,7 @@ async def test_submit(c, s, a, b): assert result == 11 assert x.done() - y = c.submit(inc, 20) + y = c.submit(inc, 20, key="y") z = c.submit(add, x, y) result = await z diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 4b4c4448ec..70c935cb47 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -987,6 +987,8 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers): slowinc, range(10), key=[f"f1-{ix}" for ix in range(10)], + workers=[w0.address], + allow_other_workers=True, ) while not w0.active_keys: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 47233bc135..63d3401591 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1804,14 +1804,14 @@ async def test_story_with_deps(c, s, a, b): stimulus_ids.add(msg[-2]) pruned_story.append(tuple(pruned_msg[:-2])) - assert len(stimulus_ids) == 3 + assert len(stimulus_ids) == 3, stimulus_ids stimulus_id = pruned_story[0][-1] assert isinstance(stimulus_id, str) assert stimulus_id.startswith("compute-task") # This is a simple transition log expected_story = [ (key, "compute-task"), - (key, "released", "waiting", {}), + (key, "released", "waiting", {dep.key: "fetch"}), (key, "waiting", "ready", {}), (key, "ready", "executing", {}), (key, "put-in-memory"), @@ -1832,11 +1832,11 @@ async def test_story_with_deps(c, s, a, b): stimulus_ids.add(msg[-2]) pruned_story.append(tuple(pruned_msg[:-2])) - assert len(stimulus_ids) == 3 + assert len(stimulus_ids) == 2, stimulus_ids stimulus_id = pruned_story[0][-1] assert isinstance(stimulus_id, str) expected_story = [ - (dep_story, "register-replica", "released"), + (dep_story, "ensure-task-exists", "released"), (dep_story, "released", "fetch", {}), ( "gather-dependencies", @@ -2794,7 +2794,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b): _acquire_replicas(s, b, fut) await futC - while fut.key not in b.tasks: + while fut.key not in b.tasks or not b.tasks[fut.key].state == "memory": await asyncio.sleep(0.005) assert len(s.who_has[fut.key]) == 2 @@ -3082,12 +3082,14 @@ def clear_leak(): ] -async def _wait_for_flight(key, worker): - while key not in worker.tasks or worker.tasks[key].state != "flight": +async def _wait_for_state(key: str, worker: Worker, state: str): + # Keep the sleep interval at 0 since the tests using this are very sensitive + # about timing. they intend to capture loop cycles after this specific + # condition was set + while key not in worker.tasks or worker.tasks[key].state != state: await asyncio.sleep(0) -@pytest.mark.xfail(reason="#5406") @gen_cluster(client=True) async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b): """At time of writing, the gather_dep implementation filtered tasks again @@ -3107,21 +3109,26 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, fut2_key = fut2.key - await _wait_for_flight(fut2_key, b) + await _wait_for_state(fut2_key, b, "flight") + while not mocked_gather.call_args: + await asyncio.sleep(0) fut4.release() while fut4.key in b.tasks: await asyncio.sleep(0) - story_before = b.story(fut2.key) - assert fut2.key in mocked_gather.call_args.kwargs["to_gather"] - await Worker.gather_dep(b, **mocked_gather.call_args.kwargs) - story_after = b.story(fut2.key) - assert story_before == story_after + assert b.tasks[fut2.key].state == "cancelled" + args, kwargs = mocked_gather.call_args + assert fut2.key in kwargs["to_gather"] + + await Worker.gather_dep(b, *args, **kwargs) + assert fut2.key not in b.tasks + f2_story = b.story(fut2.key) + assert f2_story + assert not any("missing-dep" in msg for msg in b.story(fut2.key)) await fut3 -@pytest.mark.xfail(reason="#5406") @gen_cluster( client=True, config={ @@ -3137,13 +3144,55 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a, b): fut1_key = fut1.key - await _wait_for_flight(fut1_key, b) + await _wait_for_state(fut1_key, b, "flight") + while not mocked_gather.call_args: + await asyncio.sleep(0) fut2.release() while fut2.key in b.tasks: await asyncio.sleep(0) - assert b.tasks[fut1.key] != "flight" - log_before = list(b.log) - await Worker.gather_dep(b, **mocked_gather.call_args.kwargs) - assert log_before == list(b.log) + assert b.tasks[fut1.key].state == "cancelled" + + args, kwargs = mocked_gather.call_args + await Worker.gather_dep(b, *args, **kwargs) + + assert fut2.key not in b.tasks + f1_story = b.story(fut1.key) + assert f1_story + assert not any("missing-dep" in msg for msg in b.story(fut2.key)) + + +@pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"]) +@pytest.mark.parametrize("close_worker", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( + c, s, a, b, x, intermediate_state, close_worker +): + """If a task was transitioned to in-flight, the gather-dep coroutine was + scheduled but a cancel request came in before gather_data_from_worker was + issued this might corrupt the state machine if the cancelled key is not + properly handled""" + + fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1") + fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B") + fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2") + await fut2 + with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather: + fut3 = c.submit(inc, fut2, workers=[b.address], key="f3") + + fut2_key = fut2.key + + await _wait_for_state(fut2_key, b, "flight") + + s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address}) + while not mocked_gather.call_args: + await asyncio.sleep(0) + + await s.remove_worker(address=x.address, safe=True, close=close_worker) + + await _wait_for_state(fut2_key, b, intermediate_state) + + args, kwargs = mocked_gather.call_args + await Worker.gather_dep(b, *args, **kwargs) + await fut3 diff --git a/distributed/worker.py b/distributed/worker.py index 17e514b206..af44f1f660 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -19,7 +19,7 @@ from datetime import timedelta from inspect import isawaitable from pickle import PicklingError -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar if TYPE_CHECKING: from typing_extensions import Literal @@ -1805,7 +1805,7 @@ def handle_cancel_compute(self, key, reason): Nothing will happen otherwise. """ ts = self.tasks.get(key) - if ts and ts.state in ("waiting", "ready"): + if ts and ts.state in READY | {"waiting"}: self.log.append((key, "cancel-compute", reason)) ts.scheduler_holds_ref = False # All possible dependents of TS should not be in state Processing on @@ -1820,13 +1820,13 @@ def handle_acquire_replicas( recommendations = {} scheduler_msgs = [] for k in keys: - recs, smsgs = self.register_acquire_internal( + ts = self.ensure_task_exists( k, stimulus_id=stimulus_id, priority=priorities[k], ) - recommendations.update(recs) - scheduler_msgs += smsgs + if ts.state != "memory": + recommendations[ts] = "fetch" self.update_who_has(who_has, stimulus_id=stimulus_id) @@ -1834,7 +1834,9 @@ def handle_acquire_replicas( self.batched_stream.send(msg) self.transitions(recommendations, stimulus_id=stimulus_id) - def register_acquire_internal(self, key, priority, stimulus_id): + def ensure_task_exists( + self, key: str, priority: tuple, stimulus_id: str + ) -> TaskState: try: ts = self.tasks[key] logger.debug( @@ -1843,21 +1845,14 @@ def register_acquire_internal(self, key, priority, stimulus_id): except KeyError: self.tasks[key] = ts = TaskState(key) - self.log.append((key, "register-replica", ts.state, stimulus_id, time())) + self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time())) ts.priority = ts.priority or priority - - recommendations = {} - scheduler_msgs = [] - if ts.state in ("released", "cancelled", "error"): - recommendations[ts] = "fetch" - - return recommendations, scheduler_msgs + return ts def handle_compute_task( self, *, key, - # FIXME: This will break protocol function=None, args=None, kwargs=None, @@ -1903,25 +1898,29 @@ def handle_compute_task( recommendations = {} scheduler_msgs = [] for dependency in who_has: - recs, smsgs = self.register_acquire_internal( + dep_ts = self.ensure_task_exists( key=dependency, stimulus_id=stimulus_id, priority=priority, ) - recommendations.update(recs) - scheduler_msgs += smsgs - dep_ts = self.tasks[dependency] # link up to child / parents ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) - if ts.state in {"ready", "executing", "waiting", "resumed"}: + if ts.state in READY | {"executing", "waiting", "resumed"}: pass elif ts.state == "memory": recommendations[ts] = "memory" scheduler_msgs.append(self.get_task_state_for_scheduler(ts)) - elif ts.state in {"released", "fetch", "flight", "missing", "cancelled"}: + elif ts.state in { + "released", + "fetch", + "flight", + "missing", + "cancelled", + "error", + }: recommendations[ts] = "waiting" else: raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") @@ -1940,6 +1939,7 @@ def handle_compute_task( def transition_missing_fetch(self, ts, *, stimulus_id): self._missing_dep_flight.discard(ts) ts.state = "fetch" + ts.done = False heapq.heappush(self.data_needed, (ts.priority, ts.key)) return {}, [] @@ -1959,6 +1959,7 @@ def transition_released_fetch(self, ts, *, stimulus_id): for w in ts.who_has: self.pending_data_per_worker[w].append(ts.key) ts.state = "fetch" + ts.done = False heapq.heappush(self.data_needed, (ts.priority, ts.key)) return {}, [] @@ -1977,6 +1978,8 @@ def transition_released_waiting(self, ts, *, stimulus_id): if not dep_ts.state == "memory": ts.waiting_for_data.add(dep_ts) dep_ts.waiters.add(ts) + if dep_ts.state not in {"fetch", "flight"}: + recommendations[dep_ts] = "fetch" if ts.waiting_for_data: self.waiting_for_data_count += 1 @@ -1994,6 +1997,7 @@ def transition_fetch_flight(self, ts, worker, *, stimulus_id): assert ts.who_has assert ts.key not in self.data_needed + ts.done = False ts.state = "flight" ts.coming_from = worker self._in_flight_tasks.add(ts) @@ -2205,6 +2209,7 @@ def transition_flight_fetch(self, ts, *, stimulus_id): for w in ts.who_has: self.pending_data_per_worker[w].append(ts.key) ts.state = "fetch" + ts.done = False heapq.heappush(self.data_needed, (ts.priority, ts.key)) return {}, [] @@ -2457,7 +2462,7 @@ def ensure_communicating(self): to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key) self.log.append( - ("gather-dependencies", worker, to_gather, "stimulus", time()) + ("gather-dependencies", worker, to_gather, stimulus_id, time()) ) self.comm_nbytes += total_nbytes @@ -2589,6 +2594,95 @@ def total_comm_bytes(self): ) return self.comm_threshold_bytes + def _filter_deps_for_fetch( + self, to_gather_keys: Iterable[str] + ) -> tuple[set[str], set[str], TaskState | None]: + """Filter a list of keys before scheduling coroutines to fetch data from workers. + + Returns + ------- + in_flight_keys: + The subset of keys in to_gather_keys in state `flight` + cancelled_keys: + The subset of tasks in to_gather_keys in state `cancelled` + cause: + The task to attach startstops of this transfer to + """ + in_flight_tasks: set[TaskState] = set() + cancelled_keys: set[str] = set() + for key in to_gather_keys: + ts = self.tasks.get(key) + if ts is None: + continue + if ts.state in ("flight", "resumed"): + in_flight_tasks.add(ts) + elif ts.state == "cancelled": + cancelled_keys.add(key) + else: + raise RuntimeError( + f"Task {ts.key} found in illegal state {ts.state}. " + "Only states `flight`, `resumed` and `cancelled` possible." + ) + + # For diagnostics we want to attach the transfer to a single task. this + # task is typically the next to be executed but since we're fetching + # tasks for potentially many dependents, an exact match is not possible. + # If there are no dependents, this is a pure replica fetch + cause = None + for ts in in_flight_tasks: + if ts.dependents: + cause = next(iter(ts.dependents)) + break + else: + cause = ts + in_flight_keys = {ts.key for ts in in_flight_tasks} + return in_flight_keys, cancelled_keys, cause + + def _update_metrics_received_data( + self, start: float, stop: float, data: dict, cause: TaskState, worker: str + ) -> None: + + total_bytes = sum(self.tasks[key].get_nbytes() for key in data) + + cause.startstops.append( + { + "action": "transfer", + "start": start + self.scheduler_delay, + "stop": stop + self.scheduler_delay, + "source": worker, + } + ) + duration = (stop - start) or 0.010 + bandwidth = total_bytes / duration + self.incoming_transfer_log.append( + { + "start": start + self.scheduler_delay, + "stop": stop + self.scheduler_delay, + "middle": (start + stop) / 2.0 + self.scheduler_delay, + "duration": duration, + "keys": {key: self.tasks[key].nbytes for key in data}, + "total": total_bytes, + "bandwidth": bandwidth, + "who": worker, + } + ) + if total_bytes > 1_000_000: + self.bandwidth = self.bandwidth * 0.95 + bandwidth * 0.05 + bw, cnt = self.bandwidth_workers[worker] + self.bandwidth_workers[worker] = (bw + bandwidth, cnt + 1) + + types = set(map(type, data.values())) + if len(types) == 1: + [typ] = types + bw, cnt = self.bandwidth_types[typ] + self.bandwidth_types[typ] = (bw + bandwidth, cnt + 1) + + if self.digests is not None: + self.digests["transfer-bandwidth"].add(total_bytes / duration) + self.digests["transfer-duration"].add(duration) + self.counters["transfer-count"].add(len(data)) + self.incoming_count += 1 + async def gather_dep( self, worker: str, @@ -2610,39 +2704,28 @@ async def gather_dep( total_nbytes : int Total number of bytes for all the dependencies in to_gather combined """ - cause: TaskState | None = None if self.status not in RUNNING: return with log_errors(): response = {} - to_gather_keys = set() + to_gather_keys: set[str] = set() + cancelled_keys: set[str] = set() try: - found_dependent_for_cause = False - for dependency_key in to_gather: - dependency_ts = self.tasks.get(dependency_key) - if dependency_ts and dependency_ts.state == "flight": - to_gather_keys.add(dependency_key) - if not found_dependent_for_cause: - cause = dependency_ts - # For diagnostics we want to attach the transfer to - # a single task. this task is typically the next to - # be executed but since we're fetching tasks for - # potentially many dependents, an exact match is not - # possible. If there are no dependents, this is a - # pure replica fetch - for dependent in dependency_ts.dependents: - cause = dependent - found_dependent_for_cause = True - break + to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch( + to_gather + ) if not to_gather_keys: + self.log.append( + ("nothing-to-gather", worker, to_gather, stimulus_id) + ) return - assert cause + assert cause # Keep namespace clean since this func is long and has many # dep*, *ts* variables - del to_gather, dependency_key, dependency_ts + del to_gather self.log.append( ("request-dep", worker, to_gather_keys, stimulus_id, time()) @@ -2662,53 +2745,13 @@ async def gather_dep( if response["status"] == "busy": return - data = {k: v for k, v in response["data"].items() if k in self.tasks} - lost_keys = response["data"].keys() - data.keys() - - if lost_keys: - self.log.append(("lost-during-gather", lost_keys, stimulus_id)) - - total_bytes = sum(self.tasks[key].get_nbytes() for key in data) - - cause.startstops.append( - { - "action": "transfer", - "start": start + self.scheduler_delay, - "stop": stop + self.scheduler_delay, - "source": worker, - } + self._update_metrics_received_data( + start=start, + stop=stop, + data=response["data"], + cause=cause, + worker=worker, ) - duration = (stop - start) or 0.010 - bandwidth = total_bytes / duration - self.incoming_transfer_log.append( - { - "start": start + self.scheduler_delay, - "stop": stop + self.scheduler_delay, - "middle": (start + stop) / 2.0 + self.scheduler_delay, - "duration": duration, - "keys": {key: self.tasks[key].nbytes for key in data}, - "total": total_bytes, - "bandwidth": bandwidth, - "who": worker, - } - ) - if total_bytes > 1000000: - self.bandwidth = self.bandwidth * 0.95 + bandwidth * 0.05 - bw, cnt = self.bandwidth_workers[worker] - self.bandwidth_workers[worker] = (bw + bandwidth, cnt + 1) - - types = set(map(type, response["data"].values())) - if len(types) == 1: - [typ] = types - bw, cnt = self.bandwidth_types[typ] - self.bandwidth_types[typ] = (bw + bandwidth, cnt + 1) - - if self.digests is not None: - self.digests["transfer-bandwidth"].add(total_bytes / duration) - self.digests["transfer-duration"].add(duration) - self.counters["transfer-count"].add(len(response["data"])) - self.incoming_count += 1 - self.log.append( ("receive-dep", worker, set(response["data"]), stimulus_id, time()) ) @@ -2742,25 +2785,24 @@ async def gather_dep( ) recommendations: dict[TaskState, str | tuple] = {} - deps_to_iter = set(self.in_flight_workers.pop(worker)) & to_gather_keys - for d in deps_to_iter: - ts = cast(TaskState, self.tasks.get(d)) - assert ts, (d, self.story(d)) + for d in self.in_flight_workers.pop(worker): + ts = self.tasks[d] ts.done = True - if d in data: + if d in cancelled_keys: + recommendations[ts] = "released" + elif d in data: recommendations[ts] = ("memory", data[d]) - elif not busy: + elif busy: + recommendations[ts] = "fetch" + else: ts.who_has.discard(worker) self.has_what[worker].discard(ts.key) self.log.append((d, "missing-dep")) self.batched_stream.send( {"op": "missing-data", "errant_worker": worker, "key": d} ) - - if ts.state != "memory" and ts not in recommendations: recommendations[ts] = "fetch" - del data, response self.transitions( recommendations=recommendations, stimulus_id=stimulus_id @@ -2860,7 +2902,7 @@ def handle_steal_request(self, key, stimulus_id): } self.batched_stream.send(response) - if state in {"ready", "waiting", "constrained"}: + if state in READY | {"waiting"}: # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `transition_constrained_executing` @@ -3534,6 +3576,7 @@ def validate_task_ready(self, ts): assert ts.key in pluck(1, self.ready) assert ts.key not in self.data assert ts.state != "executing" + assert not ts.done assert not ts.waiting_for_data assert all( dep.key in self.data or dep.key in self.actors for dep in ts.dependencies @@ -3542,6 +3585,7 @@ def validate_task_ready(self, ts): def validate_task_waiting(self, ts): assert ts.key not in self.data assert ts.state == "waiting" + assert not ts.done if ts.dependencies and ts.runspec: assert not all(dep.key in self.data for dep in ts.dependencies) @@ -3556,6 +3600,7 @@ def validate_task_flight(self, ts): def validate_task_fetch(self, ts): assert ts.key not in self.data assert self.address not in ts.who_has + assert not ts.done for w in ts.who_has: assert ts.key in self.has_what[w] @@ -3563,6 +3608,7 @@ def validate_task_fetch(self, ts): def validate_task_missing(self, ts): assert ts.key not in self.data assert not ts.who_has + assert not ts.done assert not any(ts.key in has_what for has_what in self.has_what.values()) assert ts.key in self._missing_dep_flight @@ -3620,7 +3666,10 @@ def validate_task(self, ts): import pdb pdb.set_trace() - raise + + raise AssertionError( + f"Invalid TaskState encountered for {ts!r}.\nStory:\n{self.story(ts)}\n" + ) from e def validate_state(self): if self.status not in RUNNING: @@ -3649,7 +3698,7 @@ def validate_state(self): assert ts_wait.key in self.tasks assert ( ts_wait.state - in {"ready", "executing", "flight", "fetch", "missing"} + in READY | {"executing", "flight", "fetch", "missing"} or ts_wait.key in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) @@ -3947,7 +3996,7 @@ async def get_data_from_worker( See Also -------- Worker.get_data - Worker.gather_deps + Worker.gather_dep utils_comm.gather_data_from_workers """ if serializers is None: