Skip to content

Commit

Permalink
add Server to client
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jan 19, 2024
1 parent efb9045 commit 395fcc2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 89 deletions.
107 changes: 20 additions & 87 deletions distributed/client.py
Expand Up @@ -47,7 +47,7 @@
)
from dask.widgets import get_template

from distributed.core import ErrorMessage, OKMessage
from distributed.core import ErrorMessage, OKMessage, Server
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import Deadline, wait_for

Expand All @@ -66,11 +66,9 @@
from distributed.compatibility import PeriodicCallback
from distributed.core import (
CommClosedError,
ConnectionPool,
PooledRPCCall,
Status,
clean_exception,
connect,
rpc,
)
from distributed.diagnostics.plugin import (
Expand Down Expand Up @@ -976,7 +974,7 @@ def __init__(
self._set_config = dask.config.set(scheduler="dask.distributed")
self._event_handlers = {}

self._stream_handlers = {
stream_handlers = {
"key-in-memory": self._handle_key_in_memory,
"lost-data": self._handle_lost_data,
"cancelled-keys": self._handle_cancelled_keys,
Expand All @@ -993,15 +991,17 @@ def __init__(
"erred": self._handle_task_erred,
}

self.rpc = ConnectionPool(
limit=connection_limit,
serializers=serializers,
deserializers=deserializers,
self.server = Server(
{},
stream_handlers=stream_handlers,
connection_limit=connection_limit,
deserialize=True,
connection_args=self.connection_args,
deserializers=deserializers,
serializers=serializers,
timeout=timeout,
server=self,
connection_args=self.connection_args,
)
self.rpc = self.server.rpc

self.extensions = {
name: extension(self) for name, extension in extensions.items()
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def _send_to_scheduler(self, msg):
async def _start(self, timeout=no_default, **kwargs):
self.status = "connecting"

await self.rpc.start()
await self.server

if timeout is no_default:
timeout = self._timeout
Expand Down Expand Up @@ -1289,7 +1289,7 @@ async def _start(self, timeout=no_default, **kwargs):
self._gather_semaphore = asyncio.Semaphore(5)

if self.scheduler is None:
self.scheduler = self.rpc(address)
self.scheduler = self.server.rpc(address)
self.scheduler_comm = None

try:
Expand All @@ -1306,7 +1306,9 @@ async def _start(self, timeout=no_default, **kwargs):

await self.preloads.start()

self._handle_report_task = asyncio.create_task(self._handle_report())
self._handle_report_task = asyncio.create_task(
self.server.handle_stream(self.scheduler_comm.comm)
)

return self

Expand Down Expand Up @@ -1355,9 +1357,7 @@ async def _ensure_connected(self, timeout=None):
self._connecting_to_scheduler = True

try:
comm = await connect(
self.scheduler.address, timeout=timeout, **self.connection_args
)
comm = await self.server.rpc.connect(self.scheduler.address)
comm.name = "Client->Scheduler"
if timeout is not None:
await wait_for(self._update_scheduler_info(), timeout)
Expand Down Expand Up @@ -1543,63 +1543,6 @@ def _release_key(self, key):
{"op": "client-releases-keys", "keys": [key], "client": self.id}
)

@log_errors
async def _handle_report(self):
"""Listen to scheduler"""
try:
while True:
if self.scheduler_comm is None:
break
try:
msgs = await self.scheduler_comm.comm.read()
except CommClosedError:
if is_python_shutting_down():
return
if self.status == "running":
if self.cluster and self.cluster.status in (
Status.closed,
Status.closing,
):
# Don't attempt to reconnect if cluster are already closed.
# Instead close down the client.
await self._close()
return
logger.info("Client report stream closed to scheduler")
logger.info("Reconnecting...")
self.status = "connecting"
await self._reconnect()
continue
else:
break
if not isinstance(msgs, (list, tuple)):
msgs = (msgs,)

breakout = False
for msg in msgs:
logger.debug("Client receives message %s", msg)

if "status" in msg and "error" in msg["status"]:
typ, exc, tb = clean_exception(**msg)
raise exc.with_traceback(tb)

op = msg.pop("op")

if op == "close" or op == "stream-closed":
breakout = True
break

try:
handler = self._stream_handlers[op]
result = handler(**msg)
if inspect.isawaitable(result):
await result
except Exception as e:
logger.exception(e)
if breakout:
break
except (CancelledError, asyncio.CancelledError):
pass

def _handle_key_in_memory(self, key=None, type=None, workers=None):
state = self.futures.get(key)
if state is not None:
Expand Down Expand Up @@ -1707,29 +1650,19 @@ async def _close(self, fast=False):
self._send_to_scheduler({"op": "close-client"})
self._send_to_scheduler({"op": "close-stream"})
async with self._wait_for_handle_report_task(fast=fast):
if (
self.scheduler_comm
and self.scheduler_comm.comm
and not self.scheduler_comm.comm.closed()
):
await self.scheduler_comm.close()

for key in list(self.futures):
self._release_key(key=key)

if self._start_arg is None:
with suppress(AttributeError):
await self.cluster.close()

await self.rpc.close()

self.status = "closed"
await self.server.close()

if _get_global_client() is self:
_set_global_client(None)
self.status = "closed"

with suppress(AttributeError):
await self.scheduler.close_rpc()
if _get_global_client() is self:
_set_global_client(None)

self.scheduler = None
self.status = "closed"
Expand Down
2 changes: 1 addition & 1 deletion distributed/pubsub.py
Expand Up @@ -175,7 +175,7 @@ class PubSubClientExtension:

def __init__(self, client):
self.client = client
self.client._stream_handlers.update({"pubsub-msg": self.handle_message})
self.client.server.stream_handlers.update({"pubsub-msg": self.handle_message})

self.subscribers = defaultdict(weakref.WeakSet)

Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_client.py
Expand Up @@ -4006,7 +4006,7 @@ async def test_get_versions_async(c, s, a, b):

@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "200ms"})
async def test_get_versions_rpc_error(c, s, a, b):
a.stop()
a.server.stop()
v = await c.get_versions()
assert v.keys() == {"scheduler", "client", "workers"}
assert v["workers"].keys() == {b.address}
Expand Down

0 comments on commit 395fcc2

Please sign in to comment.