Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deadlock - Ensure resumed flight tasks are still fetched #5426

Merged
merged 1 commit into from Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions distributed/tests/test_client.py
Expand Up @@ -102,7 +102,7 @@

@gen_cluster(client=True)
async def test_submit(c, s, a, b):
x = c.submit(inc, 10)
x = c.submit(inc, 10, key="x")
assert not x.done()

assert isinstance(x, Future)
Expand All @@ -112,7 +112,7 @@ async def test_submit(c, s, a, b):
assert result == 11
assert x.done()

y = c.submit(inc, 20)
y = c.submit(inc, 20, key="y")
z = c.submit(add, x, y)

result = await z
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_steal.py
Expand Up @@ -987,6 +987,8 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
slowinc,
range(10),
key=[f"f1-{ix}" for ix in range(10)],
workers=[w0.address],
allow_other_workers=True,
)
while not w0.active_keys:
await asyncio.sleep(0.01)
Expand Down
89 changes: 69 additions & 20 deletions distributed/tests/test_worker.py
Expand Up @@ -1804,14 +1804,14 @@ async def test_story_with_deps(c, s, a, b):
stimulus_ids.add(msg[-2])
pruned_story.append(tuple(pruned_msg[:-2]))

assert len(stimulus_ids) == 3
assert len(stimulus_ids) == 3, stimulus_ids
stimulus_id = pruned_story[0][-1]
assert isinstance(stimulus_id, str)
assert stimulus_id.startswith("compute-task")
# This is a simple transition log
expected_story = [
(key, "compute-task"),
(key, "released", "waiting", {}),
(key, "released", "waiting", {dep.key: "fetch"}),
(key, "waiting", "ready", {}),
(key, "ready", "executing", {}),
(key, "put-in-memory"),
Expand All @@ -1832,11 +1832,11 @@ async def test_story_with_deps(c, s, a, b):
stimulus_ids.add(msg[-2])
pruned_story.append(tuple(pruned_msg[:-2]))

assert len(stimulus_ids) == 3
assert len(stimulus_ids) == 2, stimulus_ids
stimulus_id = pruned_story[0][-1]
assert isinstance(stimulus_id, str)
expected_story = [
(dep_story, "register-replica", "released"),
(dep_story, "ensure-task-exists", "released"),
(dep_story, "released", "fetch", {}),
(
"gather-dependencies",
Expand Down Expand Up @@ -2794,7 +2794,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
_acquire_replicas(s, b, fut)

await futC
while fut.key not in b.tasks:
while fut.key not in b.tasks or not b.tasks[fut.key].state == "memory":
await asyncio.sleep(0.005)
assert len(s.who_has[fut.key]) == 2

Expand Down Expand Up @@ -3082,12 +3082,14 @@ def clear_leak():
]


async def _wait_for_flight(key, worker):
while key not in worker.tasks or worker.tasks[key].state != "flight":
async def _wait_for_state(key: str, worker: Worker, state: str):
# Keep the sleep interval at 0 since the tests using this are very sensitive
# about timing. they intend to capture loop cycles after this specific
# condition was set
while key not in worker.tasks or worker.tasks[key].state != state:
await asyncio.sleep(0)


@pytest.mark.xfail(reason="#5406")
@gen_cluster(client=True)
async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand All @@ -3107,21 +3109,26 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a,

fut2_key = fut2.key

await _wait_for_flight(fut2_key, b)
await _wait_for_state(fut2_key, b, "flight")
while not mocked_gather.call_args:
await asyncio.sleep(0)

fut4.release()
while fut4.key in b.tasks:
await asyncio.sleep(0)

story_before = b.story(fut2.key)
assert fut2.key in mocked_gather.call_args.kwargs["to_gather"]
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
story_after = b.story(fut2.key)
assert story_before == story_after
assert b.tasks[fut2.key].state == "cancelled"
args, kwargs = mocked_gather.call_args
assert fut2.key in kwargs["to_gather"]

await Worker.gather_dep(b, *args, **kwargs)
assert fut2.key not in b.tasks
f2_story = b.story(fut2.key)
assert f2_story
assert not any("missing-dep" in msg for msg in b.story(fut2.key))
await fut3


@pytest.mark.xfail(reason="#5406")
@gen_cluster(
client=True,
config={
Expand All @@ -3137,13 +3144,55 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a, b):

fut1_key = fut1.key

await _wait_for_flight(fut1_key, b)
await _wait_for_state(fut1_key, b, "flight")
while not mocked_gather.call_args:
await asyncio.sleep(0)

fut2.release()
while fut2.key in b.tasks:
await asyncio.sleep(0)

assert b.tasks[fut1.key] != "flight"
log_before = list(b.log)
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
assert log_before == list(b.log)
assert b.tasks[fut1.key].state == "cancelled"

args, kwargs = mocked_gather.call_args
await Worker.gather_dep(b, *args, **kwargs)

assert fut2.key not in b.tasks
f1_story = b.story(fut1.key)
assert f1_story
assert not any("missing-dep" in msg for msg in b.story(fut2.key))


@pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"])
@pytest.mark.parametrize("close_worker", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
c, s, a, b, x, intermediate_state, close_worker
):
"""If a task was transitioned to in-flight, the gather-dep coroutine was
scheduled but a cancel request came in before gather_data_from_worker was
issued this might corrupt the state machine if the cancelled key is not
properly handled"""

fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1")
fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B")
fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2")
await fut2
with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather:
fut3 = c.submit(inc, fut2, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_state(fut2_key, b, "flight")

s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address})
while not mocked_gather.call_args:
await asyncio.sleep(0)

await s.remove_worker(address=x.address, safe=True, close=close_worker)

await _wait_for_state(fut2_key, b, intermediate_state)

args, kwargs = mocked_gather.call_args
await Worker.gather_dep(b, *args, **kwargs)
await fut3