diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index da4a725bd2e..a29aef17731 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -8,6 +8,6 @@ LINUX_VER: - ubuntu18.04 RAPIDS_VER: -- "21.12" +- "22.02" excludes: diff --git a/distributed/client.py b/distributed/client.py index 10dc3497273..c78f1dd20ba 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -16,7 +16,7 @@ import warnings import weakref from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Awaitable, Collection, Iterator from concurrent.futures import ThreadPoolExecutor from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager, suppress @@ -24,7 +24,7 @@ from functools import partial from numbers import Number from queue import Queue as pyQueue -from typing import TYPE_CHECKING, Awaitable, ClassVar, Sequence +from typing import TYPE_CHECKING, ClassVar from tlz import first, groupby, keymap, merge, partition_all, valmap @@ -3481,8 +3481,8 @@ def scheduler_info(self, **kwargs): async def _dump_cluster_state( self, filename: str, - exclude: Sequence[str] = None, - format: Literal["msgpack"] | Literal["yaml"] = "msgpack", + exclude: Collection[str], + format: Literal["msgpack", "yaml"], ) -> None: scheduler_info = self.scheduler.dump_state() @@ -3503,23 +3503,36 @@ async def _dump_cluster_state( "workers": worker_info, "versions": versions_info, } + + def tuple_to_list(node): + if isinstance(node, (list, tuple)): + return [tuple_to_list(el) for el in node] + elif isinstance(node, dict): + return {k: tuple_to_list(v) for k, v in node.items()} + else: + return node + + # lists are converted to tuples by the RPC + state = tuple_to_list(state) + filename = str(filename) if format == "msgpack": - suffix = ".msgpack.gz" - if not filename.endswith(suffix): - filename += suffix import gzip import msgpack - import yaml + + suffix = ".msgpack.gz" + if not filename.endswith(suffix): + filename += suffix with gzip.open(filename, "wb") as fdg: msgpack.pack(state, fdg) elif format == "yaml": + import yaml + suffix = ".yaml" if not filename.endswith(suffix): filename += suffix - import yaml with open(filename, "w") as fd: yaml.dump(state, fd) @@ -3531,8 +3544,8 @@ async def _dump_cluster_state( def dump_cluster_state( self, filename: str = "dask-cluster-dump", - exclude: Sequence[str] = None, - format: Literal["msgpack"] | Literal["yaml"] = "msgpack", + exclude: Collection[str] = (), + format: Literal["msgpack", "yaml"] = "msgpack", ) -> Awaitable | None: """Extract a dump of the entire cluster state and persist to disk. This is intended for debugging purposes only. @@ -3549,13 +3562,13 @@ def dump_cluster_state( } } - Paramters - --------- + Parameters + ---------- filename: The output filename. The appropriate file suffix (`.msgpack.gz` or `.yaml`) will be appended automatically. exclude: - A sequence of attribute names which are supposed to be blacklisted + A collection of attribute names which are supposed to be blacklisted from the dump, e.g. to exclude code, tracebacks, logs, etc. format: Either msgpack or yaml. If msgpack is used (default), the output diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 20b69297285..365a9821520 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -361,3 +361,12 @@ async def test_transpose(): async def test_ucx_protocol(cleanup, port): async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s: assert s.address.startswith("ucx://") + + +@pytest.mark.skipif( + not hasattr(ucp.exceptions, "UCXUnreachable"), + reason="Requires UCX-Py support for UCXUnreachable exception", +) +def test_ucx_unreachable(): + with pytest.raises(OSError, match="Timed out trying to connect to"): + Client("ucx://255.255.255.255:12345", timeout=1) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index a1fcbb5b8b5..c529d27e6a9 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -420,6 +420,7 @@ async def connect(self, address: str, deserialize=True, **connection_args) -> UC except (ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled,) + ( getattr(ucp.exceptions, "UCXConnectionReset", ()), getattr(ucp.exceptions, "UCXNotConnected", ()), + getattr(ucp.exceptions, "UCXUnreachable", ()), ): raise CommClosedError("Connection closed before handshake completed") return self.comm_class( diff --git a/distributed/core.py b/distributed/core.py index d589dcd176f..046f0801c09 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -378,8 +378,8 @@ def identity(self, comm=None) -> dict[str, str]: return {"type": type(self).__name__, "id": self.id} def _to_dict( - self, comm: Comm = None, *, exclude: Container[str] = None - ) -> dict[str, str]: + self, comm: Comm | None = None, *, exclude: Container[str] = () + ) -> dict: """ A very verbose dictionary representation for debugging purposes. Not type stable and not inteded for roundtrips. @@ -395,7 +395,6 @@ def _to_dict( Server.identity Client.dump_cluster_state """ - info = self.identity() extra = { "address": self.address, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 44a11aab48c..05ab0e8f908 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -26,7 +26,7 @@ from datetime import timedelta from functools import partial from numbers import Number -from typing import Any, ClassVar, Container +from typing import ClassVar, Container from typing import cast as pep484_cast import psutil @@ -1732,10 +1732,10 @@ def get_nbytes_deps(self): return nbytes @ccall - def _to_dict(self, *, exclude: Container[str] = None): + def _to_dict(self, *, exclude: "Container[str]" = ()): # -> dict """ A very verbose dictionary representation for debugging purposes. - Not type stable and not inteded for roundtrips. + Not type stable and not intended for roundtrips. Parameters ---------- @@ -1746,12 +1746,13 @@ def _to_dict(self, *, exclude: Container[str] = None): -------- Client.dump_cluster_state """ - - if not exclude: - exclude = set() members = inspect.getmembers(self) return recursive_to_dict( - {k: v for k, v in members if k not in exclude and not callable(v)}, + { + k: v + for k, v in members + if not k.startswith("_") and k not in exclude and not callable(v) + }, exclude=exclude, ) @@ -3977,8 +3978,8 @@ def identity(self, comm=None): return d def _to_dict( - self, comm: Comm = None, *, exclude: Container[str] = None - ) -> "dict[str, Any]": + self, comm: "Comm | None" = None, *, exclude: "Container[str]" = () + ) -> dict: """ A very verbose dictionary representation for debugging purposes. Not type stable and not inteded for roundtrips. @@ -3994,20 +3995,16 @@ def _to_dict( Server.identity Client.dump_cluster_state """ - info = super()._to_dict(exclude=exclude) extra = { "transition_log": self.transition_log, "log": self.log, "tasks": self.tasks, "events": self.events, + "extensions": self.extensions, } - info.update(extra) - extensions = {} - for name, ex in self.extensions.items(): - if hasattr(ex, "_to_dict"): - extensions[name] = ex._to_dict() - return recursive_to_dict(info, exclude=exclude) + info.update(recursive_to_dict(extra, exclude=exclude)) + return info def get_worker_service_addr(self, worker, service_name, protocol=False): """ @@ -5678,7 +5675,9 @@ def remove_plugin( f"Could not find plugin {name!r} among the current scheduler plugins" ) - async def register_scheduler_plugin(self, comm=None, plugin=None, name=None): + async def register_scheduler_plugin( + self, comm=None, plugin=None, name=None, idempotent=None + ): """Register a plugin on the scheduler.""" if not dask.config.get("distributed.scheduler.pickle"): raise ValueError( @@ -5689,12 +5688,18 @@ async def register_scheduler_plugin(self, comm=None, plugin=None, name=None): ) plugin = loads(plugin) + if name is None: + name = _get_plugin_name(plugin) + + if name in self.plugins and idempotent: + return + if hasattr(plugin, "start"): result = plugin.start(self) if inspect.isawaitable(result): await result - self.add_plugin(plugin, name=name) + self.add_plugin(plugin, name=name, idempotent=idempotent) def worker_send(self, worker, msg): """Send message to worker diff --git a/distributed/stealing.py b/distributed/stealing.py index 3f1d04697aa..e5fa43d72ee 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -4,7 +4,7 @@ from collections import defaultdict, deque from math import log2 from time import time -from typing import Any, Container +from typing import Container from tlz import topk from tornado.ioloop import PeriodicCallback @@ -82,7 +82,7 @@ def __init__(self, scheduler): self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm - def _to_dict(self, *, exclude: Container[str] = None) -> dict[str, Any]: + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """ A very verbose dictionary representation for debugging purposes. Not type stable and not inteded for roundtrips. diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1f1d27147f6..db58e82beb2 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -23,6 +23,7 @@ import psutil import pytest +import yaml from tlz import concat, first, identity, isdistinct, merge, pluck, valmap import dask @@ -7165,7 +7166,6 @@ async def test_dump_cluster_state_async(c, s, a, b, tmp_path, _format): @gen_cluster(client=True) async def test_dump_cluster_state_exclude(c, s, a, b, tmp_path): - futs = c.map(inc, range(10)) while len(s.tasks) != len(futs): await asyncio.sleep(0.01) @@ -7175,15 +7175,10 @@ async def test_dump_cluster_state_exclude(c, s, a, b, tmp_path): "runspec", ] filename = tmp_path / "foo" - await c.dump_cluster_state( - filename=filename, - format="yaml", - ) - - with open(str(filename) + ".yaml") as fd: - import yaml + await c.dump_cluster_state(filename=filename, format="yaml") - state = yaml.load(fd, Loader=yaml.Loader) + with open(f"{filename}.yaml") as fd: + state = yaml.safe_load(fd) assert "workers" in state assert len(state["workers"]) == len(s.workers) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f83146cde07..288e009f078 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3255,11 +3255,12 @@ async def test__to_dict(c, s, a, b): futs = c.map(inc, range(100)) await c.gather(futs) - dct = Scheduler._to_dict(s) - assert list(dct.keys()) == [ + d = Scheduler._to_dict(s) + assert d.keys() == { "type", "id", "address", + "extensions", "services", "started", "workers", @@ -3269,5 +3270,51 @@ async def test__to_dict(c, s, a, b): "log", "tasks", "events", - ] - assert dct["tasks"][futs[0].key] + } + assert d["tasks"][futs[0].key] + + +@gen_cluster(nthreads=[]) +async def test_idempotent_plugins(s): + + from distributed.diagnostics.plugin import SchedulerPlugin + + class IdempotentPlugin(SchedulerPlugin): + def __init__(self, instance=None): + self.name = "idempotentplugin" + self.instance = instance + + def start(self, scheduler): + if self.instance != "first": + raise RuntimeError( + "Only the first plugin should be started when idempotent is set" + ) + + first = IdempotentPlugin(instance="first") + await s.register_scheduler_plugin(plugin=dumps(first), idempotent=True) + assert "idempotentplugin" in s.plugins + + second = IdempotentPlugin(instance="second") + await s.register_scheduler_plugin(plugin=dumps(second), idempotent=True) + assert "idempotentplugin" in s.plugins + assert s.plugins["idempotentplugin"].instance == "first" + + +@gen_cluster(nthreads=[]) +async def test_non_idempotent_plugins(s): + + from distributed.diagnostics.plugin import SchedulerPlugin + + class NonIdempotentPlugin(SchedulerPlugin): + def __init__(self, instance=None): + self.name = "nonidempotentplugin" + self.instance = instance + + first = NonIdempotentPlugin(instance="first") + await s.register_scheduler_plugin(plugin=dumps(first), idempotent=False) + assert "nonidempotentplugin" in s.plugins + + second = NonIdempotentPlugin(instance="second") + await s.register_scheduler_plugin(plugin=dumps(second), idempotent=False) + assert "nonidempotentplugin" in s.plugins + assert s.plugins["nonidempotentplugin"].instance == "second" diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index f07975de8e7..4d3b786bbce 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -7,6 +7,7 @@ import queue import socket import traceback +from collections import deque from time import sleep import pytest @@ -37,6 +38,7 @@ open_port, parse_ports, read_block, + recursive_to_dict, seek_delimiter, set_thread_state, sync, @@ -633,3 +635,62 @@ async def my_async_callable(x, y, z): assert iscoroutinefunction( functools.partial(functools.partial(my_async_callable, 1), 2) ) + + +def test_recursive_to_dict(): + class C: + def __init__(self, x): + self.x = x + + def __repr__(self): + return "" + + def _to_dict(self, *, exclude): + assert exclude == ["foo"] + return ["C:", recursive_to_dict(self.x, exclude=exclude)] + + class D: + def __repr__(self): + return "" + + inp = [ + 1, + 1.1, + True, + False, + None, + "foo", + b"bar", + C, + C(1), + D(), + (1, 2), + [3, 4], + {5, 6}, + frozenset([7, 8]), + deque([9, 10]), + ] + expect = [ + 1, + 1.1, + True, + False, + None, + "foo", + "b'bar'", + ".C'>", + ["C:", 1], + "", + [1, 2], + [3, 4], + list({5, 6}), + list(frozenset([7, 8])), + [9, 10], + ] + assert recursive_to_dict(inp, exclude=["foo"]) == expect + + # Test recursion + a = [] + c = C(a) + a.append(c) + assert recursive_to_dict(a, exclude=["foo"]) == [["C:", "[]"]] diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 4b7a1c273c4..e1fa5ff69dd 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -152,20 +152,35 @@ async def test_gen_cluster_tls(e, s, a, b): assert s.nthreads == {w.address: w.nthreads for w in [a, b]} +@pytest.mark.xfail( + reason="Test should always fail to ensure the body of the test function was run", + strict=True, +) @gen_test() async def test_gen_test(): await asyncio.sleep(0.01) + assert False +@pytest.mark.xfail( + reason="Test should always fail to ensure the body of the test function was run", + strict=True, +) @gen_test() def test_gen_test_legacy_implicit(): yield asyncio.sleep(0.01) + assert False +@pytest.mark.xfail( + reason="Test should always fail to ensure the body of the test function was run", + strict=True, +) @gen_test() @gen.coroutine def test_gen_test_legacy_explicit(): yield asyncio.sleep(0.01) + assert False @contextmanager diff --git a/distributed/utils.py b/distributed/utils.py index e57bfeb98c0..b4bcfaf6f8e 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -22,19 +22,16 @@ from collections import OrderedDict, UserDict, deque from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 from contextlib import contextmanager, suppress +from contextvars import ContextVar from hashlib import md5 from importlib.util import cache_from_source from time import sleep -from typing import TYPE_CHECKING from typing import Any as AnyType -from typing import ClassVar, Container, Sequence, overload +from typing import ClassVar, Container import click import tblib.pickling_support -if TYPE_CHECKING: - from typing_extensions import Protocol - try: import resource except ImportError: @@ -1439,95 +1436,57 @@ def __getattr__(name): raise AttributeError(f"module {__name__} has no attribute {name}") -if TYPE_CHECKING: - - class SupportsToDict(Protocol): - def _to_dict( - self, *, exclude: Container[str] | None = None, **kwargs - ) -> dict[str, AnyType]: - ... - - -@overload -def recursive_to_dict( - obj: SupportsToDict, exclude: Container[str] = None, seen: set[AnyType] = None -) -> dict[str, AnyType]: - ... - - -@overload -def recursive_to_dict( - obj: Sequence, exclude: Container[str] = None, seen: set[AnyType] = None -) -> Sequence: - ... - - -@overload -def recursive_to_dict( - obj: dict, exclude: Container[str] = None, seen: set[AnyType] = None -) -> dict: - ... - - -@overload -def recursive_to_dict( - obj: None, exclude: Container[str] = None, seen: set[AnyType] = None -) -> None: - ... +# Used internally by recursive_to_dict to let the YAML exporter catch infinite +# recursion. If an object has already been encountered, a string representan will be +# returned instead. This is necessary since we have multiple cyclic referencing data +# structures. +_recursive_to_dict_seen: ContextVar[set[int]] = ContextVar("_recursive_to_dict_seen") -def recursive_to_dict(obj, exclude=None, seen=None): - """ - This is for debugging purposes only and calls ``_to_dict`` methods on ``obj`` or - it's elements recursively, if available. The output of this function is - intended to be json serializable. +def recursive_to_dict(obj: AnyType, *, exclude: Container[str] = ()) -> AnyType: + """Recursively convert arbitrary Python objects to a JSON-serializable + representation. This is intended for debugging purposes only and calls ``_to_dict`` + methods on encountered objects, if available. Parameters ---------- exclude: A list of attribute names to be excluded from the dump. This will be forwarded to the objects ``_to_dict`` methods and these methods - are required to ensure this. - seen: - Used internally to avoid infinite recursion. If an object has already - been encountered, it's representation will be generated instead of its - ``_to_dict``. This is necessary since we have multiple cyclic referencing - data structures. + are required to accept this parameter. """ - if obj is None: - return None - if isinstance(obj, str): + if isinstance(obj, (int, float, bool, str)) or obj is None: return obj - if seen is None: - seen = set() + if isinstance(obj, (type, bytes)): + return repr(obj) + + # Prevent infinite recursion + try: + seen = _recursive_to_dict_seen.get() + except LookupError: + tok = _recursive_to_dict_seen.set(set()) + try: + return recursive_to_dict(obj, exclude=exclude) + finally: + _recursive_to_dict_seen.reset(tok) + if id(obj) in seen: return repr(obj) seen.add(id(obj)) - if isinstance(obj, type): - return repr(obj) + if hasattr(obj, "_to_dict"): return obj._to_dict(exclude=exclude) - if isinstance(obj, (deque, set)): - obj = tuple(obj) - if isinstance(obj, (list, tuple)): - return tuple( - recursive_to_dict( - el, - exclude=exclude, - seen=seen, - ) - for el in obj - ) - elif isinstance(obj, dict): + if isinstance(obj, (list, tuple, set, frozenset, deque)): + return [recursive_to_dict(el, exclude=exclude) for el in obj] + if isinstance(obj, dict): res = {} for k, v in obj.items(): - k = recursive_to_dict(k, exclude=exclude, seen=seen) + k = recursive_to_dict(k, exclude=exclude) + v = recursive_to_dict(v, exclude=exclude) try: - hash(k) + res[k] = v except TypeError: - k = str(k) - v = recursive_to_dict(v, exclude=exclude, seen=seen) - res[k] = v + res[str(k)] = v return res - else: - return repr(obj) + + return repr(obj) diff --git a/distributed/worker.py b/distributed/worker.py index 4e43727e1dc..e1c1446570c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -225,7 +225,7 @@ def get_nbytes(self) -> int: nbytes = self.nbytes return nbytes if nbytes is not None else DEFAULT_DATA_SIZE - def _to_dict(self, *, exclude: Container[str] = None) -> dict[str, Any]: + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """ A very verbose dictionary representation for debugging purposes. Not type stable and not inteded for roundtrips. @@ -240,16 +240,8 @@ def _to_dict(self, *, exclude: Container[str] = None) -> dict[str, Any]: -------- Client.dump_cluster_state """ - - if exclude is None: - exclude = set() - return recursive_to_dict( - { - attr: getattr(self, attr) - for attr in self.__dict__.keys() - if attr not in exclude - }, + {k: v for k, v in self.__dict__.items() if k not in exclude}, exclude=exclude, ) @@ -1123,8 +1115,8 @@ def identity(self, comm=None): } def _to_dict( - self, comm: Comm = None, *, exclude: Container[str] = None - ) -> dict[str, Any]: + self, comm: Comm | None = None, *, exclude: Container[str] = () + ) -> dict: """ A very verbose dictionary representation for debugging purposes. Not type stable and not inteded for roundtrips. @@ -1156,9 +1148,9 @@ def _to_dict( "memory_spill_fraction": self.memory_spill_fraction, "memory_pause_fraction": self.memory_pause_fraction, "logs": self.get_logs(), - "config": dict(dask.config.config), - "incoming_transfer_log": list(self.incoming_transfer_log), - "outgoing_transfer_log": list(self.outgoing_transfer_log), + "config": dask.config.config, + "incoming_transfer_log": self.incoming_transfer_log, + "outgoing_transfer_log": self.outgoing_transfer_log, } info.update(extra) return recursive_to_dict(info, exclude=exclude) diff --git a/docs/source/api.rst b/docs/source/api.rst index b5d66b759c4..309c1c9a1b3 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -40,7 +40,7 @@ API .. autoautosummary:: distributed.Future :methods: -**Client Coordination** +**Synchronization** .. currentmodule:: distributed @@ -48,6 +48,7 @@ API Event Lock MultiLock + Semaphore Queue Variable @@ -122,6 +123,24 @@ Future .. autoclass:: Future :members: + +Synchronization +--------------- + +.. autoclass:: Event + :members: +.. autoclass:: Lock + :members: +.. autoclass:: MultiLock + :members: +.. autoclass:: Semaphore + :members: +.. autoclass:: Queue + :members: +.. autoclass:: Variable + :members: + + Cluster ------- @@ -168,19 +187,6 @@ Other .. autoclass:: get_task_metadata .. autoclass:: performance_report -.. autoclass:: Event - :members: -.. autoclass:: Lock - :members: -.. autoclass:: MultiLock - :members: -.. autoclass:: Semaphore - :members: -.. autoclass:: Queue - :members: -.. autoclass:: Variable - :members: - Utilities ---------