Skip to content

Commit

Permalink
Refactor gather_dep
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 7, 2022
1 parent bde90af commit d985e6b
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 94 deletions.
5 changes: 2 additions & 3 deletions distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ async def test_worker_story_with_deps(c, s, a, b):
# Story now includes randomized stimulus_ids and timestamps.
story = b.story("res")
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task", "task-finished"}

assert stimulus_ids == {"compute-task", "gather-dep-success", "task-finished"}
# This is a simple transition log
expected = [
("res", "compute-task", "released"),
Expand All @@ -153,7 +152,7 @@ async def test_worker_story_with_deps(c, s, a, b):

story = b.story("dep")
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task"}
assert stimulus_ids == {"compute-task", "gather-dep-success"}
expected = [
("dep", "ensure-task-exists", "released"),
("dep", "released", "fetch", "fetch", {}),
Expand Down
274 changes: 184 additions & 90 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Collection,
Container,
Iterable,
Iterator,
Mapping,
MutableMapping,
)
Expand Down Expand Up @@ -122,7 +123,11 @@
FindMissingEvent,
FreeKeysEvent,
GatherDep,
GatherDepBusyEvent,
GatherDepDoneEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
Instructions,
InvalidTaskState,
InvalidTransition,
Expand Down Expand Up @@ -3257,13 +3262,6 @@ async def gather_dep(
if self.status not in WORKER_ANY_RUNNING:
return None

recommendations: Recs = {}
instructions: Instructions = []
response = {}

def done_event():
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")

try:
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
logger.debug("Request %d keys from %s", len(to_gather), worker)
Expand All @@ -3274,8 +3272,13 @@ def done_event():
)
stop = time()
if response["status"] == "busy":
return done_event()
return GatherDepBusyEvent(
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-busy-{time()}",
)

assert response["status"] == "OK"
cause = self._get_cause(to_gather)
self._update_metrics_received_data(
start=start,
Expand All @@ -3284,96 +3287,192 @@ def done_event():
cause=cause,
worker=worker,
)
self.log.append(
("receive-dep", worker, set(response["data"]), stimulus_id, time())

return GatherDepSuccessEvent(
worker=worker,
total_nbytes=total_nbytes,
data=response["data"],
stimulus_id=f"gather-dep-success-{time()}",
)
return done_event()

except OSError:
logger.exception("Worker stream died during communication: %s", worker)
has_what = self.has_what.pop(worker)
self.data_needed_per_worker.pop(worker)
self.log.append(
("receive-dep-failed", worker, has_what, stimulus_id, time())
return GatherDepNetworkFailureEvent(
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-network-failure-{time()}",
)
for d in has_what:
ts = self.tasks[d]
ts.who_has.remove(worker)
if not ts.who_has and ts.state in (
"fetch",
"flight",
"resumed",
"cancelled",
):
recommendations[ts] = "missing"
self.log.append(
("missing-who-has", worker, ts.key, stimulus_id, time())
)
return done_event()

except Exception as e:
# e.g. data failed to deserialize
logger.exception(e)
if self.batched_stream and LOG_PDB:
import pdb
return GatherDepFailureEvent.from_exception(
e,
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-failure-{time()}",
)

pdb.set_trace()
msg = error_message(e)
for k in self.in_flight_workers[worker]:
ts = self.tasks[k]
recommendations[ts] = tuple(msg.values())
return done_event()
def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]:
"""Common code for all subclasses of GatherDepDoneEvent"""
self.comm_nbytes -= ev.total_nbytes
for key in self.in_flight_workers.pop(ev.worker):
ts = self.tasks[key]
ts.done = True
yield ts

finally:
self.comm_nbytes -= total_nbytes
busy = response.get("status", "") == "busy"
data = response.get("data", {})

if busy:
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(worker)
instructions.append(
RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id)
def _refetch_missing_data(
self, worker: str, tasks: Iterable[TaskState], stimulus_id: str
) -> RecsInstrs:
"""Helper of GatherDepDoneEvent subclass handlers"""
recommendations: Recs = {}
instructions: Instructions = []

for ts in tasks:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
self.log.append((ts.key, "missing-dep", stimulus_id, time()))
instructions.append(
MissingDataMsg(
key=ts.key,
errant_worker=worker,
stimulus_id=stimulus_id,
)
)
recommendations[ts] = "fetch"
return recommendations, instructions

refresh_who_has = []

for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
ts.done = True
if d in data:
recommendations[ts] = ("memory", data[d])
elif busy:
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.append(d)
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
self.data_needed_per_worker[worker].discard(ts)
self.log.append((d, "missing-dep", stimulus_id, time()))
instructions.append(
MissingDataMsg(
key=d,
errant_worker=worker,
stimulus_id=stimulus_id,
)
)
recommendations[ts] = "fetch"

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has,
stimulus_id=f"gather-dep-busy-{time()}",
)
@_handle_event.register
def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
"""gather_dep terminated successfully.
The response may contain less keys than the request.
"""
self.log.append(
("receive-dep", ev.worker, set(ev.data), ev.stimulus_id, time())
)

recommendations: Recs = {}
refetch = set()
for ts in self._gather_dep_done_common(ev):
if ts.key in ev.data:
recommendations[ts] = ("memory", ev.data[ts.key])
else:
refetch.add(ts)

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
self._refetch_missing_data(ev.worker, refetch, ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
"""gather_dep terminated: remote worker is busy"""
self.log.append(
(
"busy-gather",
ev.worker,
set(self.in_flight_workers[ev.worker]),
ev.stimulus_id,
time(),
)
)

# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(ev.worker)

recommendations: Recs = {}
refresh_who_has = []
for ts in self._gather_dep_done_common(ev):
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.append(ts.key)

instructions: Instructions = [
RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id),
]
if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has, stimulus_id=ev.stimulus_id
)
)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)
return merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_network_failure(
self, ev: GatherDepNetworkFailureEvent
) -> RecsInstrs:
"""gather_dep terminated: network failure while trying to
communicate with remote worker
"""
logger.exception("Worker stream died during communication: %s", ev.worker)

# if state in (fetch, flight, resumed, cancelled):
# if ts.who_has is now empty:
# transition to missing; don't send data-missing
# elif ts in GatherDep.keys:
# transition to fetch; send data-missing
# else:
# don't transition
# elif ts in GatherDep.keys:
# transition to fetch; send data-missing
# else:
# don't transition

has_what = self.has_what.pop(ev.worker)
self.data_needed_per_worker.pop(ev.worker)
self.log.append(
("receive-dep-failed", ev.worker, has_what, ev.stimulus_id, time())
)
recommendations: Recs = {}
for d in has_what:
ts = self.tasks[d]
ts.who_has.remove(ev.worker)
if not ts.who_has and ts.state in (
"fetch",
"flight",
"resumed",
"cancelled",
):
recommendations[ts] = "missing"
self.log.append(
("missing-who-has", ev.worker, ts.key, ev.stimulus_id, time())
)

refetch_tasks = set(self._gather_dep_done_common(ev)) - recommendations.keys()
return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
self._refetch_missing_data(ev.worker, refetch_tasks, ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
"""gather_dep terminated: generic error raised (not a network failure);
e.g. data failed to deserialize.
"""
recommendations: Recs = {
ts: (
"error",
ev.exception,
ev.traceback,
ev.exception_text,
ev.traceback_text,
)
for ts in self._gather_dep_done_common(ev)
}

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
await asyncio.sleep(0.15)
Expand Down Expand Up @@ -3812,11 +3911,6 @@ def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs:
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_done(self, ev: GatherDepDoneEvent) -> RecsInstrs:
"""Temporary hack - to be removed"""
return self._ensure_communicating(stimulus_id=ev.stimulus_id)

@_handle_event.register
def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:
self.busy_workers.discard(ev.worker)
Expand Down

0 comments on commit d985e6b

Please sign in to comment.