Skip to content

Commit

Permalink
Refactor find_missing and refresh_who_has (dask#6348)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 20, 2022
1 parent 4e9385b commit 4a7685a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 45 deletions.
24 changes: 20 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,6 +3024,7 @@ def __init__(
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
"request-refresh-who-has": self.handle_request_refresh_who_has,
}

client_handlers = {
Expand Down Expand Up @@ -4810,6 +4811,21 @@ def handle_worker_status_change(
else:
self.running.discard(ws)

async def handle_request_refresh_who_has(
self, keys: Iterable[str], worker: str, stimulus_id: str
) -> None:
"""Asynchronous request (through bulk comms) from a Worker to refresh the
who_has for some keys. Not to be confused with scheduler.who_has, which is a
synchronous RPC request from a Client.
"""
self.stream_comms[worker].send(
{
"op": "refresh-who-has",
"who_has": self.get_who_has(keys),
"stimulus_id": stimulus_id,
},
)

async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Listen to responses from a single worker
Expand Down Expand Up @@ -6282,13 +6298,13 @@ def get_processing(self, workers=None):
w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()
}

def get_who_has(self, keys=None):
def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]:
if keys is not None:
return {
k: [ws.address for ws in self.tasks[k].who_has]
if k in self.tasks
key: [ws.address for ws in self.tasks[key].who_has]
if key in self.tasks
else []
for k in keys
for key in keys
}
else:
return {
Expand Down
91 changes: 50 additions & 41 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
GatherDep,
GatherDepDoneEvent,
Instructions,
Expand All @@ -123,7 +124,9 @@
MissingDataMsg,
Recs,
RecsInstrs,
RefreshWhoHasEvent,
ReleaseWorkerDataMsg,
RequestRefreshWhoHasMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
Expand Down Expand Up @@ -797,6 +800,7 @@ def __init__(
"compute-task": self.handle_compute_task,
"free-keys": self.handle_free_keys,
"remove-replicas": self.handle_remove_replicas,
"refresh-who-has": self.handle_refresh_who_has,
"steal-request": self.handle_steal_request,
"worker-status-change": self.handle_worker_status_change,
}
Expand Down Expand Up @@ -825,8 +829,7 @@ def __init__(
)
self.periodic_callbacks["keep-alive"] = pc

# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
pc = PeriodicCallback(self.find_missing, 1000) # type: ignore
pc = PeriodicCallback(self.find_missing, 1000)
self.periodic_callbacks["find-missing"] = pc

self._address = contact_address
Expand Down Expand Up @@ -1821,6 +1824,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str:

return "OK"

def handle_refresh_who_has(
self, who_has: dict[str, list[str]], stimulus_id: str
) -> None:
self.handle_stimulus(
RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id)
)

async def set_resources(self, **resources) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
Expand Down Expand Up @@ -2844,7 +2854,8 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:

@log_errors
def handle_stimulus(self, stim: StateMachineEvent) -> None:
self.stimulus_log.append(stim.to_loggable(handled=time()))
if not isinstance(stim, FindMissingEvent):
self.stimulus_log.append(stim.to_loggable(handled=time()))
recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
Expand Down Expand Up @@ -3422,22 +3433,19 @@ def done_event():
)
)
recommendations[ts] = "fetch"
del data, response
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
who_has = await retry_operation(
self.scheduler.who_has, keys=refresh_who_has
)
refresh_stimulus_id = f"refresh-who-has-{time()}"
recommendations, instructions = self._update_who_has(
who_has, stimulus_id=refresh_stimulus_id
instructions.append(
RequestRefreshWhoHasMsg(
keys=list(refresh_who_has),
stimulus_id=f"gather-dep-busy-{time()}",
)
)
self.transitions(recommendations, stimulus_id=refresh_stimulus_id)
self._handle_instructions(instructions)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

@log_errors
def _readd_busy_worker(self, worker: str) -> None:
Expand All @@ -3447,34 +3455,13 @@ def _readd_busy_worker(self, worker: str) -> None:
)

@log_errors
async def find_missing(self) -> None:
if not self._missing_dep_flight:
return
try:
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has
def find_missing(self) -> None:
self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}"))

stimulus_id = f"find-missing-{time()}"
who_has = await retry_operation(
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
recommendations, instructions = self._update_who_has(
who_has, stimulus_id=stimulus_id
)
for ts in self._missing_dep_flight:
if ts.who_has:
assert ts not in recommendations
recommendations[ts] = "fetch"
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

finally:
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
].callback_time = self.periodic_callbacks["heartbeat"].callback_time
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[
"heartbeat"
].callback_time

def _update_who_has(
self, who_has: Mapping[str, Collection[str]], *, stimulus_id: str
Expand Down Expand Up @@ -3543,12 +3530,15 @@ def _update_who_has(

ts.who_has = workers
# currently fetching -> can no longer be fetched -> transition to missing
# currently missing -> opportunity to be fetched -> transition to fetch
# any other state -> eventually, possibly, the task may transition to fetch
# or missing, at which point the relevant transitions will test who_has that
# we just updated. e.g. see the various transitions to fetch, which
# instead recommend transitioning to missing if who_has is empty.
if not workers and ts.state == "fetch":
recs[ts] = "missing"
elif workers and ts.state == "missing":
recs[ts] = "fetch"

return recs, instructions

Expand Down Expand Up @@ -4074,6 +4064,25 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs:
assert ts, self.story(ev.key)
return {ts: "rescheduled"}, []

@handle_event.register
def _(self, ev: FindMissingEvent) -> RecsInstrs:
if not self._missing_dep_flight:
return {}, []

if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has

smsg = RequestRefreshWhoHasMsg(
keys=[ts.key for ts in self._missing_dep_flight],
stimulus_id=ev.stimulus_id,
)
return {}, [smsg]

@handle_event.register
def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs:
return self._update_who_has(ev.who_has, stimulus_id=ev.stimulus_id)

def _prepare_args_for_execution(
self, ts: TaskState, args: tuple, kwargs: dict[str, Any]
) -> tuple[tuple, dict[str, Any]]:
Expand Down
39 changes: 39 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,26 @@ class AddKeysMsg(SendMessageToScheduler):
keys: list[str]


@dataclass
class RequestRefreshWhoHasMsg(SendMessageToScheduler):
"""Worker -> Scheduler asynchronous request for updated who_has information.
Not to be confused with the scheduler.who_has synchronous RPC call, which is used
by the Client.
See also
--------
RefreshWhoHasEvent
distributed.scheduler.Scheduler.request_refresh_who_has
distributed.client.Client.who_has
distributed.scheduler.Scheduler.get_who_has
"""

op = "request-refresh-who-has"

__slots__ = ("keys",)
keys: list[str]


@dataclass
class StateMachineEvent:
__slots__ = ("stimulus_id", "handled")
Expand Down Expand Up @@ -508,6 +528,25 @@ class RescheduleEvent(StateMachineEvent):
key: str


@dataclass
class FindMissingEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class RefreshWhoHasEvent(StateMachineEvent):
"""Scheduler -> Worker message containing updated who_has information.
See also
--------
RequestRefreshWhoHasMsg
"""

__slots__ = ("who_has",)
# {key: [worker address, ...]}
who_has: dict[str, list[str]]


if TYPE_CHECKING:
# TODO remove quotes (requires Python >=3.9)
# TODO get out of TYPE_CHECKING (requires Python >=3.10)
Expand Down

0 comments on commit 4a7685a

Please sign in to comment.