Skip to content

Commit

Permalink
Merge branch 'main' into AMM/RetireWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 3, 2021
2 parents 3de46eb + e2e2dda commit 51781ab
Show file tree
Hide file tree
Showing 14 changed files with 259 additions and 157 deletions.
2 changes: 1 addition & 1 deletion continuous_integration/gpuci/axis.yaml
Expand Up @@ -8,6 +8,6 @@ LINUX_VER:
- ubuntu18.04

RAPIDS_VER:
- "21.12"
- "22.02"

excludes:
41 changes: 27 additions & 14 deletions distributed/client.py
Expand Up @@ -16,15 +16,15 @@
import warnings
import weakref
from collections import defaultdict
from collections.abc import Iterator
from collections.abc import Awaitable, Collection, Iterator
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import DoneAndNotDoneFutures
from contextlib import contextmanager, suppress
from contextvars import ContextVar
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import TYPE_CHECKING, Awaitable, ClassVar, Sequence
from typing import TYPE_CHECKING, ClassVar

from tlz import first, groupby, keymap, merge, partition_all, valmap

Expand Down Expand Up @@ -3481,8 +3481,8 @@ def scheduler_info(self, **kwargs):
async def _dump_cluster_state(
self,
filename: str,
exclude: Sequence[str] = None,
format: Literal["msgpack"] | Literal["yaml"] = "msgpack",
exclude: Collection[str],
format: Literal["msgpack", "yaml"],
) -> None:

scheduler_info = self.scheduler.dump_state()
Expand All @@ -3503,23 +3503,36 @@ async def _dump_cluster_state(
"workers": worker_info,
"versions": versions_info,
}

def tuple_to_list(node):
if isinstance(node, (list, tuple)):
return [tuple_to_list(el) for el in node]
elif isinstance(node, dict):
return {k: tuple_to_list(v) for k, v in node.items()}
else:
return node

# lists are converted to tuples by the RPC
state = tuple_to_list(state)

filename = str(filename)
if format == "msgpack":
suffix = ".msgpack.gz"
if not filename.endswith(suffix):
filename += suffix
import gzip

import msgpack
import yaml

suffix = ".msgpack.gz"
if not filename.endswith(suffix):
filename += suffix

with gzip.open(filename, "wb") as fdg:
msgpack.pack(state, fdg)
elif format == "yaml":
import yaml

suffix = ".yaml"
if not filename.endswith(suffix):
filename += suffix
import yaml

with open(filename, "w") as fd:
yaml.dump(state, fd)
Expand All @@ -3531,8 +3544,8 @@ async def _dump_cluster_state(
def dump_cluster_state(
self,
filename: str = "dask-cluster-dump",
exclude: Sequence[str] = None,
format: Literal["msgpack"] | Literal["yaml"] = "msgpack",
exclude: Collection[str] = (),
format: Literal["msgpack", "yaml"] = "msgpack",
) -> Awaitable | None:
"""Extract a dump of the entire cluster state and persist to disk.
This is intended for debugging purposes only.
Expand All @@ -3549,13 +3562,13 @@ def dump_cluster_state(
}
}
Paramters
---------
Parameters
----------
filename:
The output filename. The appropriate file suffix (`.msgpack.gz` or
`.yaml`) will be appended automatically.
exclude:
A sequence of attribute names which are supposed to be blacklisted
A collection of attribute names which are supposed to be blacklisted
from the dump, e.g. to exclude code, tracebacks, logs, etc.
format:
Either msgpack or yaml. If msgpack is used (default), the output
Expand Down
9 changes: 9 additions & 0 deletions distributed/comm/tests/test_ucx.py
Expand Up @@ -361,3 +361,12 @@ async def test_transpose():
async def test_ucx_protocol(cleanup, port):
async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s:
assert s.address.startswith("ucx://")


@pytest.mark.skipif(
not hasattr(ucp.exceptions, "UCXUnreachable"),
reason="Requires UCX-Py support for UCXUnreachable exception",
)
def test_ucx_unreachable():
with pytest.raises(OSError, match="Timed out trying to connect to"):
Client("ucx://255.255.255.255:12345", timeout=1)
1 change: 1 addition & 0 deletions distributed/comm/ucx.py
Expand Up @@ -420,6 +420,7 @@ async def connect(self, address: str, deserialize=True, **connection_args) -> UC
except (ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled,) + (
getattr(ucp.exceptions, "UCXConnectionReset", ()),
getattr(ucp.exceptions, "UCXNotConnected", ()),
getattr(ucp.exceptions, "UCXUnreachable", ()),
):
raise CommClosedError("Connection closed before handshake completed")
return self.comm_class(
Expand Down
5 changes: 2 additions & 3 deletions distributed/core.py
Expand Up @@ -378,8 +378,8 @@ def identity(self, comm=None) -> dict[str, str]:
return {"type": type(self).__name__, "id": self.id}

def _to_dict(
self, comm: Comm = None, *, exclude: Container[str] = None
) -> dict[str, str]:
self, comm: Comm | None = None, *, exclude: Container[str] = ()
) -> dict:
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.
Expand All @@ -395,7 +395,6 @@ def _to_dict(
Server.identity
Client.dump_cluster_state
"""

info = self.identity()
extra = {
"address": self.address,
Expand Down
41 changes: 23 additions & 18 deletions distributed/scheduler.py
Expand Up @@ -26,7 +26,7 @@
from datetime import timedelta
from functools import partial
from numbers import Number
from typing import Any, ClassVar, Container
from typing import ClassVar, Container
from typing import cast as pep484_cast

import psutil
Expand Down Expand Up @@ -1732,10 +1732,10 @@ def get_nbytes_deps(self):
return nbytes

@ccall
def _to_dict(self, *, exclude: Container[str] = None):
def _to_dict(self, *, exclude: "Container[str]" = ()): # -> dict
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.
Not type stable and not intended for roundtrips.
Parameters
----------
Expand All @@ -1746,12 +1746,13 @@ def _to_dict(self, *, exclude: Container[str] = None):
--------
Client.dump_cluster_state
"""

if not exclude:
exclude = set()
members = inspect.getmembers(self)
return recursive_to_dict(
{k: v for k, v in members if k not in exclude and not callable(v)},
{
k: v
for k, v in members
if not k.startswith("_") and k not in exclude and not callable(v)
},
exclude=exclude,
)

Expand Down Expand Up @@ -3977,8 +3978,8 @@ def identity(self, comm=None):
return d

def _to_dict(
self, comm: Comm = None, *, exclude: Container[str] = None
) -> "dict[str, Any]":
self, comm: "Comm | None" = None, *, exclude: "Container[str]" = ()
) -> dict:
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.
Expand All @@ -3994,20 +3995,16 @@ def _to_dict(
Server.identity
Client.dump_cluster_state
"""

info = super()._to_dict(exclude=exclude)
extra = {
"transition_log": self.transition_log,
"log": self.log,
"tasks": self.tasks,
"events": self.events,
"extensions": self.extensions,
}
info.update(extra)
extensions = {}
for name, ex in self.extensions.items():
if hasattr(ex, "_to_dict"):
extensions[name] = ex._to_dict()
return recursive_to_dict(info, exclude=exclude)
info.update(recursive_to_dict(extra, exclude=exclude))
return info

def get_worker_service_addr(self, worker, service_name, protocol=False):
"""
Expand Down Expand Up @@ -5678,7 +5675,9 @@ def remove_plugin(
f"Could not find plugin {name!r} among the current scheduler plugins"
)

async def register_scheduler_plugin(self, comm=None, plugin=None, name=None):
async def register_scheduler_plugin(
self, comm=None, plugin=None, name=None, idempotent=None
):
"""Register a plugin on the scheduler."""
if not dask.config.get("distributed.scheduler.pickle"):
raise ValueError(
Expand All @@ -5689,12 +5688,18 @@ async def register_scheduler_plugin(self, comm=None, plugin=None, name=None):
)
plugin = loads(plugin)

if name is None:
name = _get_plugin_name(plugin)

if name in self.plugins and idempotent:
return

if hasattr(plugin, "start"):
result = plugin.start(self)
if inspect.isawaitable(result):
await result

self.add_plugin(plugin, name=name)
self.add_plugin(plugin, name=name, idempotent=idempotent)

def worker_send(self, worker, msg):
"""Send message to worker
Expand Down
4 changes: 2 additions & 2 deletions distributed/stealing.py
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict, deque
from math import log2
from time import time
from typing import Any, Container
from typing import Container

from tlz import topk
from tornado.ioloop import PeriodicCallback
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(self, scheduler):

self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm

def _to_dict(self, *, exclude: Container[str] = None) -> dict[str, Any]:
def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.
Expand Down
13 changes: 4 additions & 9 deletions distributed/tests/test_client.py
Expand Up @@ -23,6 +23,7 @@

import psutil
import pytest
import yaml
from tlz import concat, first, identity, isdistinct, merge, pluck, valmap

import dask
Expand Down Expand Up @@ -7165,7 +7166,6 @@ async def test_dump_cluster_state_async(c, s, a, b, tmp_path, _format):

@gen_cluster(client=True)
async def test_dump_cluster_state_exclude(c, s, a, b, tmp_path):

futs = c.map(inc, range(10))
while len(s.tasks) != len(futs):
await asyncio.sleep(0.01)
Expand All @@ -7175,15 +7175,10 @@ async def test_dump_cluster_state_exclude(c, s, a, b, tmp_path):
"runspec",
]
filename = tmp_path / "foo"
await c.dump_cluster_state(
filename=filename,
format="yaml",
)

with open(str(filename) + ".yaml") as fd:
import yaml
await c.dump_cluster_state(filename=filename, format="yaml")

state = yaml.load(fd, Loader=yaml.Loader)
with open(f"{filename}.yaml") as fd:
state = yaml.safe_load(fd)

assert "workers" in state
assert len(state["workers"]) == len(s.workers)
Expand Down
55 changes: 51 additions & 4 deletions distributed/tests/test_scheduler.py
Expand Up @@ -3255,11 +3255,12 @@ async def test__to_dict(c, s, a, b):
futs = c.map(inc, range(100))

await c.gather(futs)
dct = Scheduler._to_dict(s)
assert list(dct.keys()) == [
d = Scheduler._to_dict(s)
assert d.keys() == {
"type",
"id",
"address",
"extensions",
"services",
"started",
"workers",
Expand All @@ -3269,5 +3270,51 @@ async def test__to_dict(c, s, a, b):
"log",
"tasks",
"events",
]
assert dct["tasks"][futs[0].key]
}
assert d["tasks"][futs[0].key]


@gen_cluster(nthreads=[])
async def test_idempotent_plugins(s):

from distributed.diagnostics.plugin import SchedulerPlugin

class IdempotentPlugin(SchedulerPlugin):
def __init__(self, instance=None):
self.name = "idempotentplugin"
self.instance = instance

def start(self, scheduler):
if self.instance != "first":
raise RuntimeError(
"Only the first plugin should be started when idempotent is set"
)

first = IdempotentPlugin(instance="first")
await s.register_scheduler_plugin(plugin=dumps(first), idempotent=True)
assert "idempotentplugin" in s.plugins

second = IdempotentPlugin(instance="second")
await s.register_scheduler_plugin(plugin=dumps(second), idempotent=True)
assert "idempotentplugin" in s.plugins
assert s.plugins["idempotentplugin"].instance == "first"


@gen_cluster(nthreads=[])
async def test_non_idempotent_plugins(s):

from distributed.diagnostics.plugin import SchedulerPlugin

class NonIdempotentPlugin(SchedulerPlugin):
def __init__(self, instance=None):
self.name = "nonidempotentplugin"
self.instance = instance

first = NonIdempotentPlugin(instance="first")
await s.register_scheduler_plugin(plugin=dumps(first), idempotent=False)
assert "nonidempotentplugin" in s.plugins

second = NonIdempotentPlugin(instance="second")
await s.register_scheduler_plugin(plugin=dumps(second), idempotent=False)
assert "nonidempotentplugin" in s.plugins
assert s.plugins["nonidempotentplugin"].instance == "second"

0 comments on commit 51781ab

Please sign in to comment.