Skip to content

Commit

Permalink
Type annotations for Worker and gen_cluster (dask#5438)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored and zanieb committed Oct 28, 2021
1 parent e8b47eb commit 584db2c
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 129 deletions.
5 changes: 4 additions & 1 deletion distributed/profile.py
Expand Up @@ -24,12 +24,15 @@
'children': {...}}}
}
"""
from __future__ import annotations

import bisect
import linecache
import sys
import threading
from collections import defaultdict, deque
from time import sleep
from typing import Any

import tlz as toolz

Expand Down Expand Up @@ -152,7 +155,7 @@ def merge(*args):
}


def create():
def create() -> dict[str, Any]:
return {
"count": 0,
"children": {},
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_scheduler.py
Expand Up @@ -1947,7 +1947,7 @@ class NoSchedulerDelayWorker(Worker):
comparisons using times reported from workers.
"""

@property
@property # type: ignore
def scheduler_delay(self):
return 0

Expand Down
5 changes: 0 additions & 5 deletions distributed/tests/test_stress.py
Expand Up @@ -99,9 +99,6 @@ async def test_stress_creation_and_deletion(c, s):
# Assertions are handled by the validate mechanism in the scheduler
da = pytest.importorskip("dask.array")

def _disable_suspicious_counter(dask_worker):
dask_worker._suspicious_count_limit = None

rng = da.random.RandomState(0)
x = rng.random(size=(2000, 2000), chunks=(100, 100))
y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2)
Expand All @@ -111,14 +108,12 @@ async def create_and_destroy_worker(delay):
start = time()
while time() < start + 5:
async with Nanny(s.address, nthreads=2) as n:
await c.run(_disable_suspicious_counter, workers=[n.worker_address])
await asyncio.sleep(delay)
print("Killed nanny")

await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20)))

async with Nanny(s.address, nthreads=2):
await c.run(_disable_suspicious_counter)
assert await c.compute(z) == 8000884.93


Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker.py
Expand Up @@ -2968,7 +2968,7 @@ async def test_who_has_consistent_remove_replica(c, s, *workers):

await f2

assert ("missing-dep", f1.key) in a.story(f1.key)
assert (f1.key, "missing-dep") in a.story(f1.key)
assert a.tasks[f1.key].suspicious_count == 0
assert s.tasks[f1.key].suspicious == 0

Expand Down
58 changes: 34 additions & 24 deletions distributed/utils_test.py
Expand Up @@ -22,6 +22,7 @@
import warnings
import weakref
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager, nullcontext, suppress
from glob import glob
from itertools import count
Expand Down Expand Up @@ -54,6 +55,7 @@
from .diagnostics.plugin import WorkerPlugin
from .metrics import time
from .nanny import Nanny
from .node import ServerNode
from .proctitle import enable_proctitle_on_children
from .security import Security
from .utils import (
Expand Down Expand Up @@ -770,7 +772,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses))


def gen_test(timeout=_TEST_TIMEOUT):
def gen_test(timeout: float = _TEST_TIMEOUT) -> Callable[[Callable], Callable]:
"""Coroutine test
@gen_test(timeout=5)
Expand All @@ -797,14 +799,14 @@ def test_func():


async def start_cluster(
nthreads,
scheduler_addr,
loop,
security=None,
Worker=Worker,
scheduler_kwargs={},
worker_kwargs={},
):
nthreads: list[tuple[str, int] | tuple[str, int, dict]],
scheduler_addr: str,
loop: IOLoop,
security: Security | dict[str, Any] | None = None,
Worker: type[ServerNode] = Worker,
scheduler_kwargs: dict[str, Any] = {},
worker_kwargs: dict[str, Any] = {},
) -> tuple[Scheduler, list[ServerNode]]:
s = await Scheduler(
loop=loop,
validate=True,
Expand All @@ -813,6 +815,7 @@ async def start_cluster(
host=scheduler_addr,
**scheduler_kwargs,
)

workers = [
Worker(
s.address,
Expand All @@ -822,7 +825,11 @@ async def start_cluster(
loop=loop,
validate=True,
host=ncore[0],
**(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs),
**(
merge(worker_kwargs, ncore[2]) # type: ignore
if len(ncore) > 2
else worker_kwargs
),
)
for i, ncore in enumerate(nthreads)
]
Expand Down Expand Up @@ -854,21 +861,24 @@ async def end_worker(w):


def gen_cluster(
nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)],
ncores=None,
nthreads: list[tuple[str, int] | tuple[str, int, dict]] = [
("127.0.0.1", 1),
("127.0.0.1", 2),
],
ncores: None = None, # deprecated
scheduler="127.0.0.1",
timeout=_TEST_TIMEOUT,
security=None,
Worker=Worker,
client=False,
scheduler_kwargs={},
worker_kwargs={},
client_kwargs={},
active_rpc_timeout=1,
config={},
clean_kwargs={},
allow_unclosed=False,
):
timeout: float = _TEST_TIMEOUT,
security: Security | dict[str, Any] | None = None,
Worker: type[ServerNode] = Worker,
client: bool = False,
scheduler_kwargs: dict[str, Any] = {},
worker_kwargs: dict[str, Any] = {},
client_kwargs: dict[str, Any] = {},
active_rpc_timeout: float = 1,
config: dict[str, Any] = {},
clean_kwargs: dict[str, Any] = {},
allow_unclosed: bool = False,
) -> Callable[[Callable], Callable]:
from distributed import Client

""" Coroutine test with small cluster
Expand Down

0 comments on commit 584db2c

Please sign in to comment.