diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 34c7c66cd1..5db88dc552 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3091,3 +3091,68 @@ def clear_leak(): {"action": "remove-worker", "processing-tasks": {}}, {"action": "retired"}, ] + + +async def _wait_for_flight(key, worker): + while key not in worker.tasks or worker.tasks[key].state != "flight": + await asyncio.sleep(0) + + +@gen_cluster(client=True) +async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b): + """At time of writing, the gather_dep implementation filtered tasks again + for in-flight state. The response parser, however, did not distinguish + resulting in unwanted missing-data signals to the scheduler, causing + potential rescheduling or data leaks. + This test may become obsolete if the implementation changes significantly. + """ + import distributed + + with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather: + fut1 = c.submit(inc, 1, workers=[a.address], key="f1") + fut2 = c.submit(inc, fut1, workers=[a.address], key="f2") + await fut2 + fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4") + fut3 = c.submit(inc, fut1, workers=[b.address], key="f3") + + fut2_key = fut2.key + + await _wait_for_flight(fut2_key, b) + + fut4.release() + while fut4.key in b.tasks: + await asyncio.sleep(0) + + story_before = b.story(fut2.key) + assert fut2.key in mocked_gather.call_args.kwargs["to_gather"] + await Worker.gather_dep(b, **mocked_gather.call_args.kwargs) + story_after = b.story(fut2.key) + assert story_before == story_after + await fut3 + + +@gen_cluster( + client=True, + config={ + "distributed.comm.recent-messages-log-length": 1000, + }, +) +async def test_gather_dep_no_longer_in_flight_tasks(c, s, a, b): + import distributed + + with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather: + fut1 = c.submit(inc, 1, workers=[a.address], key="f1") + fut2 = c.submit(sum, fut1, fut1, workers=[b.address], key="f2") + + fut1_key = fut1.key + + await _wait_for_flight(fut1_key, b) + + fut2.release() + while fut2.key in b.tasks: + await asyncio.sleep(0) + + assert b.tasks[fut1.key] != "flight" + log_before = list(b.log) + await Worker.gather_dep(b, **mocked_gather.call_args.kwargs) + assert log_before == list(b.log) diff --git a/distributed/worker.py b/distributed/worker.py index 2acba897d2..02dc15513c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2514,10 +2514,11 @@ async def gather_dep( cause = dependent found_dependent_for_cause = True break + + if not to_gather_keys: + return # Keep namespace clean since this func is long and has many # dep*, *ts* variables - - assert cause is not None del to_gather, dependency_key, dependency_ts self.log.append( @@ -2618,7 +2619,7 @@ async def gather_dep( ) recommendations: dict[TaskState, str | tuple] = {} - deps_to_iter = self.in_flight_workers.pop(worker) + deps_to_iter = set(self.in_flight_workers.pop(worker)) & to_gather_keys for d in deps_to_iter: ts = self.tasks.get(d)