Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP/POC] ordered RPC #8430

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 94 additions & 1 deletion distributed/core.py
Expand Up @@ -434,7 +434,10 @@
"distributed.%s.blocked-handlers" % type(self).__name__.lower(), []
)
self.blocked_handlers = blocked_handlers
self.stream_handlers = {}
self.stream_handlers = {
"__ordered_send": self._handle_ordered_send,
"__ordered_rcv": self._handle_ordered_rcv,
}
self.stream_handlers.update(stream_handlers or {})

self.id = type(self).__name__ + "-" + str(uuid.uuid4())
Expand Down Expand Up @@ -532,7 +535,15 @@
timeout=timeout,
server=self,
)
import itertools

self._counter = itertools.count()
self._responses = {}
self._waiting_for = deque()
self._ensure_order = asyncio.Condition()

self._batched_comms = {}
self._batched_comms_locks = defaultdict(asyncio.Lock)
self.__stopped = False

async def upload_file(
Expand Down Expand Up @@ -1063,6 +1074,88 @@
await comm.close()
assert comm.closed()

async def _handle_ordered_send(self, sig, user_op, origin, user_kwargs, **extra):
# Note: The backchannel is currently unique. It's currently unclear if
# we need more control here
bcomm = await self._get_bcomm(origin)
try:
result = self.handlers[user_op](**merge(extra, user_kwargs))
if inspect.isawaitable(result):
result = await result
bcomm.send({"op": "__ordered_rcv", "sig": sig, "result": result})
except Exception as e:
exc_info = error_message(e)
bcomm.send({"op": "__ordered_rcv", "sig": sig, "exc_info": exc_info})

Check warning on line 1088 in distributed/core.py

View check run for this annotation

Codecov / codecov/patch

distributed/core.py#L1086-L1088

Added lines #L1086 - L1088 were not covered by tests

async def _handle_ordered_rcv(self, sig, result=None, exc_info=None):
fut = self._responses[sig]
if result is not None:
assert not exc_info
fut.set_result(result)
elif exc_info is not None:
assert not result
_, exc, tb = clean_exception(**exc_info)
fut.set_exception(exc.with_traceback(tb))

Check warning on line 1098 in distributed/core.py

View check run for this annotation

Codecov / codecov/patch

distributed/core.py#L1095-L1098

Added lines #L1095 - L1098 were not covered by tests
else:
raise RuntimeError("Unreachable")

Check warning on line 1100 in distributed/core.py

View check run for this annotation

Codecov / codecov/patch

distributed/core.py#L1100

Added line #L1100 was not covered by tests

async def ordered_rpc(self, addr=None, bcomm=None):
# TODO: Allow different channels?
if addr is not None:
assert bcomm is None
bcomm = await self._get_bcomm(addr)
else:
assert bcomm is not None

server = self

class OrderedRPC:
def __init__(self, bcomm):
self._bcomm = bcomm

def __getattr__(self, key):
async def send_recv_from_rpc(**kwargs):
sig = next(server._counter)
msg = {
"op": "__ordered_send",
"sig": sig,
"user_op": key,
"user_kwargs": kwargs,
"origin": server.address,
}
self._bcomm.send(msg)
fut = asyncio.Future()
server._responses[sig] = fut
server._waiting_for.append(sig)

def is_next():
return server._waiting_for[0] == sig

async with server._ensure_order:
await server._ensure_order.wait_for(is_next)
try:
return await fut
finally:
server._waiting_for.popleft()

return send_recv_from_rpc

return OrderedRPC(bcomm)

async def _get_bcomm(self, addr):
async with self._batched_comms_locks[addr]:
if addr in self._batched_comms:
bcomm = self._batched_comms[addr]
if not bcomm.comm.closed():
return bcomm
from distributed.batched import BatchedSend

self._batched_comms[addr] = bcomm = BatchedSend(interval=0.01)
comm = await self.rpc.connect(addr)
await comm.write({"op": "connection_stream"})
bcomm.start(comm)
return bcomm

async def close(self, timeout: float | None = None, reason: str = "") -> None:
try:
for pc in self.periodic_callbacks.values():
Expand Down
6 changes: 3 additions & 3 deletions distributed/scheduler.py
Expand Up @@ -4173,7 +4173,7 @@ async def log_errors(func):
def heartbeat_worker(
self,
*,
address: str,
worker: str,
resolve_address: bool = True,
now: float | None = None,
resources: dict[str, float] | None = None,
Expand All @@ -4182,7 +4182,7 @@ def heartbeat_worker(
executing: dict[Key, float] | None = None,
extensions: dict | None = None,
) -> dict[str, Any]:
address = self.coerce_address(address, resolve_address)
address = self.coerce_address(worker, resolve_address)
address = normalize_address(address)
ws = self.workers.get(address)
if ws is None:
Expand Down Expand Up @@ -4361,7 +4361,7 @@ async def add_worker(
self.aliases[name] = address

self.heartbeat_worker(
address=address,
worker=address,
resolve_address=resolve_address,
now=now,
resources=resources,
Expand Down
57 changes: 57 additions & 0 deletions distributed/tests/test_core.py
Expand Up @@ -1481,3 +1481,60 @@ def sync_handler(val):
assert ledger == list(range(n))
finally:
await comm.close()


@gen_test()
async def test_ordered_rpc():
entered_sleep = asyncio.Event()
i = 0

async def sleep(duration):
nonlocal i
entered_sleep.set()
await asyncio.sleep(duration)
try:
return i
finally:
i += 1

class MyServer(Server):
def __init__(self, *args, **kwargs):
handlers = {
"sleep": sleep,
"do_work": self.do_work,
}
super().__init__(handlers, *args, **kwargs)

async def do_work(self, other_addr, ordered=False):
if ordered:
r = await self.ordered_rpc(other_addr)
else:
r = self.rpc(other_addr)

t1 = asyncio.create_task(r.sleep(duration=1))

async def wait_to_unblock(error=False):
await entered_sleep.wait()
if error:
raise RuntimeError("error")
return await r.sleep(duration=0)

t2 = asyncio.create_task(wait_to_unblock(error=True))
t3 = asyncio.create_task(wait_to_unblock())

await asyncio.wait([t1, t2, t3])
assert t2.exception
r1, r3 = await asyncio.gather(t1, t3)
try:
return r1 == 0 and r3 == 1
finally:
nonlocal i
entered_sleep.clear()
i = 0

async with MyServer() as s1, MyServer() as s2:
await s1.listen()
await s2.listen()
async with rpc(s2.address) as r:
assert not await r.do_work(other_addr=s1.address)
assert await r.do_work(other_addr=s1.address, ordered=True)
33 changes: 0 additions & 33 deletions distributed/tests/test_worker.py
Expand Up @@ -1751,39 +1751,6 @@ async def test_shutdown_on_scheduler_comm_closed(s, a):
assert f"Connection to {s.address} has been closed" in logger.getvalue()


@gen_cluster(nthreads=[])
async def test_heartbeat_comm_closed(s, monkeypatch):
with captured_logger("distributed.worker", level=logging.WARNING) as logger:

def bad_heartbeat_worker(*args, **kwargs):
raise CommClosedError()

async with Worker(s.address) as w:
# Trigger CommClosedError during worker heartbeat
monkeypatch.setattr(w.scheduler, "heartbeat_worker", bad_heartbeat_worker)

await w.heartbeat()
assert w.status == Status.running
logs = logger.getvalue()
assert "Failed to communicate with scheduler during heartbeat" in logs
assert "Traceback" in logs


@gen_cluster(nthreads=[("", 1)], worker_kwargs={"heartbeat_interval": "100s"})
async def test_heartbeat_missing(s, a, monkeypatch):
async def missing_heartbeat_worker(*args, **kwargs):
return {"status": "missing"}

with captured_logger("distributed.worker", level=logging.WARNING) as wlogger:
monkeypatch.setattr(a.scheduler, "heartbeat_worker", missing_heartbeat_worker)
await a.heartbeat()
assert a.status == Status.closed
assert "Scheduler was unaware of this worker" in wlogger.getvalue()

while s.workers:
await asyncio.sleep(0.01)


@gen_cluster(nthreads=[("", 1)], worker_kwargs={"heartbeat_interval": "100s"})
async def test_heartbeat_missing_real_cluster(s, a):
# The idea here is to create a situation where `s.workers[a.address]`,
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker_metrics.py
Expand Up @@ -597,7 +597,7 @@ async def test_new_metrics_during_heartbeat(c, s, a):
a.digest_metric(("execute", span.id, "x", "test", "test"), 1)
await asyncio.sleep(0)
await hb_task
assert n > 9
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 9 feels very magical. since the heartbeat doesn't have to open a new comm, we need fewer ticks until it completes and this is not guaranteed to be past 9.

I'm not entirely convinced this test still works/make sense/is relevant.

assert n > 1
await a.heartbeat()

assert a.digests_total["execute", span.id, "x", "test", "test"] == n
Expand Down
8 changes: 3 additions & 5 deletions distributed/worker.py
Expand Up @@ -786,6 +786,7 @@ def __init__(
BaseWorker.__init__(self, state)

self.scheduler = self.rpc(scheduler_addr)
self.scheduler_orderd = None
self.execution_state = {
"scheduler": self.scheduler.address,
"ioloop": self.loop,
Expand Down Expand Up @@ -1225,6 +1226,7 @@ async def _register_with_scheduler(self) -> None:
raise ValueError(f"Unexpected response from register: {response!r}")

self.batched_stream.start(comm)
self.scheduler_ordered = await self.ordered_rpc(bcomm=self.batched_stream)
self.status = Status.running

await asyncio.gather(
Expand All @@ -1249,9 +1251,7 @@ async def heartbeat(self) -> None:
logger.debug("Heartbeat: %s", self.address)
try:
start = time()
response = await retry_operation(
self.scheduler.heartbeat_worker,
address=self.contact_address,
response = await self.scheduler_ordered.heartbeat_worker(
now=start,
metrics=await self.get_metrics(),
executing={
Expand Down Expand Up @@ -1286,8 +1286,6 @@ async def heartbeat(self) -> None:
)
self.bandwidth_workers.clear()
self.bandwidth_types.clear()
except OSError:
logger.exception("Failed to communicate with scheduler during heartbeat.")
except Exception:
logger.exception("Unexpected exception during heartbeat. Closing worker.")
await self.close()
Expand Down