Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gather_dep #6388

Merged
merged 14 commits into from
Jun 10, 2022
5 changes: 2 additions & 3 deletions distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ async def test_worker_story_with_deps(c, s, a, b):
# Story now includes randomized stimulus_ids and timestamps.
story = b.story("res")
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task", "task-finished"}

assert stimulus_ids == {"compute-task", "gather-dep-success", "task-finished"}
# This is a simple transition log
expected = [
("res", "compute-task", "released"),
Expand All @@ -153,7 +152,7 @@ async def test_worker_story_with_deps(c, s, a, b):

story = b.story("dep")
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task"}
assert stimulus_ids == {"compute-task", "gather-dep-success"}
expected = [
("dep", "ensure-task-exists", "released"),
("dep", "released", "fetch", "fetch", {}),
Expand Down
36 changes: 36 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,39 @@ async def test_fetch_to_missing_on_refresh_who_has(c, s, w1, w2, w3):
assert w3.tasks["x"].state == "missing"
assert w3.tasks["y"].state == "flight"
assert w3.tasks["y"].who_has == {w2.address}


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_fetch_to_missing_on_network_failure(c, s, a):
"""
1. Two tasks, x and y, are respectively in flight and fetch state from the same
worker, which holds the only replica of both.
2. gather_dep for x returns GatherDepNetworkFailureEvent
3. The event empties has_what, x.who_has, and y.who_has; it recommends a transition
to missing for both x and y.
5. Before the recommendation can be implemented, the same event invokes
_ensure_communicating, which pops y from data_needed - but y has an empty
who_has, which is an exceptional situation.
6. The fetch->missing transition is executed, but y is no longer in data_needed -
another exceptional situation.
"""
block_get_data = asyncio.Event()

class BlockedBreakingWorker(Worker):
async def get_data(self, comm, *args, **kwargs):
await block_get_data.wait()
raise OSError("fake error")

async with BlockedBreakingWorker(s.address) as b:
x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, 2, key="y", workers=[b.address])
await wait([x, y])
s.request_acquire_replicas(a.address, ["x"], stimulus_id="test_x")
await wait_for_state("x", "flight", a)
s.request_acquire_replicas(a.address, ["y"], stimulus_id="test_y")
await wait_for_state("y", "fetch", a)

block_get_data.set()

await wait_for_state("x", "missing", a)
await wait_for_state("y", "missing", a)
252 changes: 170 additions & 82 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Collection,
Container,
Iterable,
Iterator,
Mapping,
MutableMapping,
)
Expand Down Expand Up @@ -122,7 +123,11 @@
FindMissingEvent,
FreeKeysEvent,
GatherDep,
GatherDepBusyEvent,
GatherDepDoneEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
Instructions,
InvalidTaskState,
InvalidTransition,
Expand Down Expand Up @@ -3289,13 +3294,6 @@ async def gather_dep(
if self.status not in WORKER_ANY_RUNNING:
return None

recommendations: Recs = {}
instructions: Instructions = []
response = {}

def done_event():
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")

try:
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
logger.debug("Request %d keys from %s", len(to_gather), worker)
Expand All @@ -3306,8 +3304,14 @@ def done_event():
)
stop = time()
if response["status"] == "busy":
return done_event()
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
return GatherDepBusyEvent(
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-busy-{time()}",
)

assert response["status"] == "OK"
cause = self._get_cause(to_gather)
self._update_metrics_received_data(
start=start,
Expand All @@ -3319,93 +3323,182 @@ def done_event():
self.log.append(
("receive-dep", worker, set(response["data"]), stimulus_id, time())
)
return done_event()
return GatherDepSuccessEvent(
worker=worker,
total_nbytes=total_nbytes,
data=response["data"],
stimulus_id=f"gather-dep-success-{time()}",
)

except OSError:
logger.exception("Worker stream died during communication: %s", worker)
has_what = self.has_what.pop(worker)
self.data_needed_per_worker.pop(worker)
self.log.append(
("receive-dep-failed", worker, has_what, stimulus_id, time())
("receive-dep-failed", worker, to_gather, stimulus_id, time())
)
return GatherDepNetworkFailureEvent(
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-network-failure-{time()}",
)
for d in has_what:
ts = self.tasks[d]
ts.who_has.remove(worker)
if not ts.who_has and ts.state in (
"fetch",
"flight",
"resumed",
"cancelled",
):
recommendations[ts] = "missing"
self.log.append(
("missing-who-has", worker, ts.key, stimulus_id, time())
)
return done_event()

except Exception as e:
# e.g. data failed to deserialize
logger.exception(e)
if self.batched_stream and LOG_PDB:
import pdb

pdb.set_trace()
msg = error_message(e)
for k in self.in_flight_workers[worker]:
ts = self.tasks[k]
recommendations[ts] = tuple(msg.values())
return done_event()

finally:
self.comm_nbytes -= total_nbytes
busy = response.get("status", "") == "busy"
data = response.get("data", {})
return GatherDepFailureEvent.from_exception(
e,
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-failure-{time()}",
)

if busy:
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(worker)
instructions.append(
RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id)
def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]:
"""Common code for all subclasses of GatherDepDoneEvent.

Yields the tasks that need to transition out of flight.
"""
self.comm_nbytes -= ev.total_nbytes
keys = self.in_flight_workers.pop(ev.worker)
for key in keys:
ts = self.tasks[key]
ts.done = True
yield ts

def _refetch_missing_data(
self, worker: str, tasks: Iterable[TaskState], stimulus_id: str
) -> RecsInstrs:
Copy link
Member

@fjetter fjetter Jun 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly dislike the idea of reusing this code. This "helper" already shows signs of too much complexity, we're dealing with many individual ts.state values to make a decision. If we restrict ourselves to the specialized invocation this should be much less complex.

Reusing code was one of the major reasons why the except/except/finally block caused so many problems.

Particularly with the MissingDataMsg singal in here (#6445) I do not trust this to be the correct answer for both success-but-missing and network-failure responses, see also #6112 (comment)

"""Helper of GatherDepSuccessEvent and GatherDepNetworkFailureEvent handlers.

Remove tasks that were not returned from the peer worker from has_what and
inform the scheduler with a missing-data message. Then, transition them back to
'fetch' so that they can be fetched from another worker.
"""
recommendations: Recs = {}
instructions: Instructions = []

for ts in tasks:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
self.log.append((ts.key, "missing-dep", stimulus_id, time()))
instructions.append(
MissingDataMsg(
key=ts.key,
errant_worker=worker,
stimulus_id=stimulus_id,
)
)
if ts.state in ("flight", "resumed", "cancelled"):
# This will actually transition to missing if who_has is empty
recommendations[ts] = "fetch"
elif ts.state == "fetch":
self.data_needed_per_worker[worker].discard(ts)
if not ts.who_has:
recommendations[ts] = "missing"

refresh_who_has = []

for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
ts.done = True
if d in data:
recommendations[ts] = ("memory", data[d])
elif busy:
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.append(d)
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
self.data_needed_per_worker[worker].discard(ts)
self.log.append((d, "missing-dep", stimulus_id, time()))
instructions.append(
MissingDataMsg(
key=d,
errant_worker=worker,
stimulus_id=stimulus_id,
)
)
recommendations[ts] = "fetch"

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has,
stimulus_id=f"gather-dep-busy-{time()}",
)
return recommendations, instructions

@_handle_event.register
def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
"""gather_dep terminated successfully.
The response may contain less keys than the request.
"""
recommendations: Recs = {}
refetch = set()
for ts in self._gather_dep_done_common(ev):
if ts.key in ev.data:
recommendations[ts] = ("memory", ev.data[ts.key])
else:
refetch.add(ts)

return merge_recs_instructions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we'll eventually be able to refactor all these merge_recs_instructions out, or is this here to stay because of the pattern of using the helper functions and ensure_communicating?

Copy link
Collaborator Author

@crusaderky crusaderky Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's here to stay. The alternative would be to pass recommendations, transitions to the helper functions and let them write back into them.

(recommendations, []),
self._refetch_missing_data(ev.worker, refetch, ev.stimulus_id),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
"""gather_dep terminated: remote worker is busy"""
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(ev.worker)

recommendations: Recs = {}
refresh_who_has = []
for ts in self._gather_dep_done_common(ev):
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.append(ts.key)

instructions: Instructions = [
RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id),
]

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has, stimulus_id=ev.stimulus_id
)
)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)
return merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_network_failure(
self, ev: GatherDepNetworkFailureEvent
) -> RecsInstrs:
"""gather_dep terminated: network failure while trying to
communicate with remote worker
crusaderky marked this conversation as resolved.
Show resolved Hide resolved

Though the network failure could be transient, we assume it is not, and
preemptively act as though the other worker has died (including removing all
keys from it, even ones we did not fetch).

This optimization leads to faster completion of the fetch, since we immediately
either retry a different worker, or ask the scheduler to inform us of a new
worker if no other worker is available.
"""
refetch = set(self._gather_dep_done_common(ev))
refetch |= {self.tasks[key] for key in self.has_what[ev.worker]}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these two sets should be mixed, see my earlier comment about the shared "helper" function.


recs, instrs = merge_recs_instructions(
self._refetch_missing_data(ev.worker, refetch, ev.stimulus_id),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
# This cleanup must happen after _refetch_missing_data
del self.has_what[ev.worker]
del self.data_needed_per_worker[ev.worker]
return recs, instrs

@_handle_event.register
def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
"""gather_dep terminated: generic error raised (not a network failure);
e.g. data failed to deserialize.
"""
recommendations: Recs = {
ts: (
"error",
ev.exception,
ev.traceback,
ev.exception_text,
ev.traceback_text,
)
for ts in self._gather_dep_done_common(ev)
}

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
await asyncio.sleep(0.15)
Expand Down Expand Up @@ -3844,11 +3937,6 @@ def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs:
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_done(self, ev: GatherDepDoneEvent) -> RecsInstrs:
"""Temporary hack - to be removed"""
return self._ensure_communicating(stimulus_id=ev.stimulus_id)

@_handle_event.register
def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:
self.busy_workers.discard(ev.worker)
Expand Down