Skip to content

Commit

Permalink
Refactor gather_dep
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 20, 2022
1 parent f895b26 commit d988cbe
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 105 deletions.
22 changes: 14 additions & 8 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
from collections.abc import Iterator
from contextlib import contextmanager
from itertools import chain

import pytest

Expand All @@ -16,7 +18,6 @@
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
StateMachineEvent,
TaskState,
TaskStateState,
Expand Down Expand Up @@ -103,14 +104,19 @@ def test_unique_task_heap():
assert repr(heap) == "<UniqueTaskHeap: 0 items>"


def traverse_subclasses(cls: type) -> Iterator[type]:
yield cls
for subcls in cls.__subclasses__():
yield from traverse_subclasses(subcls)


@pytest.mark.parametrize(
"cls",
chain(
[UniqueTaskHeap],
Instruction.__subclasses__(),
SendMessageToScheduler.__subclasses__(),
StateMachineEvent.__subclasses__(),
),
[
UniqueTaskHeap,
*traverse_subclasses(Instruction),
*traverse_subclasses(StateMachineEvent),
],
)
def test_slots(cls):
params = [
Expand Down
284 changes: 188 additions & 96 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Collection,
Container,
Iterable,
Iterator,
Mapping,
MutableMapping,
)
Expand Down Expand Up @@ -117,7 +118,11 @@
ExecuteSuccessEvent,
FindMissingEvent,
GatherDep,
GatherDepBusyEvent,
GatherDepDoneEvent,
GatherDepErrorEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
Instructions,
InvalidTransition,
LongRunningMsg,
Expand Down Expand Up @@ -3251,7 +3256,6 @@ def _update_metrics_received_data(
self.counters["transfer-count"].add(len(data))
self.incoming_count += 1

@fail_hard
@log_errors
async def gather_dep(
self,
Expand All @@ -3277,13 +3281,6 @@ async def gather_dep(
if self.status not in WORKER_ANY_RUNNING:
return None

recommendations: Recs = {}
instructions: Instructions = []
response = {}

def done_event():
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")

try:
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
logger.debug("Request %d keys from %s", len(to_gather), worker)
Expand All @@ -3294,42 +3291,32 @@ def done_event():
)
stop = time()
if response["status"] == "busy":
return done_event()
return GatherDepBusyEvent(
worker=worker, total_nbytes=total_nbytes, stimulus_id=stimulus_id
)

cause = self._get_cause(to_gather)
self._update_metrics_received_data(
start=start,
stop=stop,
data=response["data"],
cause=cause,
assert response["status"] == "OK"
if response["data"]:
cause = self._get_cause(response["data"])
self._update_metrics_received_data(
start=start,
stop=stop,
data=response["data"],
cause=cause,
worker=worker,
)

return GatherDepSuccessEvent(
worker=worker,
total_nbytes=total_nbytes,
data=response["data"],
stimulus_id=stimulus_id,
)
self.log.append(
("receive-dep", worker, set(response["data"]), stimulus_id, time())
)
return done_event()

except OSError:
logger.exception("Worker stream died during communication: %s", worker)
has_what = self.has_what.pop(worker)
self.data_needed_per_worker.pop(worker)
self.log.append(
("receive-dep-failed", worker, has_what, stimulus_id, time())
return GatherDepNetworkFailureEvent(
worker=worker, total_nbytes=total_nbytes, stimulus_id=stimulus_id
)
for d in has_what:
ts = self.tasks[d]
ts.who_has.remove(worker)
if not ts.who_has and ts.state in (
"fetch",
"flight",
"resumed",
"cancelled",
):
recommendations[ts] = "missing"
self.log.append(
("missing-who-has", worker, ts.key, stimulus_id, time())
)
return done_event()

except Exception as e:
logger.exception(e)
Expand All @@ -3338,61 +3325,15 @@ def done_event():

pdb.set_trace()
msg = error_message(e)
for k in self.in_flight_workers[worker]:
ts = self.tasks[k]
recommendations[ts] = tuple(msg.values())
return done_event()

finally:
self.comm_nbytes -= total_nbytes
busy = response.get("status", "") == "busy"
data = response.get("data", {})

if busy:
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(worker)
instructions.append(
RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id)
)

refresh_who_has = set()

for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
ts.done = True
if d in data:
recommendations[ts] = ("memory", data[d])
elif busy:
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.add(ts.key)
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
self.log.append((d, "missing-dep", stimulus_id, time()))
instructions.append(
MissingDataMsg(
key=d,
errant_worker=worker,
stimulus_id=stimulus_id,
)
)
recommendations[ts] = "fetch"

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=list(refresh_who_has),
stimulus_id=f"gather-dep-busy-{time()}",
)
)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)
return GatherDepErrorEvent(
worker=worker,
total_nbytes=total_nbytes,
exception=msg["exception"],
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
stimulus_id=stimulus_id,
)

async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
await asyncio.sleep(0.15)
Expand Down Expand Up @@ -3935,10 +3876,161 @@ def _(self, ev: UnpauseEvent) -> RecsInstrs:
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]:
"""Common code for all subclasses of GatherDepDoneEvent"""
self.comm_nbytes -= ev.total_nbytes
for key in self.in_flight_workers.pop(ev.worker):
ts = self.tasks[key]
ts.done = True
yield ts

def _refetch_missing_data(
self, ev: GatherDepDoneEvent, tasks: Iterable[TaskState]
) -> RecsInstrs:
"""Helper of GatherDepDoneEvent subclass handlers"""
recommendations: Recs = {}
instructions: Instructions = []

for ts in tasks:
ts.who_has.discard(ev.worker)
self.has_what[ev.worker].discard(ts.key)
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
instructions.append(
MissingDataMsg(
key=ts.key,
errant_worker=ev.worker,
stimulus_id=ev.stimulus_id,
)
)
recommendations[ts] = "fetch"
return recommendations, instructions

@handle_event.register
def _(self, ev: GatherDepDoneEvent) -> RecsInstrs:
"""Temporary hack - to be removed"""
return self._ensure_communicating(stimulus_id=ev.stimulus_id)
def _(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
"""gather_dep terminated successfully.
The response may contain less keys than the request.
"""
self.log.append(
("receive-dep", ev.worker, set(ev.data), ev.stimulus_id, time())
)

recommendations: Recs = {}
refetch = set()
for ts in self._gather_dep_done_common(ev):
if ts.key in ev.data:
recommendations[ts] = ("memory", ev.data[ts.key])
else:
refetch.add(ts)

smsg = EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id)
return merge_recs_instructions(
(recommendations, [smsg]),
self._refetch_missing_data(ev, refetch),
)

@handle_event.register
def _(self, ev: GatherDepBusyEvent) -> RecsInstrs:
"""gather_dep terminated: remote worker is busy"""
self.log.append(
(
"busy-gather",
ev.worker,
set(self.in_flight_workers[ev.worker]),
ev.stimulus_id,
time(),
)
)

# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(ev.worker)

recommendations: Recs = {}
refresh_who_has = []
for ts in self._gather_dep_done_common(ev):
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.append(ts.key)

instructions: Instructions = [
RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id),
EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id),
]
if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has,
stimulus_id=f"gather-dep-busy-{time()}",
)
)

return recommendations, instructions

@handle_event.register
def _(self, ev: GatherDepNetworkFailureEvent) -> RecsInstrs:
"""gather_dep terminated: network failure while trying to
communicate with remote worker
"""
logger.exception("Worker stream died during communication: %s", ev.worker)

# if state in (fetch, flight, resumed, cancelled):
# if ts.who_has is now empty:
# transition to missing; don't send data-missing
# elif ts in GatherDep.keys:
# transition to fetch; send data-missing
# else:
# don't transition
# elif ts in GatherDep.keys:
# transition to fetch; send data-missing
# else:
# don't transition

has_what = self.has_what.pop(ev.worker)
self.data_needed_per_worker.pop(ev.worker)
self.log.append(
("receive-dep-failed", ev.worker, has_what, ev.stimulus_id, time())
)
recommendations: Recs = {}
for d in has_what:
ts = self.tasks[d]
ts.who_has.remove(ev.worker)
if not ts.who_has and ts.state in (
"fetch",
"flight",
"resumed",
"cancelled",
):
recommendations[ts] = "missing"
self.log.append(
("missing-who-has", ev.worker, ts.key, ev.stimulus_id, time())
)

refetch_tasks = set(self._gather_dep_done_common(ev)) - recommendations.keys()
smsg = EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id)
return merge_recs_instructions(
(recommendations, [smsg]),
self._refetch_missing_data(ev, refetch_tasks),
)

@handle_event.register
def _(self, ev: GatherDepErrorEvent) -> RecsInstrs:
"""gather_dep terminated: generic error raised (not a network failure);
e.g. data failed to deserialize.
"""
recommendations: Recs = {
ts: (
"error",
ev.exception,
ev.traceback,
ev.exception_text,
ev.traceback_text,
)
for ts in self._gather_dep_done_common(ev)
}
smsg = EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id)
return recommendations, [smsg]

@handle_event.register
def _(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:
Expand Down

0 comments on commit d988cbe

Please sign in to comment.