Skip to content

Commit

Permalink
Avoid workers in paused and closing_gracefully status
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 18, 2021
1 parent 7d2516a commit 05a12e5
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 80 deletions.
16 changes: 14 additions & 2 deletions distributed/active_memory_manager.py
Expand Up @@ -9,6 +9,7 @@
import dask
from dask.utils import parse_timedelta

from .core import Status
from .metrics import time
from .utils import import_term, log_errors

Expand Down Expand Up @@ -234,11 +235,16 @@ def _find_recipient(
if ts.state != "memory":
return None
if candidates is None:
candidates = set(self.scheduler.workers.values())
candidates = self.scheduler.running.copy()
else:
candidates &= self.scheduler.running

candidates -= ts.who_has
candidates -= pending_repl
if not candidates:
return None

# Select candidate with the lowest memory usage
return min(candidates, key=self.workers_memory.__getitem__)

def _find_dropper(
Expand Down Expand Up @@ -268,7 +274,13 @@ def _find_dropper(
candidates -= {waiter_ts.processing_on for waiter_ts in ts.waiters}
if not candidates:
return None
return max(candidates, key=self.workers_memory.__getitem__)

# Select candidate with the highest memory usage.
# Drop from workers with status paused or closing_gracefully first.
return max(
candidates,
key=lambda ws: (ws.status != Status.running, self.workers_memory[ws]),
)


class ActiveMemoryManagerPolicy:
Expand Down
105 changes: 79 additions & 26 deletions distributed/scheduler.py
Expand Up @@ -541,6 +541,7 @@ def __init__(
self,
*,
address: str,
status: Status,
pid: Py_ssize_t,
name: object,
nthreads: Py_ssize_t = 0,
Expand All @@ -560,9 +561,9 @@ def __init__(
self._services = services or {}
self._versions = versions or {}
self._nanny = nanny
self._status = status

self._hash = hash(address)
self._status = Status.undefined
self._nbytes = 0
self._occupancy = 0
self._memory_unmanaged_old = 0
Expand Down Expand Up @@ -721,6 +722,7 @@ def clean(self):
"""Return a version of this object that is appropriate for serialization"""
ws: WorkerState = WorkerState(
address=self._address,
status=self._status,
pid=self._pid,
name=self._name,
nthreads=self._nthreads,
Expand All @@ -736,9 +738,10 @@ def clean(self):
return ws

def __repr__(self):
return "<WorkerState %r, name: %s, memory: %d, processing: %d>" % (
return "<WorkerState %r, name: %s, status: %s, memory: %d, processing: %d>" % (
self._address,
self._name,
self._status.name,
len(self._has_what),
len(self._processing),
)
Expand All @@ -747,6 +750,7 @@ def _repr_html_(self):
return get_template("worker_state.html.j2").render(
address=self.address,
name=self.name,
status=self.status.name,
has_what=self._has_what,
processing=self.processing,
)
Expand Down Expand Up @@ -1872,6 +1876,8 @@ class SchedulerState:
Set of workers that are not fully utilized
* **saturated:** ``{WorkerState}``:
Set of workers that are not over-utilized
* **running:** ``{WorkerState}``:
Set of workers that are currently in running state
* **clients:** ``{client key: ClientState}``
Clients currently connected to the scheduler
Expand All @@ -1890,7 +1896,8 @@ class SchedulerState:
_idle_dv: dict # dict[str, WorkerState]
_n_tasks: Py_ssize_t
_resources: dict
_saturated: set
_saturated: set # set[WorkerState]
_running: set # set[WorkerState]
_tasks: dict
_task_groups: dict
_task_prefixes: dict
Expand Down Expand Up @@ -1977,6 +1984,9 @@ def __init__(
self._workers = workers
# Note: cython.cast, not typing.cast!
self._workers_dv = cast(dict, self._workers)
self._running = {
ws for ws in self._workers.values() if ws.status == Status.running
}
self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins}

# Variables from dask.config, cached by __init__ for performance
Expand Down Expand Up @@ -2041,9 +2051,13 @@ def resources(self):
return self._resources

@property
def saturated(self):
def saturated(self) -> "set[WorkerState]":
return self._saturated

@property
def running(self) -> "set[WorkerState]":
return self._running

@property
def tasks(self):
return self._tasks
Expand Down Expand Up @@ -3339,7 +3353,7 @@ def get_task_duration(self, ts: TaskState) -> double:

@ccall
@exceptval(check=False)
def valid_workers(self, ts: TaskState) -> set:
def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
"""Return set of currently valid workers for key
If all workers are valid then this returns ``None``.
Expand All @@ -3352,7 +3366,7 @@ def valid_workers(self, ts: TaskState) -> set:
s: set = None # type: ignore

if ts._worker_restrictions:
s = {w for w in ts._worker_restrictions if w in self._workers_dv}
s = {addr for addr in ts._worker_restrictions if addr in self._workers_dv}

if ts._host_restrictions:
# Resolve the alias here rather than early, for the worker
Expand All @@ -3379,9 +3393,9 @@ def valid_workers(self, ts: TaskState) -> set:
self._resources[resource] = dr = {}

sw: set = set()
for w, supplied in dr.items():
for addr, supplied in dr.items():
if supplied >= required:
sw.add(w)
sw.add(addr)

dw[resource] = sw

Expand All @@ -3391,8 +3405,13 @@ def valid_workers(self, ts: TaskState) -> set:
else:
s &= ww

if s is not None:
s = {self._workers_dv[w] for w in s}
if s is None:
if len(self._running) < len(self._workers_dv):
return self._running.copy()
else:
s = {self._workers_dv[addr] for addr in s}
if len(self._running) < len(self._workers_dv):
s &= self._running

return s

Expand Down Expand Up @@ -4212,7 +4231,9 @@ def heartbeat_worker(
async def add_worker(
self,
comm=None,
address=None,
*,
address: str,
status: str,
keys=(),
nthreads=None,
name=None,
Expand All @@ -4238,9 +4259,8 @@ async def add_worker(
address = normalize_address(address)
host = get_address_host(address)

ws: WorkerState = parent._workers_dv.get(address)
if ws is not None:
raise ValueError("Worker already exists %s" % ws)
if address in parent._workers_dv:
raise ValueError("Worker already exists %s" % address)

if name in parent._aliases:
logger.warning(
Expand All @@ -4255,8 +4275,10 @@ async def add_worker(
await comm.write(msg)
return

ws: WorkerState
parent._workers[address] = ws = WorkerState(
address=address,
status=Status.lookup[status], # type: ignore
pid=pid,
nthreads=nthreads,
memory_limit=memory_limit or 0,
Expand All @@ -4267,12 +4289,14 @@ async def add_worker(
nanny=nanny,
extra=extra,
)
if ws._status == Status.running:
parent._running.add(ws)

dh: dict = parent._host_info.get(host)
dh: dict = parent._host_info.get(host) # type: ignore
if dh is None:
parent._host_info[host] = dh = {}

dh_addresses: set = dh.get("addresses")
dh_addresses: set = dh.get("addresses") # type: ignore
if dh_addresses is None:
dh["addresses"] = dh_addresses = set()
dh["nthreads"] = 0
Expand All @@ -4292,7 +4316,8 @@ async def add_worker(
metrics=metrics,
)

# Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot exist before this.
# Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot
# exist before this.
self.check_idle_saturated(ws)

# for key in keys: # TODO
Expand All @@ -4318,7 +4343,7 @@ async def add_worker(
assert isinstance(nbytes, dict)
already_released_keys = []
for key in nbytes:
ts: TaskState = parent._tasks.get(key)
ts: TaskState = parent._tasks.get(key) # type: ignore
if ts is not None and ts.state != "released":
if ts.state == "memory":
self.add_keys(worker=address, keys=[key])
Expand Down Expand Up @@ -4347,14 +4372,15 @@ async def add_worker(
"stimulus_id": f"reconnect-already-released-{time()}",
}
)
for ts in list(parent._unrunnable):
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recommendations[ts._key] = "waiting"

if ws._status == Status.running:
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recommendations[ts._key] = "waiting"

if recommendations:
parent._transitions(recommendations, client_msgs, worker_msgs)
recommendations = {}

self.send_all(client_msgs, worker_msgs)

Expand Down Expand Up @@ -4896,6 +4922,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True):
parent._saturated.discard(ws)
del parent._workers[address]
ws.status = Status.closed
parent._running.discard(ws)
parent._total_occupancy -= ws._occupancy

recommendations: dict = {}
Expand Down Expand Up @@ -5143,6 +5170,11 @@ def validate_state(self, allow_overlap=False):
if not ws._processing:
assert not ws._occupancy
assert ws._address in parent._idle_dv
assert (ws._status == Status.running) == (ws in parent._running)

for ws in parent._running:
assert ws._status == Status.running
assert ws._address in parent._workers_dv

ts: TaskState
for k, ts in parent._tasks.items():
Expand Down Expand Up @@ -5423,20 +5455,41 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
ws._processing[ts] = 0
self.check_idle_saturated(ws)

def handle_worker_status_change(self, status: str, worker: str):
def handle_worker_status_change(self, status: str, worker: str) -> None:
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv.get(worker) # type: ignore
if not ws:
return
prev_status = ws._status
ws._status = Status.lookup[status] # type: ignore
if ws._status == prev_status:
return

self.log_event(
ws._address,
{
"action": "worker-status-change",
"prev-status": ws._status.name,
"prev-status": prev_status.name,
"status": status,
},
)
ws._status = Status.lookup[status] # type: ignore

if ws._status == Status.running:
parent._running.add(ws)

recs = {}
client_msgs: dict = {}
worker_msgs: dict = {}
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recs[ts._key] = "waiting"
if recs:
parent._transitions(recs, client_msgs, worker_msgs)
self.send_all(client_msgs, worker_msgs)

else:
parent._running.discard(ws)

async def handle_worker(self, comm=None, worker=None):
"""
Expand Down
34 changes: 34 additions & 0 deletions distributed/tests/test_active_memory_manager.py
Expand Up @@ -8,6 +8,7 @@
ActiveMemoryManagerExtension,
ActiveMemoryManagerPolicy,
)
from distributed.core import Status
from distributed.utils_test import gen_cluster, inc, slowinc

NO_AMM_START = {"distributed.scheduler.active-memory-manager.start": False}
Expand Down Expand Up @@ -328,6 +329,22 @@ async def test_drop_with_bad_candidates(c, s, a, b):
assert s.tasks["x"].who_has == {ws0, ws1}


@gen_cluster(client=True, nthreads=[("", 1)] * 10, config=demo_config("drop", n=1))
async def test_drop_prefers_paused_workers(c, s, *workers):
x = await c.scatter({"x": 1}, broadcast=True)
ts = s.tasks["x"]
assert len(ts.who_has) == 10
ws = s.workers[workers[3].address]
workers[3].memory_pause_fraction = 1e-9
while ws.status != Status.paused:
await asyncio.sleep(0.01)

s.extensions["amm"].run_once()
while len(ts.who_has) != 9:
await asyncio.sleep(0.01)
assert ws not in ts.who_has


@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("replicate", n=2))
async def test_replicate(c, s, *workers):
futures = await c.scatter({"x": 123})
Expand Down Expand Up @@ -436,6 +453,23 @@ async def test_replicate_to_candidates_with_key(c, s, a, b):
assert s.tasks["x"].who_has == {ws0}


@gen_cluster(
client=True,
nthreads=[("", 1), ("", 1, {"memory_pause_fraction": 1e-15}), ("", 1)],
config=demo_config("replicate"),
)
async def test_replicate_avoids_paused_workers(c, s, w0, w1, w2):
while s.workers[w1.address].status != Status.paused:
await asyncio.sleep(0.01)

futures = await c.scatter({"x": 1}, workers=[w0.address])
s.extensions["amm"].run_once()
while "x" not in w2.data:
await asyncio.sleep(0.01)
await asyncio.sleep(0.2)
assert "x" not in w1.data


@gen_cluster(
nthreads=[("", 1)] * 4,
client=True,
Expand Down

0 comments on commit 05a12e5

Please sign in to comment.