Skip to content

Commit

Permalink
POC ordered RPC
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 21, 2023
1 parent 1c74474 commit bfdf3e0
Showing 1 changed file with 128 additions and 0 deletions.
128 changes: 128 additions & 0 deletions distributed/tests/test_core.py
Expand Up @@ -1481,3 +1481,131 @@ def sync_handler(val):
assert ledger == list(range(n))
finally:
await comm.close()


import itertools
from collections import deque


@gen_test()
async def test_ordered_rpc():
entered_sleep = asyncio.Event()
i = 0

async def sleep(duration):
nonlocal i
entered_sleep.set()
await asyncio.sleep(duration)
try:
return i
finally:
i += 1

class MyServer(Server):
def __init__(self, handlers, *args, stream_handlers=None, **kwargs):
self._counter = itertools.count()
self._responses = {}
self._waiting_for = deque()
self._ensure_order = asyncio.Condition()
handlers = handlers or {}
handlers["do_work"] = self.do_work
handlers["sleep"] = sleep
stream_handlers = stream_handlers or {}
stream_handlers["__ordered_send"] = self._handle_ordered_send
stream_handlers["__ordered_rcv"] = self._handle_ordered_rcv
super().__init__(handlers, *args, stream_handlers=stream_handlers, **kwargs)
self._batched_comms = {}

async def _handle_ordered_rcv(self, sig, result):
fut = self._responses[sig]
fut.set_result(result)

async def _get_bcomm(self, addr):
if addr in self._batched_comms:
bcomm = self._batched_comms[addr]
# Another ordered_rpc came first and is still connecting. Wait
# until that is finished before proceeding
# TODO: This should be an event or smth similar
while not bcomm.comm:
await asyncio.sleep(0.01)
else:
self._batched_comms[addr] = bcomm = BatchedSend(interval=0.01)
comm = await self.rpc.connect(addr)
await comm.write({"op": "connection_stream"})
bcomm.start(comm)
return bcomm

async def _handle_ordered_send(self, sig, user_op, origin, user_kwargs):
try:
result = await self.handlers[user_op](**user_kwargs)
bcomm = await self._get_bcomm(origin)
bcomm.send({"op": "__ordered_rcv", "sig": sig, "result": result})
except Exception as e:
raise NotImplementedError()

# This is overriding the rpc property of Server
async def ordered_rpc(self, addr):
bcomm = await self._get_bcomm(addr)
server = self

class OrderedRPC:
def __init__(self, bcomm):
self._bcomm = bcomm

def __getattr__(self, key):
async def send_recv_from_rpc(**kwargs):
msg = {
"op": "__ordered_send",
"sig": next(server._counter),
"user_op": key,
"user_kwargs": kwargs,
"origin": server.address,
}
sig = msg["sig"]
self._bcomm.send(msg)
fut = asyncio.Future()
server._responses[sig] = fut
server._waiting_for.append(sig)

def is_next():
return server._waiting_for[0] == sig

async with server._ensure_order:
await server._ensure_order.wait_for(is_next)
try:
return await fut
finally:
server._waiting_for.popleft()

return send_recv_from_rpc

return OrderedRPC(bcomm)

async def do_work(self, other_addr, ordered=False):
if ordered:
r = await self.ordered_rpc(other_addr)
else:
r = self.rpc(other_addr)

t1 = asyncio.create_task(r.sleep(duration=1))

async def wait_to_unblock():
await entered_sleep.wait()
return await r.sleep(duration=0)

t2 = asyncio.create_task(wait_to_unblock())

r1, r2 = await asyncio.gather(t1, t2)
try:
return r1 == 0 and r2 == 1
finally:
nonlocal i
entered_sleep.clear()
i = 0

async with MyServer({}) as s1, MyServer({}) as s2:
await s1.listen()
await s2.listen()
async with rpc(s2.address) as r:
assert not await r.do_work(other_addr=s1.address)
assert await r.do_work(other_addr=s1.address, ordered=True)

0 comments on commit bfdf3e0

Please sign in to comment.