Skip to content

Commit

Permalink
Merge branch 'main' into AMM/RetireWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 21, 2021
2 parents 02fb4a3 + 7d1401a commit b955440
Show file tree
Hide file tree
Showing 10 changed files with 486 additions and 118 deletions.
17 changes: 17 additions & 0 deletions distributed/comm/tests/test_ucx.py
Expand Up @@ -2,6 +2,8 @@

import pytest

import dask

pytestmark = pytest.mark.gpu

ucp = pytest.importorskip("ucp")
Expand All @@ -10,6 +12,7 @@
from distributed.comm import connect, listen, parse_address, ucx
from distributed.comm.registry import backends, get_backend
from distributed.deploy.local import LocalCluster
from distributed.diagnostics.nvml import has_cuda_context
from distributed.protocol import to_serialize
from distributed.utils_test import inc

Expand Down Expand Up @@ -326,6 +329,20 @@ async def test_simple():
assert await client.submit(lambda x: x + 1, 10) == 11


@pytest.mark.asyncio
async def test_cuda_context():
with dask.config.set({"distributed.comm.ucx.create_cuda_context": True}):
async with LocalCluster(
protocol="ucx", n_workers=1, asynchronous=True
) as cluster:
async with Client(cluster, asynchronous=True) as client:
assert cluster.scheduler_address.startswith("ucx://")
assert has_cuda_context() == 0
worker_cuda_context = await client.run(has_cuda_context)
assert len(worker_cuda_context) == 1
assert list(worker_cuda_context.values())[0] == 0


@pytest.mark.asyncio
async def test_transpose():
da = pytest.importorskip("dask.array")
Expand Down
13 changes: 13 additions & 0 deletions distributed/comm/tests/test_ucx_config.py
Expand Up @@ -79,6 +79,19 @@ async def test_ucx_config(cleanup):
assert ucx_config.get("TLS") == "rc,tcp,rdmacm,cuda_copy"
assert ucx_config.get("SOCKADDR_TLS_PRIORITY") == "rdmacm"

ucx = {
"nvlink": None,
"infiniband": None,
"rdmacm": None,
"net-devices": None,
"tcp": None,
"cuda_copy": None,
}

with dask.config.set({"distributed.comm.ucx": ucx}):
ucx_config = _scrub_ucx_config()
assert ucx_config == {}


@pytest.mark.flaky(
reruns=10, reruns_delay=5, condition=ucp.get_ucx_version() < (1, 11, 0)
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/ucx.py
Expand Up @@ -70,7 +70,9 @@ def init_once():
# We ensure the CUDA context is created before initializing UCX. This can't
# be safely handled externally because communications in Dask start before
# preload scripts run.
if "TLS" in ucx_config and "cuda_copy" in ucx_config["TLS"]:
if dask.config.get("distributed.comm.ucx.create_cuda_context") is True or (
"TLS" in ucx_config and "cuda_copy" in ucx_config["TLS"]
):
try:
import numba.cuda
except ImportError:
Expand Down
9 changes: 9 additions & 0 deletions distributed/distributed-schema.yaml
Expand Up @@ -834,6 +834,15 @@ properties:
introduced to resolve an issue with CUDA IPC that has been fixed in UCX 1.10, but
can cause establishing endpoints to be very slow, this is particularly noticeable in
clusters of more than a few dozen workers.
create-cuda-context:
type: [boolean, 'null']
description: |
Creates a CUDA context before UCX is initialized. This is necessary to enable UCX to
properly identify connectivity of GPUs with specialized networking hardware, such as
InfiniBand. This permits UCX to choose transports automatically, without specifying
additional variables for each transport, while ensuring optimal connectivity. When
``True``, a CUDA context will be created on the first device listed in
``CUDA_VISIBLE_DEVICES``.
websockets:
type: object
Expand Down
11 changes: 6 additions & 5 deletions distributed/distributed.yaml
Expand Up @@ -187,13 +187,14 @@ distributed:
socket-backlog: 2048
recent-messages-log-length: 0 # number of messages to keep for debugging
ucx:
cuda_copy: False # enable cuda-copy
tcp: False # enable tcp
nvlink: False # enable cuda_ipc
infiniband: False # enable Infiniband
rdmacm: False # enable RDMACM
cuda_copy: null # enable cuda-copy
tcp: null # enable tcp
nvlink: null # enable cuda_ipc
infiniband: null # enable Infiniband
rdmacm: null # enable RDMACM
net-devices: null # define what interface to use for UCX comm
reuse-endpoints: null # enable endpoint reuse
create-cuda-context: null # create CUDA context before UCX initialization

zstd:
level: 3 # Compression level, between 1 and 22.
Expand Down
132 changes: 132 additions & 0 deletions distributed/tests/test_cancelled_state.py
@@ -1,7 +1,10 @@
import asyncio
from unittest import mock

import pytest

import distributed
from distributed import Worker
from distributed.core import CommClosedError
from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc

Expand Down Expand Up @@ -131,3 +134,132 @@ async def wait_and_raise(*args, **kwargs):
assert any("missing-dep" in msg for msg in b_story)
assert any("cancelled" in msg for msg in b_story)
assert any("resumed" in msg for msg in b_story)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_executing_cancelled_error(c, s, w):
"""One worker with one thread. We provoke an executing->cancelled transition
and let the task err. This test ensures that there is no residual state
(e.g. a semaphore) left blocking the thread"""
lock = distributed.Lock()
await lock.acquire()

async def wait_and_raise(*args, **kwargs):
async with lock:
raise RuntimeError()

fut = c.submit(wait_and_raise)
await wait_for_state(fut.key, "executing", w)

fut.release()
await wait_for_state(fut.key, "cancelled", w)
await lock.release()

# At this point we do not fetch the result of the future since the future
# itself would raise a cancelled exception. At this point we're concerned
# about the worker. The task should transition over error to be eventually
# forgotten since we no longer hold a ref.
while fut.key in w.tasks:
await asyncio.sleep(0.01)

# Everything should still be executing as usual after this
await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10)))

# Everything above this line should be generically true, regardless of
# refactoring. Below verifies some implementation specific test assumptions

story = w.story(fut.key)
start_finish = [(msg[1], msg[2], msg[3]) for msg in story if len(msg) == 7]
assert ("executing", "released", "cancelled") in start_finish
assert ("cancelled", "error", "error") in start_finish
assert ("error", "released", "released") in start_finish


@gen_cluster(client=True)
async def test_flight_cancelled_error(c, s, a, b):
"""One worker with one thread. We provoke an flight->cancelled transition
and let the task err."""
lock = asyncio.Lock()
await lock.acquire()

async def wait_and_raise(*args, **kwargs):
async with lock:
raise RuntimeError()

with mock.patch.object(
distributed.worker,
"get_data_from_worker",
side_effect=wait_and_raise,
):
fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
fut2 = c.submit(inc, fut1, workers=[b.address])

await wait_for_state(fut1.key, "flight", b)
fut2.release()
fut1.release()
await wait_for_state(fut1.key, "cancelled", b)

lock.release()
# At this point we do not fetch the result of the future since the future
# itself would raise a cancelled exception. At this point we're concerned
# about the worker. The task should transition over error to be eventually
# forgotten since we no longer hold a ref.
while fut1.key in b.tasks:
await asyncio.sleep(0.01)

# Everything should still be executing as usual after this
await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10)))


class LargeButForbiddenSerialization:
def __reduce__(self):
raise RuntimeError("I will never serialize!")

def __sizeof__(self) -> int:
"""Ensure this is immediately tried to spill"""
return 1_000_000_000_000


def test_ensure_spilled_immediately(tmpdir):
"""See also test_value_raises_during_spilling"""
import sys

from distributed.spill import SpillBuffer

mem_target = 1000
buf = SpillBuffer(tmpdir, target=mem_target)
buf["key"] = 1

obj = LargeButForbiddenSerialization()
assert sys.getsizeof(obj) > mem_target
with pytest.raises(
TypeError,
match=f"Could not serialize object of type {LargeButForbiddenSerialization.__name__}",
):
buf["error"] = obj


@gen_cluster(client=True, nthreads=[])
async def test_value_raises_during_spilling(c, s):
"""See also test_ensure_spilled_immediately"""

# Use a worker with a default memory limit
async with Worker(
s.address,
) as w:

def produce_evil_data():
return LargeButForbiddenSerialization()

fut = c.submit(produce_evil_data)

await wait_for_state(fut.key, "error", w)

with pytest.raises(
TypeError,
match=f"Could not serialize object of type {LargeButForbiddenSerialization.__name__}",
):
await fut

# Everything should still be executing as usual after this
await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10)))
90 changes: 83 additions & 7 deletions distributed/tests/test_worker.py
Expand Up @@ -1816,11 +1816,11 @@ async def test_story_with_deps(c, s, a, b):
# This is a simple transition log
expected_story = [
(key, "compute-task"),
(key, "released", "waiting", {dep.key: "fetch"}),
(key, "waiting", "ready", {}),
(key, "ready", "executing", {}),
(key, "released", "waiting", "waiting", {dep.key: "fetch"}),
(key, "waiting", "ready", "ready", {}),
(key, "ready", "executing", "executing", {}),
(key, "put-in-memory"),
(key, "executing", "memory", {}),
(key, "executing", "memory", "memory", {}),
]
assert pruned_story == expected_story

Expand All @@ -1842,13 +1842,13 @@ async def test_story_with_deps(c, s, a, b):
assert isinstance(stimulus_id, str)
expected_story = [
(dep_story, "ensure-task-exists", "released"),
(dep_story, "released", "fetch", {}),
(dep_story, "released", "fetch", "fetch", {}),
(
"gather-dependencies",
a.address,
{dep.key},
),
(dep_story, "fetch", "flight", {}),
(dep_story, "fetch", "flight", "flight", {}),
(
"request-dep",
a.address,
Expand All @@ -1860,7 +1860,7 @@ async def test_story_with_deps(c, s, a, b):
{dep.key},
),
(dep_story, "put-in-memory"),
(dep_story, "flight", "memory", {res.key: "ready"}),
(dep_story, "flight", "memory", "memory", {res.key: "ready"}),
]
assert pruned_story == expected_story

Expand Down Expand Up @@ -3095,6 +3095,82 @@ async def _wait_for_state(key: str, worker: Worker, state: str):
await asyncio.sleep(0)


@gen_cluster(client=True)
async def test_gather_dep_cancelled_rescheduled(c, s, a, b):
"""At time of writing, the gather_dep implementation filtered tasks again
for in-flight state. The response parser, however, did not distinguish
resulting in unwanted missing-data signals to the scheduler, causing
potential rescheduling or data leaks.
If a cancelled key is rescheduled for fetching while gather_dep waits
internally for get_data, the response parser would misclassify this key and
cause the key to be recommended for a release causing deadlocks and/or lost
keys.
At time of writing, this transition was implemented wrongly and caused a
flight->cancelled transition which should be recoverable but the cancelled
state was corrupted by this transition since ts.done==True. This attribute
setting would cause a cancelled->fetch transition to actually drop the key
instead, causing https://github.com/dask/distributed/issues/5366
See also test_gather_dep_do_not_handle_response_of_not_requested_tasks
"""
import distributed

with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather:
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[a.address], key="f2")
await fut2
fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4")
fut3 = c.submit(inc, fut1, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_state(fut2_key, b, "flight")
while not mocked_gather.call_args:
await asyncio.sleep(0)

fut4.release()
while fut4.key in b.tasks:
await asyncio.sleep(0)

assert b.tasks[fut2.key].state == "cancelled"
args, kwargs = mocked_gather.call_args
assert fut2.key in kwargs["to_gather"]

# The below synchronization and mock structure allows us to intercept the
# state after gather_dep has been scheduled and is waiting for the
# get_data_from_worker to finish. If state transitions happen during this
# time, the response parser needs to handle this properly
lock = asyncio.Lock()
event = asyncio.Event()
async with lock:

async def wait_get_data(*args, **kwargs):
event.set()
async with lock:
return await distributed.worker.get_data_from_worker(*args, **kwargs)

with mock.patch.object(
distributed.worker,
"get_data_from_worker",
side_effect=wait_get_data,
):
gather_dep_fut = asyncio.ensure_future(
Worker.gather_dep(b, *args, **kwargs)
)

await event.wait()

fut4 = c.submit(sum, [fut1, fut2], workers=[b.address], key="f4")
while b.tasks[fut2.key].state != "flight":
await asyncio.sleep(0.1)
await gather_dep_fut
f2_story = b.story(fut2.key)
assert f2_story
await fut3
await fut4


@gen_cluster(client=True)
async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand Down

0 comments on commit b955440

Please sign in to comment.