Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations for Worker and gen_cluster #5438

Merged
merged 6 commits into from Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion distributed/profile.py
Expand Up @@ -24,12 +24,15 @@
'children': {...}}}
}
"""
from __future__ import annotations

import bisect
import linecache
import sys
import threading
from collections import defaultdict, deque
from time import sleep
from typing import Any

import tlz as toolz

Expand Down Expand Up @@ -152,7 +155,7 @@ def merge(*args):
}


def create():
def create() -> dict[str, Any]:
return {
"count": 0,
"children": {},
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_scheduler.py
Expand Up @@ -1947,7 +1947,7 @@ class NoSchedulerDelayWorker(Worker):
comparisons using times reported from workers.
"""

@property
@property # type: ignore
def scheduler_delay(self):
return 0

Expand Down
5 changes: 0 additions & 5 deletions distributed/tests/test_stress.py
Expand Up @@ -99,9 +99,6 @@ async def test_stress_creation_and_deletion(c, s):
# Assertions are handled by the validate mechanism in the scheduler
da = pytest.importorskip("dask.array")

def _disable_suspicious_counter(dask_worker):
dask_worker._suspicious_count_limit = None

rng = da.random.RandomState(0)
x = rng.random(size=(2000, 2000), chunks=(100, 100))
y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2)
Expand All @@ -111,14 +108,12 @@ async def create_and_destroy_worker(delay):
start = time()
while time() < start + 5:
async with Nanny(s.address, nthreads=2) as n:
await c.run(_disable_suspicious_counter, workers=[n.worker_address])
await asyncio.sleep(delay)
print("Killed nanny")

await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20)))

async with Nanny(s.address, nthreads=2):
await c.run(_disable_suspicious_counter)
assert await c.compute(z) == 8000884.93


Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker.py
Expand Up @@ -2968,7 +2968,7 @@ async def test_who_has_consistent_remove_replica(c, s, *workers):

await f2

assert ("missing-dep", f1.key) in a.story(f1.key)
assert (f1.key, "missing-dep") in a.story(f1.key)
assert a.tasks[f1.key].suspicious_count == 0
assert s.tasks[f1.key].suspicious == 0

Expand Down
58 changes: 34 additions & 24 deletions distributed/utils_test.py
Expand Up @@ -22,6 +22,7 @@
import warnings
import weakref
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager, nullcontext, suppress
from glob import glob
from itertools import count
Expand Down Expand Up @@ -54,6 +55,7 @@
from .diagnostics.plugin import WorkerPlugin
from .metrics import time
from .nanny import Nanny
from .node import ServerNode
from .proctitle import enable_proctitle_on_children
from .security import Security
from .utils import (
Expand Down Expand Up @@ -770,7 +772,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses))


def gen_test(timeout=_TEST_TIMEOUT):
def gen_test(timeout: float = _TEST_TIMEOUT) -> Callable[[Callable], Callable]:
"""Coroutine test

@gen_test(timeout=5)
Expand All @@ -797,14 +799,14 @@ def test_func():


async def start_cluster(
nthreads,
scheduler_addr,
loop,
security=None,
Worker=Worker,
scheduler_kwargs={},
worker_kwargs={},
):
nthreads: list[tuple[str, int] | tuple[str, int, dict]],
scheduler_addr: str,
loop: IOLoop,
security: Security | dict[str, Any] | None = None,
Worker: type[ServerNode] = Worker,
scheduler_kwargs: dict[str, Any] = {},
worker_kwargs: dict[str, Any] = {},
) -> tuple[Scheduler, list[ServerNode]]:
s = await Scheduler(
loop=loop,
validate=True,
Expand All @@ -813,6 +815,7 @@ async def start_cluster(
host=scheduler_addr,
**scheduler_kwargs,
)

workers = [
Worker(
s.address,
Expand All @@ -822,7 +825,11 @@ async def start_cluster(
loop=loop,
validate=True,
host=ncore[0],
**(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs),
**(
merge(worker_kwargs, ncore[2]) # type: ignore
if len(ncore) > 2
else worker_kwargs
),
)
for i, ncore in enumerate(nthreads)
]
Expand Down Expand Up @@ -854,21 +861,24 @@ async def end_worker(w):


def gen_cluster(
nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)],
ncores=None,
nthreads: list[tuple[str, int] | tuple[str, int, dict]] = [
("127.0.0.1", 1),
("127.0.0.1", 2),
],
ncores: None = None, # deprecated
scheduler="127.0.0.1",
timeout=_TEST_TIMEOUT,
security=None,
Worker=Worker,
client=False,
scheduler_kwargs={},
worker_kwargs={},
client_kwargs={},
active_rpc_timeout=1,
config={},
clean_kwargs={},
allow_unclosed=False,
):
timeout: float = _TEST_TIMEOUT,
security: Security | dict[str, Any] | None = None,
Worker: type[ServerNode] = Worker,
client: bool = False,
scheduler_kwargs: dict[str, Any] = {},
worker_kwargs: dict[str, Any] = {},
client_kwargs: dict[str, Any] = {},
active_rpc_timeout: float = 1,
config: dict[str, Any] = {},
clean_kwargs: dict[str, Any] = {},
allow_unclosed: bool = False,
) -> Callable[[Callable], Callable]:
from distributed import Client

""" Coroutine test with small cluster
Expand Down