Skip to content

Commit

Permalink
Do not attempt to fetch keys which are no longer in flight
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 4, 2021
1 parent cf018a1 commit 3da4d9e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
65 changes: 65 additions & 0 deletions distributed/tests/test_worker.py
Expand Up @@ -3091,3 +3091,68 @@ def clear_leak():
{"action": "remove-worker", "processing-tasks": {}},
{"action": "retired"},
]


async def _wait_for_flight(key, worker):
while key not in worker.tasks or worker.tasks[key].state != "flight":
await asyncio.sleep(0)


@gen_cluster(client=True)
async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b):
"""At time of writing, the gather_dep implementation filtered tasks again
for in-flight state. The response parser, however, did not distinguish
resulting in unwanted missing-data signals to the scheduler, causing
potential rescheduling or data leaks.
This test may become obsolete if the implementation changes significantly.
"""
import distributed

with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather:
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[a.address], key="f2")
await fut2
fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4")
fut3 = c.submit(inc, fut1, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_flight(fut2_key, b)

fut4.release()
while fut4.key in b.tasks:
await asyncio.sleep(0)

story_before = b.story(fut2.key)
assert fut2.key in mocked_gather.call_args.kwargs["to_gather"]
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
story_after = b.story(fut2.key)
assert story_before == story_after
await fut3


@gen_cluster(
client=True,
config={
"distributed.comm.recent-messages-log-length": 1000,
},
)
async def test_gather_dep_no_longer_in_flight_tasks(c, s, a, b):
import distributed

with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather:
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(sum, fut1, fut1, workers=[b.address], key="f2")

fut1_key = fut1.key

await _wait_for_flight(fut1_key, b)

fut2.release()
while fut2.key in b.tasks:
await asyncio.sleep(0)

assert b.tasks[fut1.key] != "flight"
log_before = list(b.log)
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
assert log_before == list(b.log)
7 changes: 4 additions & 3 deletions distributed/worker.py
Expand Up @@ -2514,10 +2514,11 @@ async def gather_dep(
cause = dependent
found_dependent_for_cause = True
break

if not to_gather_keys:
return
# Keep namespace clean since this func is long and has many
# dep*, *ts* variables

assert cause is not None
del to_gather, dependency_key, dependency_ts

self.log.append(
Expand Down Expand Up @@ -2618,7 +2619,7 @@ async def gather_dep(
)

recommendations: dict[TaskState, str | tuple] = {}
deps_to_iter = self.in_flight_workers.pop(worker)
deps_to_iter = set(self.in_flight_workers.pop(worker)) & to_gather_keys

for d in deps_to_iter:
ts = self.tasks.get(d)
Expand Down

0 comments on commit 3da4d9e

Please sign in to comment.