Skip to content

Commit

Permalink
Fix race condition for published futures with annotations (#8577)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Apr 15, 2024
1 parent 42c479f commit 0e2be4a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
11 changes: 8 additions & 3 deletions distributed/client.py
Expand Up @@ -2650,18 +2650,23 @@ def retry(self, futures, asynchronous=None):
@log_errors
async def _publish_dataset(self, *args, name=None, override=False, **kwargs):
coroutines = []
uid = uuid.uuid4().hex
self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid})

def add_coro(name, data):
keys = [f.key for f in futures_of(data)]
coroutines.append(
self.scheduler.publish_put(

async def _():
await self.scheduler.publish_wait_flush(uid=uid)
await self.scheduler.publish_put(
keys=keys,
name=name,
data=to_serialize(data),
override=override,
client=self.id,
)
)

coroutines.append(_())

if name:
if len(args) == 0:
Expand Down
14 changes: 14 additions & 0 deletions distributed/publish.py
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio
from collections import defaultdict
from collections.abc import MutableMapping

from dask.utils import stringify
Expand All @@ -25,9 +27,21 @@ def __init__(self, scheduler):
"publish_put": self.put,
"publish_get": self.get,
"publish_delete": self.delete,
"publish_wait_flush": self.flush_wait,
}
stream_handlers = {
"publish_flush_batched_send": self.flush_receive,
}

self.scheduler.handlers.update(handlers)
self.scheduler.stream_handlers.update(stream_handlers)
self._flush_received = defaultdict(asyncio.Event)

def flush_receive(self, uid, **kwargs):
self._flush_received[uid].set()

async def flush_wait(self, uid):
await self._flush_received[uid].wait()

@log_errors
def put(self, keys=None, data=None, name=None, override=False, client=None):
Expand Down
28 changes: 28 additions & 0 deletions distributed/tests/test_publish.py
Expand Up @@ -11,6 +11,7 @@
from distributed.metrics import time
from distributed.protocol import Serialized
from distributed.utils_test import gen_cluster, inc
from distributed.worker import get_worker


@gen_cluster()
Expand Down Expand Up @@ -301,3 +302,30 @@ async def test_deserialize_client(c, s, a, b):
from distributed.client import _current_client

assert _current_client.get() is c


@gen_cluster(client=True, worker_kwargs={"resources": {"A": 1}})
async def test_publish_submit_ordering(c, s, a, b):
RESOURCES = {"A": 1}

def _retrieve_annotations():
worker = get_worker()
task = worker.state.tasks.get(worker.get_current_task())
return task.annotations

# If publish does not take the same comm channel as the submit, it can
# happen that the publish message reaches the scheduler before the submit
# such that the state of the published future is not the one that has been
# requested from the submit. Particularly, this lets us drop annotations
# The current implementation does in fact not use the same channel due to
# serialization issue (including Futures in BatchedSend appends them to the
# "recent messages" log which screws with the refcounting) but ensure that
# all queued up messages are flushed and received by the schduler befure
# publishing
future = c.submit(_retrieve_annotations, resources=RESOURCES, pure=False)

await c.publish_dataset(future, name="foo")
assert await c.list_datasets() == ("foo",)

result = await future.result()
assert result == {"resources": RESOURCES}

0 comments on commit 0e2be4a

Please sign in to comment.