Skip to content

Commit

Permalink
Merge branch 'main' into AMM/RetireWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 9, 2021
2 parents c391c40 + e01d777 commit 462288f
Showing 1 changed file with 44 additions and 7 deletions.
51 changes: 44 additions & 7 deletions distributed/tests/test_worker.py
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import importlib
import logging
Expand All @@ -24,6 +26,7 @@
import distributed
from distributed import (
Client,
Future,
Nanny,
Reschedule,
default_client,
Expand Down Expand Up @@ -2679,7 +2682,6 @@ async def test_gather_dep_exception_one_task(c, s, a, b):

event = asyncio.Event()
write_queue = asyncio.Queue()
event.clear()
b.rpc = _LockedCommPool(b.rpc, write_event=event, write_queue=write_queue)
b.rpc.remove(a.address)

Expand Down Expand Up @@ -2737,10 +2739,12 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b):
await fut2


def _acquire_replicas(scheduler, worker, *futures):
def _acquire_replicas(
scheduler: Scheduler, worker: Worker | str, *futures: Future
) -> None:
keys = [f.key for f in futures]

scheduler.stream_comms[worker.address].send(
address = worker if isinstance(worker, str) else worker.address
scheduler.stream_comms[address].send(
{
"op": "acquire-replicas",
"keys": keys,
Expand All @@ -2753,14 +2757,17 @@ def _acquire_replicas(scheduler, worker, *futures):
)


def _remove_replicas(scheduler, worker, *futures):
def _remove_replicas(
scheduler: Scheduler, worker: Worker | str, *futures: Future
) -> None:
keys = [f.key for f in futures]
ws = scheduler.workers[worker.address]
address = worker if isinstance(worker, str) else worker.address
ws = scheduler.workers[address]
for k in keys:
ts = scheduler.tasks[k]
if ws in ts.who_has:
scheduler.remove_replica(ts, ws)
scheduler.stream_comms[ws.address].send(
scheduler.stream_comms[address].send(
{
"op": "remove-replicas",
"keys": keys,
Expand Down Expand Up @@ -2842,6 +2849,36 @@ async def test_acquire_replicas_many(c, s, *workers):
await asyncio.sleep(0.001)


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny)
async def test_acquire_replicas_already_in_flight(c, s, *nannies):
"""Trying to acquire a replica that is already in flight is a no-op"""

class SlowToFly:
def __getstate__(self):
sleep(0.9)
return {}

a, b = s.workers
x = c.submit(SlowToFly, workers=[a], key="x")
await wait(x)
y = c.submit(lambda x: 123, x, workers=[b], key="y")
await asyncio.sleep(0.3)
start = time()
_acquire_replicas(s, b, x)
assert await y == 123

story = await c.run(lambda dask_worker: dask_worker.story("x"), workers=[b])
events = [ev for ev in story[b] if ev[-1] >= start]

assert len(events) == 5
assert events[0][:3] == ("x", "ensure-task-exists", "flight")
assert events[1][:4] == ("x", "flight", "fetch", "flight")
assert events[2][:1] == ("receive-dep",)
assert events[3][:2] == ("x", "put-in-memory")
assert events[4][:4] == ("x", "flight", "memory", "memory")


@gen_cluster(client=True)
async def test_remove_replica_simple(c, s, a, b):
futs = c.map(inc, range(10), workers=[a.address])
Expand Down

0 comments on commit 462288f

Please sign in to comment.