Skip to content

Commit

Permalink
Merge branch 'main' into AMM/RetireWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 17, 2021
2 parents 8cd64f2 + ef5b088 commit 035bf98
Show file tree
Hide file tree
Showing 19 changed files with 612 additions and 282 deletions.
2 changes: 1 addition & 1 deletion continuous_integration/gpuci/axis.yaml
Expand Up @@ -2,7 +2,7 @@ PYTHON_VER:
- "3.8"

CUDA_VER:
- "11.2"
- "11.5"

LINUX_VER:
- ubuntu18.04
Expand Down
4 changes: 2 additions & 2 deletions continuous_integration/gpuci/build.sh
Expand Up @@ -41,7 +41,7 @@ gpuci_logger "Install dask"
python -m pip install git+https://github.com/dask/dask

gpuci_logger "Install distributed"
python setup.py install
python -m pip install -e .

gpuci_logger "Check Python versions"
python --version
Expand All @@ -52,4 +52,4 @@ conda config --show-sources
conda list --show-channel-urls

gpuci_logger "Python py.test for distributed"
py.test $WORKSPACE/distributed -v -m gpu --runslow --junitxml="$WORKSPACE/junit-distributed.xml"
py.test distributed -v -m gpu --runslow --junitxml="$WORKSPACE/junit-distributed.xml"
105 changes: 79 additions & 26 deletions distributed/client.py
Expand Up @@ -16,7 +16,7 @@
import warnings
import weakref
from collections import defaultdict
from collections.abc import Awaitable, Collection, Iterator
from collections.abc import Collection, Iterator
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import DoneAndNotDoneFutures
from contextlib import contextmanager, suppress
Expand Down Expand Up @@ -2387,7 +2387,14 @@ def run_on_scheduler(self, function, *args, **kwargs):
return self.sync(self._run_on_scheduler, function, *args, **kwargs)

async def _run(
self, function, *args, nanny=False, workers=None, wait=True, **kwargs
self,
function,
*args,
nanny: bool = False,
workers: list[str] | None = None,
wait: bool = True,
on_error: Literal["raise", "return", "ignore"] = "raise",
**kwargs,
):
responses = await self.scheduler.broadcast(
msg=dict(
Expand All @@ -2399,18 +2406,46 @@ async def _run(
),
workers=workers,
nanny=nanny,
on_error="return_pickle",
)
results = {}
for key, resp in responses.items():
if resp["status"] == "OK":
results[key] = resp["result"]
if isinstance(resp, bytes):
# Pickled RPC exception
exc = loads(resp)
assert isinstance(exc, Exception)
elif resp["status"] == "error":
typ, exc, tb = clean_exception(**resp)
raise exc.with_traceback(tb)
# Exception raised by the remote function
_, exc, tb = clean_exception(**resp)
exc = exc.with_traceback(tb)
else:
assert resp["status"] == "OK"
results[key] = resp["result"]
continue

if on_error == "raise":
raise exc
elif on_error == "return":
results[key] = exc
elif on_error != "ignore":
raise ValueError(
"on_error must be 'raise', 'return', or 'ignore'; "
f"got {on_error!r}"
)

if wait:
return results

def run(self, function, *args, **kwargs):
def run(
self,
function,
*args,
workers: list[str] | None = None,
wait: bool = True,
nanny: bool = False,
on_error: Literal["raise", "return", "ignore"] = "raise",
**kwargs,
):
"""
Run a function on all workers outside of task scheduling system
Expand All @@ -2437,6 +2472,17 @@ def run(self, function, *args, **kwargs):
Whether to run ``function`` on the nanny. By default, the function
is run on the worker process. If specified, the addresses in
``workers`` should still be the worker addresses, not the nanny addresses.
on_error: "raise" | "return" | "ignore"
If the function raises an error on a worker:
raise
(default) Re-raise the exception on the client.
The output from other workers will be lost.
return
Return the Exception object instead of the function output for
the worker
ignore
Ignore the exception and remove the worker from the result dict
Examples
--------
Expand Down Expand Up @@ -2469,7 +2515,16 @@ def run(self, function, *args, **kwargs):
>>> c.run(print_state, wait=False) # doctest: +SKIP
"""
return self.sync(self._run, function, *args, **kwargs)
return self.sync(
self._run,
function,
*args,
workers=workers,
wait=wait,
nanny=nanny,
on_error=on_error,
**kwargs,
)

@_deprecated(use_instead="Client.run which detects async functions automatically")
def run_coroutine(self, function, *args, **kwargs):
Expand Down Expand Up @@ -3487,20 +3542,23 @@ async def _dump_cluster_state(

scheduler_info = self.scheduler.dump_state()

worker_info = self.scheduler.broadcast(
msg=dict(
op="dump_state",
exclude=exclude,
),
workers_info = self.scheduler.broadcast(
msg={"op": "dump_state", "exclude": exclude},
on_error="return_pickle",
)
versions = self._get_versions()
scheduler_info, worker_info, versions_info = await asyncio.gather(
scheduler_info, worker_info, versions
versions_info = self._get_versions()
scheduler_info, workers_info, versions_info = await asyncio.gather(
scheduler_info, workers_info, versions_info
)
# Unpickle RPC errors and convert them to string
workers_info = {
k: repr(loads(v)) if isinstance(v, bytes) else v
for k, v in workers_info.items()
}

state = {
"scheduler": scheduler_info,
"workers": worker_info,
"workers": workers_info,
"versions": versions_info,
}

Expand Down Expand Up @@ -3546,7 +3604,7 @@ def dump_cluster_state(
filename: str = "dask-cluster-dump",
exclude: Collection[str] = (),
format: Literal["msgpack", "yaml"] = "msgpack",
) -> Awaitable | None:
):
"""Extract a dump of the entire cluster state and persist to disk.
This is intended for debugging purposes only.
Expand Down Expand Up @@ -3864,15 +3922,10 @@ def get_versions(self, check=False, packages=[]):

async def _get_versions(self, check=False, packages=[]):
client = version_module.get_versions(packages=packages)
try:
scheduler = await self.scheduler.versions(packages=packages)
except KeyError:
scheduler = None
except TypeError: # packages keyword not supported
scheduler = await self.scheduler.versions() # this raises

scheduler = await self.scheduler.versions(packages=packages)
workers = await self.scheduler.broadcast(
msg={"op": "versions", "packages": packages}
msg={"op": "versions", "packages": packages},
on_error="ignore",
)
result = {"scheduler": scheduler, "workers": workers, "client": client}

Expand Down
15 changes: 8 additions & 7 deletions distributed/comm/tests/test_ucx_config.py
@@ -1,3 +1,4 @@
import os
from time import sleep

import pytest
Expand Down Expand Up @@ -96,17 +97,16 @@ async def test_ucx_config(cleanup):
@pytest.mark.flaky(
reruns=10, reruns_delay=5, condition=ucp.get_ucx_version() < (1, 11, 0)
)
def test_ucx_config_w_env_var(cleanup, loop, monkeypatch):
size = "1000.00 MB"
monkeypatch.setenv("DASK_RMM__POOL_SIZE", size)

dask.config.refresh()
def test_ucx_config_w_env_var(cleanup, loop):
env = os.environ.copy()
env["DASK_RMM__POOL_SIZE"] = "1000.00 MB"

port = "13339"
sched_addr = f"ucx://{HOST}:{port}"

with popen(
["dask-scheduler", "--no-dashboard", "--protocol", "ucx", "--port", port]
["dask-scheduler", "--no-dashboard", "--protocol", "ucx", "--port", port],
env=env,
) as sched:
with popen(
[
Expand All @@ -116,7 +116,8 @@ def test_ucx_config_w_env_var(cleanup, loop, monkeypatch):
"--protocol",
"ucx",
"--no-nanny",
]
],
env=env,
):
with Client(sched_addr, loop=loop, timeout=10) as c:
while not c.scheduler_info()["workers"]:
Expand Down

0 comments on commit 035bf98

Please sign in to comment.