Skip to content

Commit

Permalink
Simplify preamble of gather_dep (dask#6371)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 20, 2022
1 parent 4a7685a commit b660eac
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 88 deletions.
4 changes: 2 additions & 2 deletions distributed/tests/test_cancelled_state.py
Expand Up @@ -321,7 +321,7 @@ async def get_data(self, comm, *args, **kwargs):
# The initial free-keys is rejected
("free-keys", (fut1.key,)),
(fut1.key, "resumed", "released", "cancelled", {}),
# After gather_dep receives the data, it tries to transition to memory but the task will release instead
# After gather_dep receives the data, the task is forgotten
(fut1.key, "cancelled", "memory", "released", {fut1.key: "forgotten"}),
],
)
Expand Down Expand Up @@ -366,7 +366,7 @@ def block_execution(event, lock):


@gen_cluster(client=True, nthreads=[("", 1, {"resources": {"A": 1}})])
async def test_cancelled_error_with_ressources(c, s, a):
async def test_cancelled_error_with_resources(c, s, a):
executing = Event()
lock_executing = Lock()
await lock_executing.acquire()
Expand Down
36 changes: 35 additions & 1 deletion distributed/tests/test_worker_state_machine.py
Expand Up @@ -8,7 +8,7 @@
from distributed.core import Status
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import assert_story, gen_cluster, inc
from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc
from distributed.worker_state_machine import (
ExecuteFailureEvent,
ExecuteSuccessEvent,
Expand Down Expand Up @@ -473,3 +473,37 @@ async def test_self_denounce_missing_data(c, s, a):
while "x" in a.data:
await asyncio.sleep(0.01)
assert a.tasks["x"].state == "released"


@gen_cluster(client=True)
async def test_cancelled_while_in_flight(c, s, a, b):
event = asyncio.Event()
a.rpc = _LockedCommPool(a.rpc, write_event=event)

x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "flight", a)
y.release()
await wait_for_state("x", "cancelled", a)

# Let the comm from b to a return the result
event.set()
# upon reception, x transitions cancelled->forgotten
while a.tasks:
await asyncio.sleep(0.01)


@gen_cluster(client=True)
async def test_in_memory_while_in_flight(c, s, a, b):
event = asyncio.Event()
a.rpc = _LockedCommPool(a.rpc, write_event=event)

x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "flight", a)
a.update_data({"x": 3})
await wait_for_state("x", "memory", a)

# Let the comm from b to a return the result
event.set()
assert await y == 4 # Data in flight from b has been discarded
2 changes: 1 addition & 1 deletion distributed/utils_test.py
Expand Up @@ -1923,7 +1923,7 @@ class _LockedCommPool(ConnectionPool):
)
# It might be necessary to remove all existing comms
# if the wrapped pool has been used before
>>> w.remove(remote_address)
>>> w.rpc.remove(remote_address)
>>> async def ping_pong():
return await w.rpc(remote_address).ping()
Expand Down
109 changes: 25 additions & 84 deletions distributed/worker.py
Expand Up @@ -3178,61 +3178,32 @@ def total_comm_bytes(self):
)
return self.comm_threshold_bytes

def _filter_deps_for_fetch(
self, to_gather_keys: Iterable[str]
) -> tuple[set[str], set[str], TaskState | None]:
"""Filter a list of keys before scheduling coroutines to fetch data from workers.
def _get_cause(self, keys: Iterable[str]) -> TaskState:
"""For diagnostics, we want to attach a transfer to a single task. This task is
typically the next to be executed but since we're fetching tasks for potentially
many dependents, an exact match is not possible. Additionally, if a key was
fetched through acquire-replicas, dependents may not be known at all.
Returns
-------
in_flight_keys:
The subset of keys in to_gather_keys in state `flight` or `resumed`
cancelled_keys:
The subset of tasks in to_gather_keys in state `cancelled` or `memory`
cause:
The task to attach startstops of this transfer to
The task to attach startstops of this transfer to
"""
in_flight_tasks: set[TaskState] = set()
cancelled_keys: set[str] = set()
for key in to_gather_keys:
ts = self.tasks.get(key)
if ts is None:
continue

# At this point, a task has been transitioned fetch->flight
# flight is only allowed to be transitioned into
# {memory, resumed, cancelled}
# resumed and cancelled will block any further transition until this
# coro has been finished

if ts.state in ("flight", "resumed"):
in_flight_tasks.add(ts)
# If the key is already in memory, the fetch should not happen which
# is signalled by the cancelled_keys
elif ts.state in {"cancelled", "memory"}:
cancelled_keys.add(key)
else:
raise RuntimeError(
f"Task {ts.key} found in illegal state {ts.state}. "
"Only states `flight`, `resumed` and `cancelled` possible."
)

# For diagnostics we want to attach the transfer to a single task. this
# task is typically the next to be executed but since we're fetching
# tasks for potentially many dependents, an exact match is not possible.
# If there are no dependents, this is a pure replica fetch
cause = None
for ts in in_flight_tasks:
for key in keys:
ts = self.tasks[key]
if ts.dependents:
cause = next(iter(ts.dependents))
break
else:
cause = ts
in_flight_keys = {ts.key for ts in in_flight_tasks}
return in_flight_keys, cancelled_keys, cause
return next(iter(ts.dependents))
cause = ts
assert cause # Always at least one key
return cause

def _update_metrics_received_data(
self, start: float, stop: float, data: dict, cause: TaskState, worker: str
self,
start: float,
stop: float,
data: dict[str, Any],
cause: TaskState,
worker: str,
) -> None:

total_bytes = sum(self.tasks[key].get_nbytes() for key in data)
Expand Down Expand Up @@ -3281,7 +3252,7 @@ def _update_metrics_received_data(
async def gather_dep(
self,
worker: str,
to_gather: Iterable[str],
to_gather: Collection[str],
total_nbytes: int,
*,
stimulus_id: str,
Expand All @@ -3305,46 +3276,23 @@ async def gather_dep(
recommendations: Recs = {}
instructions: Instructions = []
response = {}
to_gather_keys: set[str] = set()
cancelled_keys: set[str] = set()

def done_event():
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")

try:
to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch(
to_gather
)

if not to_gather_keys:
self.log.append(
("nothing-to-gather", worker, to_gather, stimulus_id, time())
)
return done_event()

assert cause
# Keep namespace clean since this func is long and has many
# dep*, *ts* variables
del to_gather

self.log.append(
("request-dep", worker, to_gather_keys, stimulus_id, time())
)
logger.debug(
"Request %d keys for task %s from %s",
len(to_gather_keys),
cause,
worker,
)
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
logger.debug("Request %d keys from %s", len(to_gather), worker)

start = time()
response = await get_data_from_worker(
self.rpc, to_gather_keys, worker, who=self.address
self.rpc, to_gather, worker, who=self.address
)
stop = time()
if response["status"] == "busy":
return done_event()

cause = self._get_cause(to_gather)
self._update_metrics_received_data(
start=start,
stop=stop,
Expand Down Expand Up @@ -3397,9 +3345,7 @@ def done_event():
data = response.get("data", {})

if busy:
self.log.append(
("busy-gather", worker, to_gather_keys, stimulus_id, time())
)
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(worker)
Expand All @@ -3410,12 +3356,7 @@ def done_event():
for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
ts.done = True
if d in cancelled_keys:
if ts.state == "cancelled":
recommendations[ts] = "released"
else:
recommendations[ts] = "fetch"
elif d in data:
if d in data:
recommendations[ts] = ("memory", data[d])
elif busy:
recommendations[ts] = "fetch"
Expand Down

0 comments on commit b660eac

Please sign in to comment.