Skip to content

Commit

Permalink
Refactor gather_dep
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 20, 2022
1 parent 4488144 commit 8a3d7d9
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 116 deletions.
35 changes: 24 additions & 11 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

from distributed import Client, Nanny, Scheduler, Worker, config, default_client
from distributed.compatibility import WINDOWS
from distributed.core import Server, rpc
from distributed.core import Server, Status, rpc
from distributed.metrics import time
from distributed.utils import mp_context
from distributed.utils_test import (
_LockedCommPool,
_UnhashableCallable,
assert_story,
captured_logger,
check_process_leak,
cluster,
dump_cluster_state,
Expand Down Expand Up @@ -731,15 +732,27 @@ def test_raises_with_cause():
raise RuntimeError("exception") from ValueError("cause")


def test_worker_fail_hard(capsys):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_fail_hard(c, s, a):
with pytest.raises(Exception):
await a.gather_dep(
worker="abcd", to_gather=["x"], total_nbytes=0, stimulus_id="foo"
)
@gen_cluster(nthreads=[("", 1)])
async def test_fail_hard(s, a):
with captured_logger("distributed.worker") as logger:
# Asynchronously kick off handle_acquire_replicas on the worker,
# which will fail
s.stream_comms[a.address].send(
{
"op": "acquire-replicas",
"who_has": {"x": ["abcd"]},
"stimulus_id": "foo",
},
)
while a.status != Status.closed:
await asyncio.sleep(0.01)

assert "missing port number in address 'abcd'" in logger.getvalue()

with pytest.raises(Exception) as info:
test_fail_hard()

assert "abcd" in str(info.value)
@gen_cluster(nthreads=[("", 1)])
async def test_fail_hard_reraises(s, a):
with pytest.raises(AttributeError):
a.handle_stimulus(None)
while a.status != Status.closed:
await asyncio.sleep(0.01)
22 changes: 14 additions & 8 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
from collections.abc import Iterator
from contextlib import contextmanager
from itertools import chain

import pytest

Expand All @@ -16,7 +18,6 @@
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
StateMachineEvent,
TaskState,
TaskStateState,
Expand Down Expand Up @@ -103,14 +104,19 @@ def test_unique_task_heap():
assert repr(heap) == "<UniqueTaskHeap: 0 items>"


def traverse_subclasses(cls: type) -> Iterator[type]:
yield cls
for subcls in cls.__subclasses__():
yield from traverse_subclasses(subcls)


@pytest.mark.parametrize(
"cls",
chain(
[UniqueTaskHeap],
Instruction.__subclasses__(),
SendMessageToScheduler.__subclasses__(),
StateMachineEvent.__subclasses__(),
),
[
UniqueTaskHeap,
*traverse_subclasses(Instruction),
*traverse_subclasses(StateMachineEvent),
],
)
def test_slots(cls):
params = [
Expand Down

0 comments on commit 8a3d7d9

Please sign in to comment.