Skip to content

Commit

Permalink
Do not filter tasks before gathering data (#6371)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 20, 2022
1 parent f669f06 commit 4420644
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 88 deletions.
4 changes: 2 additions & 2 deletions distributed/tests/test_cancelled_state.py
Expand Up @@ -321,7 +321,7 @@ async def get_data(self, comm, *args, **kwargs):
# The initial free-keys is rejected
("free-keys", (fut1.key,)),
(fut1.key, "resumed", "released", "cancelled", {}),
# After gather_dep receives the data, it tries to transition to memory but the task will release instead
# After gather_dep receives the data, the task is forgotten
(fut1.key, "cancelled", "memory", "released", {fut1.key: "forgotten"}),
],
)
Expand Down Expand Up @@ -366,7 +366,7 @@ def block_execution(event, lock):


@gen_cluster(client=True, nthreads=[("", 1, {"resources": {"A": 1}})])
async def test_cancelled_error_with_ressources(c, s, a):
async def test_cancelled_error_with_resources(c, s, a):
executing = Event()
lock_executing = Lock()
await lock_executing.acquire()
Expand Down
48 changes: 47 additions & 1 deletion distributed/tests/test_worker_state_machine.py
Expand Up @@ -5,7 +5,7 @@

from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import assert_story, gen_cluster, inc
from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc
from distributed.worker_state_machine import (
ExecuteFailureEvent,
ExecuteSuccessEvent,
Expand Down Expand Up @@ -302,3 +302,49 @@ async def test_fetch_via_amm_to_compute(c, s, a, b):
b.comm_threshold_bytes = old_comm_threshold

await f1


@gen_cluster(client=True)
async def test_cancelled_while_in_flight(c, s, a, b):
event = asyncio.Event()
a.rpc = _LockedCommPool(a.rpc, write_event=event)

x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "flight", a)
y.release()
await wait_for_state("x", "cancelled", a)

# Let the comm from b to a return the result
event.set()
# upon reception, x transitions cancelled->forgotten
while a.tasks:
await asyncio.sleep(0.01)


@gen_cluster(client=True)
async def test_in_memory_while_in_flight(c, s, a, b):
"""
1. A client scatters x to a
2. The scheduler does not know about scattered keys until the three-way round-trip
between client, worker, and scheduler has been completed (see Scheduler.scatter)
3. In the middle of that handshake, a client (not necessarily the same client) calls
``{op: compute-task, key: x}`` on b and then
``{op: compute-task, key: y, who_has: {x: [b]}`` on a, which triggers a
gather_dep call to copy x key from b to a.
4. while x is in flight from b to a, the scatter finishes, which triggers
update_data, which in turn transitions x from flight to memory.
5. later on, gather_dep finishes, but the key is already in memory.
"""
event = asyncio.Event()
a.rpc = _LockedCommPool(a.rpc, write_event=event)

x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "flight", a)
a.update_data({"x": 3})
await wait_for_state("x", "memory", a)

# Let the comm from b to a return the result
event.set()
assert await y == 4 # Data in flight from b has been discarded
2 changes: 1 addition & 1 deletion distributed/utils_test.py
Expand Up @@ -1923,7 +1923,7 @@ class _LockedCommPool(ConnectionPool):
)
# It might be necessary to remove all existing comms
# if the wrapped pool has been used before
>>> w.remove(remote_address)
>>> w.rpc.remove(remote_address)
>>> async def ping_pong():
return await w.rpc(remote_address).ping()
Expand Down
109 changes: 25 additions & 84 deletions distributed/worker.py
Expand Up @@ -3156,61 +3156,32 @@ def total_comm_bytes(self):
)
return self.comm_threshold_bytes

def _filter_deps_for_fetch(
self, to_gather_keys: Iterable[str]
) -> tuple[set[str], set[str], TaskState | None]:
"""Filter a list of keys before scheduling coroutines to fetch data from workers.
def _get_cause(self, keys: Iterable[str]) -> TaskState:
"""For diagnostics, we want to attach a transfer to a single task. This task is
typically the next to be executed but since we're fetching tasks for potentially
many dependents, an exact match is not possible. Additionally, if a key was
fetched through acquire-replicas, dependents may not be known at all.
Returns
-------
in_flight_keys:
The subset of keys in to_gather_keys in state `flight` or `resumed`
cancelled_keys:
The subset of tasks in to_gather_keys in state `cancelled` or `memory`
cause:
The task to attach startstops of this transfer to
The task to attach startstops of this transfer to
"""
in_flight_tasks: set[TaskState] = set()
cancelled_keys: set[str] = set()
for key in to_gather_keys:
ts = self.tasks.get(key)
if ts is None:
continue

# At this point, a task has been transitioned fetch->flight
# flight is only allowed to be transitioned into
# {memory, resumed, cancelled}
# resumed and cancelled will block any further transition until this
# coro has been finished

if ts.state in ("flight", "resumed"):
in_flight_tasks.add(ts)
# If the key is already in memory, the fetch should not happen which
# is signalled by the cancelled_keys
elif ts.state in {"cancelled", "memory"}:
cancelled_keys.add(key)
else:
raise RuntimeError(
f"Task {ts.key} found in illegal state {ts.state}. "
"Only states `flight`, `resumed` and `cancelled` possible."
)

# For diagnostics we want to attach the transfer to a single task. this
# task is typically the next to be executed but since we're fetching
# tasks for potentially many dependents, an exact match is not possible.
# If there are no dependents, this is a pure replica fetch
cause = None
for ts in in_flight_tasks:
for key in keys:
ts = self.tasks[key]
if ts.dependents:
cause = next(iter(ts.dependents))
break
else:
cause = ts
in_flight_keys = {ts.key for ts in in_flight_tasks}
return in_flight_keys, cancelled_keys, cause
return next(iter(ts.dependents))
cause = ts
assert cause # Always at least one key
return cause

def _update_metrics_received_data(
self, start: float, stop: float, data: dict, cause: TaskState, worker: str
self,
start: float,
stop: float,
data: dict[str, Any],
cause: TaskState,
worker: str,
) -> None:

total_bytes = sum(self.tasks[key].get_nbytes() for key in data)
Expand Down Expand Up @@ -3259,7 +3230,7 @@ def _update_metrics_received_data(
async def gather_dep(
self,
worker: str,
to_gather: Iterable[str],
to_gather: Collection[str],
total_nbytes: int,
*,
stimulus_id: str,
Expand All @@ -3283,46 +3254,23 @@ async def gather_dep(
recommendations: Recs = {}
instructions: Instructions = []
response = {}
to_gather_keys: set[str] = set()
cancelled_keys: set[str] = set()

def done_event():
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")

try:
to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch(
to_gather
)

if not to_gather_keys:
self.log.append(
("nothing-to-gather", worker, to_gather, stimulus_id, time())
)
return done_event()

assert cause
# Keep namespace clean since this func is long and has many
# dep*, *ts* variables
del to_gather

self.log.append(
("request-dep", worker, to_gather_keys, stimulus_id, time())
)
logger.debug(
"Request %d keys for task %s from %s",
len(to_gather_keys),
cause,
worker,
)
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
logger.debug("Request %d keys from %s", len(to_gather), worker)

start = time()
response = await get_data_from_worker(
self.rpc, to_gather_keys, worker, who=self.address
self.rpc, to_gather, worker, who=self.address
)
stop = time()
if response["status"] == "busy":
return done_event()

cause = self._get_cause(to_gather)
self._update_metrics_received_data(
start=start,
stop=stop,
Expand Down Expand Up @@ -3375,9 +3323,7 @@ def done_event():
data = response.get("data", {})

if busy:
self.log.append(
("busy-gather", worker, to_gather_keys, stimulus_id, time())
)
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(worker)
Expand All @@ -3388,12 +3334,7 @@ def done_event():
for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
ts.done = True
if d in cancelled_keys:
if ts.state == "cancelled":
recommendations[ts] = "released"
else:
recommendations[ts] = "fetch"
elif d in data:
if d in data:
recommendations[ts] = ("memory", data[d])
elif busy:
recommendations[ts] = "fetch"
Expand Down

0 comments on commit 4420644

Please sign in to comment.