diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index e2d7210a3c4..9d6077dfb94 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import importlib import logging @@ -24,6 +26,7 @@ import distributed from distributed import ( Client, + Future, Nanny, Reschedule, default_client, @@ -2679,7 +2682,6 @@ async def test_gather_dep_exception_one_task(c, s, a, b): event = asyncio.Event() write_queue = asyncio.Queue() - event.clear() b.rpc = _LockedCommPool(b.rpc, write_event=event, write_queue=write_queue) b.rpc.remove(a.address) @@ -2737,10 +2739,12 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): await fut2 -def _acquire_replicas(scheduler, worker, *futures): +def _acquire_replicas( + scheduler: Scheduler, worker: Worker | str, *futures: Future +) -> None: keys = [f.key for f in futures] - - scheduler.stream_comms[worker.address].send( + address = worker if isinstance(worker, str) else worker.address + scheduler.stream_comms[address].send( { "op": "acquire-replicas", "keys": keys, @@ -2753,14 +2757,17 @@ def _acquire_replicas(scheduler, worker, *futures): ) -def _remove_replicas(scheduler, worker, *futures): +def _remove_replicas( + scheduler: Scheduler, worker: Worker | str, *futures: Future +) -> None: keys = [f.key for f in futures] - ws = scheduler.workers[worker.address] + address = worker if isinstance(worker, str) else worker.address + ws = scheduler.workers[address] for k in keys: ts = scheduler.tasks[k] if ws in ts.who_has: scheduler.remove_replica(ts, ws) - scheduler.stream_comms[ws.address].send( + scheduler.stream_comms[address].send( { "op": "remove-replicas", "keys": keys, @@ -2842,6 +2849,36 @@ async def test_acquire_replicas_many(c, s, *workers): await asyncio.sleep(0.001) +@pytest.mark.slow +@gen_cluster(client=True, Worker=Nanny) +async def test_acquire_replicas_already_in_flight(c, s, *nannies): + """Trying to acquire a replica that is already in flight is a no-op""" + + class SlowToFly: + def __getstate__(self): + sleep(0.9) + return {} + + a, b = s.workers + x = c.submit(SlowToFly, workers=[a], key="x") + await wait(x) + y = c.submit(lambda x: 123, x, workers=[b], key="y") + await asyncio.sleep(0.3) + start = time() + _acquire_replicas(s, b, x) + assert await y == 123 + + story = await c.run(lambda dask_worker: dask_worker.story("x"), workers=[b]) + events = [ev for ev in story[b] if ev[-1] >= start] + + assert len(events) == 5 + assert events[0][:3] == ("x", "ensure-task-exists", "flight") + assert events[1][:4] == ("x", "flight", "fetch", "flight") + assert events[2][:1] == ("receive-dep",) + assert events[3][:2] == ("x", "put-in-memory") + assert events[4][:4] == ("x", "flight", "memory", "memory") + + @gen_cluster(client=True) async def test_remove_replica_simple(c, s, a, b): futs = c.map(inc, range(10), workers=[a.address])