Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP/POC] ordered RPC #8430

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
155 changes: 153 additions & 2 deletions distributed/core.py
Expand Up @@ -31,6 +31,7 @@
from tornado.ioloop import IOLoop

import dask
from dask.typing import NoDefault, no_default
from dask.utils import parse_timedelta

from distributed import profile, protocol
Expand All @@ -55,6 +56,7 @@
has_keyword,
import_file,
iscoroutinefunction,
log_errors,
offload,
recursive_to_dict,
truncate_exception,
Expand All @@ -65,6 +67,7 @@
if TYPE_CHECKING:
from typing_extensions import ParamSpec, Self

from distributed.batched import BatchedSend
from distributed.counter import Digest

P = ParamSpec("P")
Expand Down Expand Up @@ -99,6 +102,11 @@
Status.lookup = {s.name: s for s in Status} # type: ignore


class RPCCall:
def __getattr__(self, key: str) -> Callable[..., Awaitable]:
raise NotImplementedError()

Check warning on line 107 in distributed/core.py

View check run for this annotation

Codecov / codecov/patch

distributed/core.py#L107

Added line #L107 was not covered by tests


class RPCClosed(IOError):
pass

Expand Down Expand Up @@ -427,14 +435,20 @@
"echo": self.echo,
"connection_stream": self.handle_stream,
"dump_state": self._to_dict,
"_ordered_send_payload": self._handle_ordered_send_payload,
}
self.handlers.update(handlers)
if blocked_handlers is None:
blocked_handlers = dask.config.get(
"distributed.%s.blocked-handlers" % type(self).__name__.lower(), []
)
self.blocked_handlers = blocked_handlers
self.stream_handlers = {}
self.stream_handlers = {
"__ordered_send": self._handle_ordered_send,
"__ordered_rcv": self._handle_ordered_rcv,
}
self._side_channel_payload = {}
self._side_channel_arrived = defaultdict(asyncio.Event)
self.stream_handlers.update(stream_handlers or {})

self.id = type(self).__name__ + "-" + str(uuid.uuid4())
Expand Down Expand Up @@ -532,7 +546,15 @@
timeout=timeout,
server=self,
)
import itertools

self._counter = itertools.count()
self._responses = {}
self._waiting_for = deque()
self._ensure_order = asyncio.Condition()

self._batched_comms = {}
self._batched_comms_locks = defaultdict(asyncio.Lock)
self.__stopped = False

async def upload_file(
Expand Down Expand Up @@ -1063,6 +1085,135 @@
await comm.close()
assert comm.closed()

async def _handle_ordered_send_payload(self, sig, payload, origin):
# FIXME: If something goes wrong, this can leak memory
# We'd need a callback for when the incoming connection is closed to
# clean this up
key = (origin, sig)
self._side_channel_payload[key] = payload
self._side_channel_arrived[key].set()

async def _handle_ordered_send(
self, sig, user_op, origin, user_kwargs, use_side_channel, **extra
):
# Note: The backchannel is currently unique. It's currently unclear if
# we need more control here
bcomm = await self._get_bcomm(origin)
try:
if use_side_channel:
assert user_kwargs is None
key = (origin, sig)
await self._side_channel_arrived[key].wait()
user_kwargs = self._side_channel_payload.pop(key)
result = self.handlers[user_op](**merge(extra, user_kwargs))
if inspect.isawaitable(result):
result = await result
bcomm.send({"op": "__ordered_rcv", "sig": sig, "result": result})
except Exception as e:
exc_info = error_message(e)
bcomm.send({"op": "__ordered_rcv", "sig": sig, "exc_info": exc_info})

async def _handle_ordered_rcv(self, sig, result=no_default, exc_info=no_default):
fut = self._responses[sig]
if result is not no_default:
assert exc_info is no_default
fut.set_result(result)
elif exc_info is not no_default:
assert result is no_default
_, exc, tb = clean_exception(**exc_info)
fut.set_exception(exc.with_traceback(tb))
else:
raise RuntimeError("Unreachable")

Check warning on line 1126 in distributed/core.py

View check run for this annotation

Codecov / codecov/patch

distributed/core.py#L1126

Added line #L1126 was not covered by tests

@log_errors
async def ordered_rpc(
self,
addr: str | NoDefault = no_default,
bcomm: BatchedSend | NoDefault = no_default,
use_side_channel: bool = False,
) -> RPCCall:
# TODO: Allow different channels?
if addr is not no_default:
assert bcomm is no_default
bcomm = await self._get_bcomm(addr)
else:
assert bcomm is not no_default
addr = bcomm.comm.peer_address

server = self

class OrderedRPC(RPCCall):
def __init__(self, bcomm):
self._bcomm = bcomm

def __getattr__(self, key):
async def send_recv_from_rpc(**kwargs):
sig = next(server._counter)
msg = {
"op": "__ordered_send",
"sig": sig,
"user_op": key,
"origin": server.address,
"use_side_channel": use_side_channel,
}
if not use_side_channel:
msg["user_kwargs"] = kwargs
else:
msg["user_kwargs"] = None
self._bcomm.send(msg)
fut = asyncio.Future()
server._responses[sig] = fut
server._waiting_for.append(sig)
if use_side_channel:
# Note: We may even want to consider moving this to a
# background task
async def _():
await server.rpc(addr)._ordered_send_payload(
sig=sig,
payload=kwargs,
origin=server.address,
)

server._ongoing_background_tasks.call_soon(_)

async def watch_comm():
while True:
if self._bcomm.comm.closed():
fut.set_exception(CommClosedError)
break
await asyncio.sleep(0.1)

t = asyncio.create_task(watch_comm())

def is_next():
return server._waiting_for[0] == sig

try:
async with server._ensure_order:
await server._ensure_order.wait_for(is_next)
return await fut
finally:
t.cancel()
server._waiting_for.popleft()

return send_recv_from_rpc

return OrderedRPC(bcomm)

async def _get_bcomm(self, addr):
async with self._batched_comms_locks[addr]:
if addr in self._batched_comms:
bcomm = self._batched_comms[addr]
if not bcomm.comm.closed():
return bcomm
from distributed.batched import BatchedSend

comm = await self.rpc.connect(addr)
await comm.write({"op": "connection_stream"})
self._batched_comms[addr] = bcomm = BatchedSend(interval=0.01)
bcomm.start(comm)
return bcomm

async def close(self, timeout: float | None = None, reason: str = "") -> None:
try:
for pc in self.periodic_callbacks.values():
Expand Down Expand Up @@ -1365,7 +1516,7 @@
return "<rpc to %r, %d comms>" % (self.address, len(self.comms))


class PooledRPCCall:
class PooledRPCCall(RPCCall):
"""The result of ConnectionPool()('host:port')

See Also:
Expand Down
6 changes: 3 additions & 3 deletions distributed/scheduler.py
Expand Up @@ -4173,7 +4173,7 @@ async def log_errors(func):
def heartbeat_worker(
self,
*,
address: str,
worker: str,
resolve_address: bool = True,
now: float | None = None,
resources: dict[str, float] | None = None,
Expand All @@ -4182,7 +4182,7 @@ def heartbeat_worker(
executing: dict[Key, float] | None = None,
extensions: dict | None = None,
) -> dict[str, Any]:
address = self.coerce_address(address, resolve_address)
address = self.coerce_address(worker, resolve_address)
address = normalize_address(address)
ws = self.workers.get(address)
if ws is None:
Expand Down Expand Up @@ -4361,7 +4361,7 @@ async def add_worker(
self.aliases[name] = address

self.heartbeat_worker(
address=address,
worker=address,
resolve_address=resolve_address,
now=now,
resources=resources,
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_core.py
Expand Up @@ -313,7 +313,7 @@ async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:

if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
id=self.id, run_id=self.run_id, key=key, assigned_worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_rechunk.py
Expand Up @@ -774,7 +774,7 @@ def create_run_on_worker(
local_address=plugin.worker.address,
rpc=plugin.worker.rpc,
digest_metric=plugin.worker.digest_metric,
scheduler=plugin.worker.scheduler,
scheduler=plugin.worker.scheduler_ordered, # type: ignore
memory_limiter_disk=plugin.memory_limiter_disk,
memory_limiter_comms=plugin.memory_limiter_comms,
disk=self.disk,
Expand Down
10 changes: 7 additions & 3 deletions distributed/shuffle/_scheduler_plugin.py
Expand Up @@ -78,7 +78,9 @@ async def start(self, scheduler: Scheduler) -> None:
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)

async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
async def barrier(
self, id: ShuffleId, run_id: int, consistent: bool, worker: None
) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
Expand All @@ -98,7 +100,9 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
workers=list(shuffle.participating_workers),
)

def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> dict:
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, assigned_worker: str, worker: str
) -> dict:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
return {
Expand All @@ -111,7 +115,7 @@ def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> di
"message": f"Request invalid, expected {run_id=} for {shuffle}",
}
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
self._set_restriction(ts, assigned_worker)
return {"status": "OK"}

def heartbeat(self, ws: WorkerState, data: dict) -> None:
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_shuffle.py
Expand Up @@ -574,7 +574,7 @@ def create_run_on_worker(
local_address=plugin.worker.address,
rpc=plugin.worker.rpc,
digest_metric=plugin.worker.digest_metric,
scheduler=plugin.worker.scheduler,
scheduler=plugin.worker.scheduler_ordered, # type: ignore
memory_limiter_disk=plugin.memory_limiter_disk
if self.disk
else ResourceLimiter(None),
Expand Down
6 changes: 4 additions & 2 deletions distributed/shuffle/tests/test_shuffle.py
Expand Up @@ -2486,10 +2486,12 @@ def __init__(self, scheduler: Scheduler):
self.in_barrier = asyncio.Event()
self.block_barrier = asyncio.Event()

async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
async def barrier(
self, id: ShuffleId, run_id: int, consistent: bool, worker: None
) -> None:
self.in_barrier.set()
await self.block_barrier.wait()
return await super().barrier(id, run_id, consistent)
return await super().barrier(id, run_id, consistent, worker)


@gen_cluster(client=True)
Expand Down