Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Decompose server class #8468

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
159 changes: 159 additions & 0 deletions distributed/_async_taskgroup.py
@@ -0,0 +1,159 @@
from __future__ import annotations

import asyncio
import threading
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, TypeVar

if TYPE_CHECKING:
from typing_extensions import ParamSpec

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
Coro = Coroutine[Any, Any, T]


class _LoopBoundMixin:
"""Backport of the private asyncio.mixins._LoopBoundMixin from 3.11"""

_global_lock = threading.Lock()

_loop = None

def _get_loop(self):
loop = asyncio.get_running_loop()

if self._loop is None:
with self._global_lock:
if self._loop is None:
self._loop = loop
if loop is not self._loop:
raise RuntimeError(f"{self!r} is bound to a different event loop")

Check warning on line 32 in distributed/_async_taskgroup.py

View check run for this annotation

Codecov / codecov/patch

distributed/_async_taskgroup.py#L32

Added line #L32 was not covered by tests
return loop


class AsyncTaskGroupClosedError(RuntimeError):
pass


def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]:
"""Decorator to delay the evaluation of a coroutine function by the given delay in seconds."""

async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
await asyncio.sleep(delay)
return await corofunc(*args, **kwargs)

return wrapper


class AsyncTaskGroup(_LoopBoundMixin):
"""Collection tracking all currently running asynchronous tasks within a group"""

#: If True, the group is closed and does not allow adding new tasks.
closed: bool

def __init__(self) -> None:
self.closed = False
self._ongoing_tasks: set[asyncio.Task[None]] = set()

def call_soon(
self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs
) -> None:
"""Schedule a coroutine function to be executed as an `asyncio.Task`.

The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments
as an `asyncio.Task`.

Parameters
----------
afunc
Coroutine function to schedule.
*args
Arguments to be passed to `afunc`.
**kwargs
Keyword arguments to be passed to `afunc`

Returns
-------
None

Raises
------
AsyncTaskGroupClosedError
If the task group is closed.
"""
if self.closed: # Avoid creating a coroutine
raise AsyncTaskGroupClosedError(
"Cannot schedule a new coroutine function as the group is already closed."
)
task = self._get_loop().create_task(afunc(*args, **kwargs))
task.add_done_callback(self._ongoing_tasks.remove)
self._ongoing_tasks.add(task)
return None

def call_later(
self,
delay: float,
afunc: Callable[P, Coro[None]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`.

The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments
as an `asyncio.Task` that is executed after `delay` seconds.

Parameters
----------
delay
Delay in seconds.
afunc
Coroutine function to schedule.
*args
Arguments to be passed to `afunc`.
**kwargs
Keyword arguments to be passed to `afunc`

Returns
-------
The None

Raises
------
AsyncTaskGroupClosedError
If the task group is closed.
"""
self.call_soon(_delayed(afunc, delay), *args, **kwargs)

def close(self) -> None:
"""Closes the task group so that no new tasks can be scheduled.

Existing tasks continue to run.
"""
self.closed = True

async def stop(self) -> None:
"""Close the group and stop all currently running tasks.

Closes the task group and cancels all tasks. All tasks are cancelled
an additional time for each time this task is cancelled.
"""
self.close()

current_task = asyncio.current_task(self._get_loop())
err = None
while tasks_to_stop := (self._ongoing_tasks - {current_task}):
for task in tasks_to_stop:
task.cancel()
try:
await asyncio.wait(tasks_to_stop)
except asyncio.CancelledError as e:
err = e

if err is not None:
raise err

def __len__(self):
return len(self._ongoing_tasks)
107 changes: 20 additions & 87 deletions distributed/client.py
Expand Up @@ -47,7 +47,7 @@
)
from dask.widgets import get_template

from distributed.core import ErrorMessage, OKMessage
from distributed.core import ErrorMessage, OKMessage, Server
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import Deadline, wait_for

Expand All @@ -66,11 +66,9 @@
from distributed.compatibility import PeriodicCallback
from distributed.core import (
CommClosedError,
ConnectionPool,
PooledRPCCall,
Status,
clean_exception,
connect,
rpc,
)
from distributed.diagnostics.plugin import (
Expand Down Expand Up @@ -976,7 +974,7 @@ def __init__(
self._set_config = dask.config.set(scheduler="dask.distributed")
self._event_handlers = {}

self._stream_handlers = {
stream_handlers = {
"key-in-memory": self._handle_key_in_memory,
"lost-data": self._handle_lost_data,
"cancelled-keys": self._handle_cancelled_keys,
Expand All @@ -993,15 +991,17 @@ def __init__(
"erred": self._handle_task_erred,
}

self.rpc = ConnectionPool(
limit=connection_limit,
serializers=serializers,
deserializers=deserializers,
self.server = Server(
{},
stream_handlers=stream_handlers,
connection_limit=connection_limit,
deserialize=True,
connection_args=self.connection_args,
deserializers=deserializers,
serializers=serializers,
timeout=timeout,
server=self,
connection_args=self.connection_args,
)
self.rpc = self.server.rpc

self.extensions = {
name: extension(self) for name, extension in extensions.items()
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def _send_to_scheduler(self, msg):
async def _start(self, timeout=no_default, **kwargs):
self.status = "connecting"

await self.rpc.start()
await self.server

if timeout is no_default:
timeout = self._timeout
Expand Down Expand Up @@ -1289,7 +1289,7 @@ async def _start(self, timeout=no_default, **kwargs):
self._gather_semaphore = asyncio.Semaphore(5)

if self.scheduler is None:
self.scheduler = self.rpc(address)
self.scheduler = self.server.rpc(address)
self.scheduler_comm = None

try:
Expand All @@ -1306,7 +1306,9 @@ async def _start(self, timeout=no_default, **kwargs):

await self.preloads.start()

self._handle_report_task = asyncio.create_task(self._handle_report())
self._handle_report_task = asyncio.create_task(
self.server.handle_stream(self.scheduler_comm.comm)
)

return self

Expand Down Expand Up @@ -1355,9 +1357,7 @@ async def _ensure_connected(self, timeout=None):
self._connecting_to_scheduler = True

try:
comm = await connect(
self.scheduler.address, timeout=timeout, **self.connection_args
)
comm = await self.server.rpc.connect(self.scheduler.address)
comm.name = "Client->Scheduler"
if timeout is not None:
await wait_for(self._update_scheduler_info(), timeout)
Expand Down Expand Up @@ -1543,63 +1543,6 @@ def _release_key(self, key):
{"op": "client-releases-keys", "keys": [key], "client": self.id}
)

@log_errors
async def _handle_report(self):
"""Listen to scheduler"""
try:
while True:
if self.scheduler_comm is None:
break
try:
msgs = await self.scheduler_comm.comm.read()
except CommClosedError:
if is_python_shutting_down():
return
if self.status == "running":
if self.cluster and self.cluster.status in (
Status.closed,
Status.closing,
):
# Don't attempt to reconnect if cluster are already closed.
# Instead close down the client.
await self._close()
return
logger.info("Client report stream closed to scheduler")
logger.info("Reconnecting...")
self.status = "connecting"
await self._reconnect()
continue
else:
break
if not isinstance(msgs, (list, tuple)):
msgs = (msgs,)

breakout = False
for msg in msgs:
logger.debug("Client receives message %s", msg)

if "status" in msg and "error" in msg["status"]:
typ, exc, tb = clean_exception(**msg)
raise exc.with_traceback(tb)

op = msg.pop("op")

if op == "close" or op == "stream-closed":
breakout = True
break

try:
handler = self._stream_handlers[op]
result = handler(**msg)
if inspect.isawaitable(result):
await result
except Exception as e:
logger.exception(e)
if breakout:
break
except (CancelledError, asyncio.CancelledError):
pass

def _handle_key_in_memory(self, key=None, type=None, workers=None):
state = self.futures.get(key)
if state is not None:
Expand Down Expand Up @@ -1707,29 +1650,19 @@ async def _close(self, fast=False):
self._send_to_scheduler({"op": "close-client"})
self._send_to_scheduler({"op": "close-stream"})
async with self._wait_for_handle_report_task(fast=fast):
if (
self.scheduler_comm
and self.scheduler_comm.comm
and not self.scheduler_comm.comm.closed()
):
await self.scheduler_comm.close()

for key in list(self.futures):
self._release_key(key=key)

if self._start_arg is None:
with suppress(AttributeError):
await self.cluster.close()

await self.rpc.close()

self.status = "closed"
await self.server.close()

if _get_global_client() is self:
_set_global_client(None)
self.status = "closed"

with suppress(AttributeError):
await self.scheduler.close_rpc()
if _get_global_client() is self:
_set_global_client(None)

self.scheduler = None
self.status = "closed"
Expand Down