Skip to content

Commit

Permalink
Merge branch 'main' into AMM/RetireWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 17, 2021
2 parents 035bf98 + fa326d5 commit 98588d7
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 61 deletions.
51 changes: 30 additions & 21 deletions distributed/scheduler.py
Expand Up @@ -511,6 +511,7 @@ class WorkerState:
_occupancy: double
_pid: Py_ssize_t
_processing: dict
_long_running: set
_resources: dict
_services: dict
_status: Status
Expand Down Expand Up @@ -539,6 +540,7 @@ class WorkerState:
"_occupancy",
"_pid",
"_processing",
"_long_running",
"_resources",
"_services",
"_status",
Expand Down Expand Up @@ -588,6 +590,7 @@ def __init__(
self._actors = set()
self._has_what = {}
self._processing = {}
self._long_running = set()
self._executing = {}
self._resources = {}
self._used_resources = {}
Expand Down Expand Up @@ -2670,8 +2673,10 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double:
total_duration = duration + comm
old = ws._processing.get(ts, 0)
ws._processing[ts] = total_duration
self._total_occupancy += total_duration - old
ws._occupancy += total_duration - old

if ts not in ws._long_running:
self._total_occupancy += total_duration - old
ws._occupancy += total_duration - old

return total_duration

Expand Down Expand Up @@ -3543,6 +3548,23 @@ def remove_all_replicas(self, ts: TaskState):
self._replicated_tasks.remove(ts)
ts._who_has.clear()

@ccall
@exceptval(check=False)
def _reevaluate_occupancy_worker(self, ws: WorkerState):
"""See reevaluate_occupancy"""
ts: TaskState
old = ws._occupancy
for ts in ws._processing:
self.set_duration_estimate(ts, ws)

self.check_idle_saturated(ws)
steal = self.extensions.get("stealing")
if steal is None:
return
if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3:
for ts in ws._processing:
steal.recalculate_cost(ts)


class Scheduler(SchedulerState, ServerNode):
"""Dynamic distributed task scheduler
Expand Down Expand Up @@ -5521,7 +5543,12 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
occ: double = ws._processing[ts]
ws._occupancy -= occ
parent._total_occupancy -= occ
# Cannot remove from processing since we're using this for things like
# idleness detection. Idle workers are typically targeted for
# downscaling but we should not downscale workers with long running
# tasks
ws._processing[ts] = 0
ws._long_running.add(ts)
self.check_idle_saturated(ws)

def handle_worker_status_change(self, status: str, worker: str) -> None:
Expand Down Expand Up @@ -7827,7 +7854,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0):
try:
if ws is None or not ws._processing:
continue
_reevaluate_occupancy_worker(parent, ws)
parent._reevaluate_occupancy_worker(ws)
finally:
del ws # lose ref

Expand Down Expand Up @@ -8168,24 +8195,6 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict:
return {}


@cfunc
@exceptval(check=False)
def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
"""See reevaluate_occupancy"""
ts: TaskState
old = ws._occupancy
for ts in ws._processing:
state.set_duration_estimate(ts, ws)

state.check_idle_saturated(ws)
steal = state.extensions.get("stealing")
if not steal:
return
if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3:
for ts in ws._processing:
steal.recalculate_cost(ts)


@cfunc
@exceptval(check=False)
def decide_worker(
Expand Down
12 changes: 2 additions & 10 deletions distributed/stealing.py
Expand Up @@ -313,16 +313,8 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
_log_msg = [key, state, victim.address, thief.address, stimulus_id]

if ts.state != "processing":
self.log(("not-processing", *_log_msg))
old_thief = thief.occupancy
new_thief = sum(thief.processing.values())
old_victim = victim.occupancy
new_victim = sum(victim.processing.values())
thief.occupancy = new_thief
victim.occupancy = new_victim
self.scheduler.total_occupancy += (
new_thief - old_thief + new_victim - old_victim
)
self.scheduler._reevaluate_occupancy_worker(thief)
self.scheduler._reevaluate_occupancy_worker(victim)
elif (
state in _WORKER_STATE_UNDEFINED
or state in _WORKER_STATE_CONFIRM
Expand Down
38 changes: 37 additions & 1 deletion distributed/tests/test_client.py
Expand Up @@ -31,7 +31,7 @@
import dask.bag as db
from dask import delayed
from dask.optimization import SubgraphCallable
from dask.utils import stringify, tmpfile
from dask.utils import parse_timedelta, stringify, tmpfile

from distributed import (
CancelledError,
Expand Down Expand Up @@ -5156,6 +5156,42 @@ def f(x):
assert results == [sum(map(inc, range(10)))] * 10


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_long_running_not_in_occupancy(c, s, a):
# https://github.com/dask/distributed/issues/5332
from distributed import Lock

l = Lock()
await l.acquire()

def long_running(lock):
sleep(0.1)
secede()
lock.acquire()

f = c.submit(long_running, l)
while f.key not in s.tasks:
await asyncio.sleep(0.01)
assert s.workers[a.address].occupancy == parse_timedelta(
dask.config.get("distributed.scheduler.unknown-task-duration")
)

while s.workers[a.address].occupancy:
await asyncio.sleep(0.01)
await a.heartbeat()

ts = s.tasks[f.key]
ws = s.workers[a.address]
s.set_duration_estimate(ts, ws)
assert s.workers[a.address].occupancy == 0

s.reevaluate_occupancy(0)
assert s.workers[a.address].occupancy == 0
await l.release()

await f


@gen_cluster(client=True)
async def test_sub_submit_priority(c, s, a, b):
def func():
Expand Down
68 changes: 68 additions & 0 deletions distributed/tests/test_utils_test.py
Expand Up @@ -7,6 +7,7 @@
from time import sleep

import pytest
import yaml
from tornado import gen

from distributed import Client, Nanny, Scheduler, Worker, config, default_client
Expand All @@ -18,6 +19,7 @@
_UnhashableCallable,
assert_worker_story,
cluster,
dump_cluster_state,
gen_cluster,
gen_test,
inc,
Expand Down Expand Up @@ -480,3 +482,69 @@ def test_assert_worker_story():
def test_assert_worker_story_malformed_story(story):
with pytest.raises(AssertionError, match="Malformed story event"):
assert_worker_story(story, [])


@gen_cluster()
async def test_dump_cluster_state(s, a, b, tmpdir):
await dump_cluster_state(s, [a, b], str(tmpdir), "dump")
with open(f"{tmpdir}/dump.yaml") as fh:
out = yaml.safe_load(fh)

assert out.keys() == {"scheduler", "workers", "versions"}
assert out["workers"].keys() == {a.address, b.address}


@gen_cluster(nthreads=[])
async def test_dump_cluster_state_no_workers(s, tmpdir):
await dump_cluster_state(s, [], str(tmpdir), "dump")
with open(f"{tmpdir}/dump.yaml") as fh:
out = yaml.safe_load(fh)

assert out.keys() == {"scheduler", "workers", "versions"}
assert out["workers"] == {}


@gen_cluster(Worker=Nanny)
async def test_dump_cluster_state_nannies(s, a, b, tmpdir):
await dump_cluster_state(s, [a, b], str(tmpdir), "dump")
with open(f"{tmpdir}/dump.yaml") as fh:
out = yaml.safe_load(fh)

assert out.keys() == {"scheduler", "workers", "versions"}
assert out["workers"].keys() == s.workers.keys()


@gen_cluster()
async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmpdir):
a.stop()
await dump_cluster_state(s, [a, b], str(tmpdir), "dump")
with open(f"{tmpdir}/dump.yaml") as fh:
out = yaml.safe_load(fh)

assert out.keys() == {"scheduler", "workers", "versions"}
assert isinstance(out["workers"][a.address], dict)
assert isinstance(out["workers"][b.address], dict)


@pytest.mark.slow
@gen_cluster(
client=True,
Worker=Nanny,
config={"distributed.comm.timeouts.connect": "200ms"},
)
async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmpdir):
addr1, addr2 = s.workers
clog_fut = asyncio.create_task(
c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[addr1])
)
await asyncio.sleep(0.2)

await dump_cluster_state(s, [a, b], str(tmpdir), "dump")
with open(f"{tmpdir}/dump.yaml") as fh:
out = yaml.safe_load(fh)

assert out.keys() == {"scheduler", "workers", "versions"}
assert isinstance(out["workers"][addr2], dict)
assert out["workers"][addr1].startswith("OSError('Timed out trying to connect to")

clog_fut.cancel()
27 changes: 26 additions & 1 deletion distributed/tests/test_worker_client.py
Expand Up @@ -2,6 +2,7 @@
import random
import threading
import warnings
from collections import defaultdict
from time import sleep

import pytest
Expand All @@ -18,7 +19,7 @@
worker_client,
)
from distributed.metrics import time
from distributed.utils_test import double, gen_cluster, inc
from distributed.utils_test import double, gen_cluster, inc, slowinc


@gen_cluster(client=True)
Expand Down Expand Up @@ -315,3 +316,27 @@ async def test_submit_different_names(s, a, b):
assert fut > 0
finally:
await c.close()


@gen_cluster(client=True)
async def test_secede_does_not_claim_worker(c, s, a, b):
"""A seceded task must not block the task running it. Tasks scheduled from
within should be evenly distributed"""
# https://github.com/dask/distributed/issues/5332
def get_addr(x):
w = get_worker()
slowinc(x)
return w.address

def long_running():
with worker_client() as client:
futs = client.map(get_addr, range(100))
workers = defaultdict(int)
for f in futs:
workers[f.result()] += 1
return dict(workers)

res = await c.submit(long_running)
assert len(res) == 2
assert res[a.address] > 25
assert res[b.address] > 25
47 changes: 19 additions & 28 deletions distributed/utils_test.py
Expand Up @@ -994,18 +994,12 @@ async def coro():
task.print_stack(file=buffer)

if cluster_dump_directory:
try:
await dump_cluster_state(
s,
ws,
output_dir=cluster_dump_directory,
func_name=func.__name__,
)
except Exception:
print(
f"Exception {sys.exc_info()} while trying to "
"dump cluster state."
)
await dump_cluster_state(
s,
ws,
output_dir=cluster_dump_directory,
func_name=func.__name__,
)

task.cancel()
while not task.cancelled():
Expand Down Expand Up @@ -1090,30 +1084,25 @@ def get_unclosed():
async def dump_cluster_state(
s: Scheduler, ws: list[ServerNode], output_dir: str, func_name: str
) -> None:
"""A variant of Client.dump_cluster_state, which does not expect on any of the below
"""A variant of Client.dump_cluster_state, which does not rely on any of the below
to work:
- individual Workers
- Having a client at all
- Client->Scheduler comms
- Scheduler->Worker comms (unless using Nannies)
"""
scheduler_info = s._to_dict()
workers_info: dict[str, Any]
versions_info = version_module.get_versions()

if not ws or isinstance(ws[0], Worker):
workers_info = {w.address: w._to_dict() for w in ws}
else:
# Variant of s.broadcast() that deals with unresponsive workers
async def safe_broadcast(addr: str) -> dict:
try:
return await s.broadcast(msg={"op": "dump_state"}, workers=[addr])
except Exception:
msg = f"Exception {sys.exc_info()} while trying to dump worker state"
return {addr: msg}

workers_info = merge(
await asyncio.gather(safe_broadcast(w.address) for w in ws)
)
workers_info = await s.broadcast(msg={"op": "dump_state"}, on_error="return")
workers_info = {
k: repr(v) if isinstance(v, Exception) else v
for k, v in workers_info.items()
}

state = {
"scheduler": scheduler_info,
Expand All @@ -1123,7 +1112,7 @@ async def safe_broadcast(addr: str) -> dict:
os.makedirs(output_dir, exist_ok=True)
fname = os.path.join(output_dir, func_name) + ".yaml"
with open(fname, "w") as fh:
yaml.dump(state, fh)
yaml.safe_dump(state, fh) # Automatically convert tuples to lists
print(f"Dumped cluster state to {fname}")


Expand Down Expand Up @@ -1832,8 +1821,10 @@ def assert_worker_story(
assert isinstance(ev, tuple)
assert isinstance(ev[-2], str) and ev[-2] # stimulus_id
assert isinstance(ev[-1], float) # timestamp
assert prev_ts <= ev[-1] # timestamps are monotonic ascending
assert now - 3600 < ev[-1] <= now # timestamps are within the last hour
assert prev_ts <= ev[-1] # Timestamps are monotonic ascending
# Timestamps are within the last hour. It's been observed that a timestamp
# generated in a Nanny process can be a few milliseconds in the future.
assert now - 3600 < ev[-1] <= now + 1
prev_ts = ev[-1]
except AssertionError:
raise AssertionError(
Expand Down

0 comments on commit 98588d7

Please sign in to comment.