diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 1abb00db55..a29239db46 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -2,6 +2,7 @@ import distributed from distributed import Event, Lock, Worker +from distributed.client import wait from distributed.utils_test import ( _LockedCommPool, assert_story, @@ -264,11 +265,12 @@ async def test_in_flight_lost_after_resumed(c, s, b): block_get_data = asyncio.Lock() in_get_data = asyncio.Event() + await block_get_data.acquire() lock_executing = Lock() def block_execution(lock): with lock: - return + return 1 class BlockedGetData(Worker): async def get_data(self, comm, *args, **kwargs): @@ -281,15 +283,12 @@ async def get_data(self, comm, *args, **kwargs): block_execution, lock_executing, workers=[a.address], - allow_other_workers=True, key="fut1", ) # Ensure fut1 is in memory but block any further execution afterwards to # ensure we control when the recomputation happens - await fut1 + await wait(fut1) await lock_executing.acquire() - in_get_data.clear() - await block_get_data.acquire() fut2 = c.submit(inc, fut1, workers=[b.address], key="fut2") # This ensures that B already fetches the task, i.e. after this the task @@ -298,6 +297,7 @@ async def get_data(self, comm, *args, **kwargs): assert fut1.key in b.tasks assert b.tasks[fut1.key].state == "flight" + s.set_restrictions({fut1.key: [a.address, b.address]}) # It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled # to be recomputed on B await s.remove_worker(a.address, "foo", close=False, safe=True)