Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve KeyError-related deadlock #5525

Merged
merged 2 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 83 additions & 7 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,11 +1811,11 @@ async def test_story_with_deps(c, s, a, b):
# This is a simple transition log
expected_story = [
(key, "compute-task"),
(key, "released", "waiting", {dep.key: "fetch"}),
(key, "waiting", "ready", {}),
(key, "ready", "executing", {}),
(key, "released", "waiting", "waiting", {dep.key: "fetch"}),
(key, "waiting", "ready", "ready", {}),
(key, "ready", "executing", "executing", {}),
(key, "put-in-memory"),
(key, "executing", "memory", {}),
(key, "executing", "memory", "memory", {}),
]
assert pruned_story == expected_story

Expand All @@ -1837,13 +1837,13 @@ async def test_story_with_deps(c, s, a, b):
assert isinstance(stimulus_id, str)
expected_story = [
(dep_story, "ensure-task-exists", "released"),
(dep_story, "released", "fetch", {}),
(dep_story, "released", "fetch", "fetch", {}),
(
"gather-dependencies",
a.address,
{dep.key},
),
(dep_story, "fetch", "flight", {}),
(dep_story, "fetch", "flight", "flight", {}),
(
"request-dep",
a.address,
Expand All @@ -1855,7 +1855,7 @@ async def test_story_with_deps(c, s, a, b):
{dep.key},
),
(dep_story, "put-in-memory"),
(dep_story, "flight", "memory", {res.key: "ready"}),
(dep_story, "flight", "memory", "memory", {res.key: "ready"}),
]
assert pruned_story == expected_story

Expand Down Expand Up @@ -3090,6 +3090,82 @@ async def _wait_for_state(key: str, worker: Worker, state: str):
await asyncio.sleep(0)


@gen_cluster(client=True)
async def test_gather_dep_cancelled_rescheduled(c, s, a, b):
"""At time of writing, the gather_dep implementation filtered tasks again
for in-flight state. The response parser, however, did not distinguish
resulting in unwanted missing-data signals to the scheduler, causing
potential rescheduling or data leaks.

If a cancelled key is rescheduled for fetching while gather_dep waits
internally for get_data, the response parser would misclassify this key and
cause the key to be recommended for a release causing deadlocks and/or lost
keys.
At time of writing, this transition was implemented wrongly and caused a
flight->cancelled transition which should be recoverable but the cancelled
state was corrupted by this transition since ts.done==True. This attribute
setting would cause a cancelled->fetch transition to actually drop the key
instead, causing https://github.com/dask/distributed/issues/5366

See also test_gather_dep_do_not_handle_response_of_not_requested_tasks
"""
import distributed

with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather:
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[a.address], key="f2")
await fut2
fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4")
fut3 = c.submit(inc, fut1, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_state(fut2_key, b, "flight")
while not mocked_gather.call_args:
await asyncio.sleep(0)

fut4.release()
while fut4.key in b.tasks:
await asyncio.sleep(0)

assert b.tasks[fut2.key].state == "cancelled"
args, kwargs = mocked_gather.call_args
assert fut2.key in kwargs["to_gather"]

# The below synchronization and mock structure allows us to intercept the
# state after gather_dep has been scheduled and is waiting for the
# get_data_from_worker to finish. If state transitions happen during this
# time, the response parser needs to handle this properly
lock = asyncio.Lock()
event = asyncio.Event()
async with lock:

async def wait_get_data(*args, **kwargs):
event.set()
async with lock:
return await distributed.worker.get_data_from_worker(*args, **kwargs)

with mock.patch.object(
distributed.worker,
"get_data_from_worker",
side_effect=wait_get_data,
):
gather_dep_fut = asyncio.ensure_future(
Worker.gather_dep(b, *args, **kwargs)
)

await event.wait()

fut4 = c.submit(sum, [fut1, fut2], workers=[b.address], key="f4")
while b.tasks[fut2.key].state != "flight":
await asyncio.sleep(0.1)
await gather_dep_fut
f2_story = b.story(fut2.key)
assert f2_story
await fut3
await fut4


@gen_cluster(client=True)
async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand Down
25 changes: 20 additions & 5 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,10 +2306,15 @@ def transition_flight_error(
)

def transition_flight_released(self, ts, *, stimulus_id):
ts._previous = "flight"
# See https://github.com/dask/distributed/pull/5046#discussion_r685093940
ts.state = "cancelled"
return {}, []
if ts.done:
# FIXME: Is this even possible? Would an assert instead be more
# sensible?
return self.transition_generic_released(ts, stimulus_id=stimulus_id)
else:
ts._previous = "flight"
# See https://github.com/dask/distributed/pull/5046#discussion_r685093940
ts.state = "cancelled"
return {}, []

def transition_cancelled_memory(self, ts, value, *, stimulus_id):
return {ts: ts._next}, []
Expand Down Expand Up @@ -2399,9 +2404,15 @@ def _transition(self, ts, finish, *args, stimulus_id, **kwargs):

self.log.append(
(
# key
ts.key,
# initial
start,
# recommended
finish,
# final
ts.state,
# new recommendations
{ts.key: new for ts, new in recs.items()},
stimulus_id,
time(),
Expand Down Expand Up @@ -2444,6 +2455,7 @@ def transitions(self, recommendations: dict, *, stimulus_id):
ts, finish = remaining_recs.popitem()
tasks.add(ts)
a_recs, a_smsgs = self._transition(ts, finish, stimulus_id=stimulus_id)

remaining_recs.update(a_recs)
smsgs += a_smsgs

Expand Down Expand Up @@ -2867,7 +2879,10 @@ async def gather_dep(
ts = self.tasks[d]
ts.done = True
if d in cancelled_keys:
recommendations[ts] = "released"
if ts.state == "cancelled":
recommendations[ts] = "released"
else:
recommendations[ts] = "fetch"
elif d in data:
recommendations[ts] = ("memory", data[d])
elif busy:
Expand Down