Skip to content

Commit

Permalink
AMM: Don't schedule tasks to paused workers (#5431)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 20, 2021
1 parent f3aa9d1 commit 3afc670
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 77 deletions.
106 changes: 80 additions & 26 deletions distributed/scheduler.py
Expand Up @@ -541,6 +541,7 @@ def __init__(
self,
*,
address: str,
status: Status,
pid: Py_ssize_t,
name: object,
nthreads: Py_ssize_t = 0,
Expand All @@ -560,9 +561,9 @@ def __init__(
self._services = services or {}
self._versions = versions or {}
self._nanny = nanny
self._status = status

self._hash = hash(address)
self._status = Status.undefined
self._nbytes = 0
self._occupancy = 0
self._memory_unmanaged_old = 0
Expand Down Expand Up @@ -721,6 +722,7 @@ def clean(self):
"""Return a version of this object that is appropriate for serialization"""
ws: WorkerState = WorkerState(
address=self._address,
status=self._status,
pid=self._pid,
name=self._name,
nthreads=self._nthreads,
Expand All @@ -736,9 +738,10 @@ def clean(self):
return ws

def __repr__(self):
return "<WorkerState %r, name: %s, memory: %d, processing: %d>" % (
return "<WorkerState %r, name: %s, status: %s, memory: %d, processing: %d>" % (
self._address,
self._name,
self._status.name,
len(self._has_what),
len(self._processing),
)
Expand All @@ -747,6 +750,7 @@ def _repr_html_(self):
return get_template("worker_state.html.j2").render(
address=self.address,
name=self.name,
status=self.status.name,
has_what=self._has_what,
processing=self.processing,
)
Expand Down Expand Up @@ -1872,6 +1876,8 @@ class SchedulerState:
Set of workers that are not fully utilized
* **saturated:** ``{WorkerState}``:
Set of workers that are not over-utilized
* **running:** ``{WorkerState}``:
Set of workers that are currently in running state
* **clients:** ``{client key: ClientState}``
Clients currently connected to the scheduler
Expand All @@ -1890,7 +1896,8 @@ class SchedulerState:
_idle_dv: dict # dict[str, WorkerState]
_n_tasks: Py_ssize_t
_resources: dict
_saturated: set
_saturated: set # set[WorkerState]
_running: set # set[WorkerState]
_tasks: dict
_task_groups: dict
_task_prefixes: dict
Expand Down Expand Up @@ -1977,6 +1984,9 @@ def __init__(
self._workers = workers
# Note: cython.cast, not typing.cast!
self._workers_dv = cast(dict, self._workers)
self._running = {
ws for ws in self._workers.values() if ws.status == Status.running
}
self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins}

# Variables from dask.config, cached by __init__ for performance
Expand Down Expand Up @@ -2041,9 +2051,13 @@ def resources(self):
return self._resources

@property
def saturated(self):
def saturated(self) -> "set[WorkerState]":
return self._saturated

@property
def running(self) -> "set[WorkerState]":
return self._running

@property
def tasks(self):
return self._tasks
Expand Down Expand Up @@ -3339,7 +3353,7 @@ def get_task_duration(self, ts: TaskState) -> double:

@ccall
@exceptval(check=False)
def valid_workers(self, ts: TaskState) -> set:
def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
"""Return set of currently valid workers for key
If all workers are valid then this returns ``None``.
Expand All @@ -3352,7 +3366,7 @@ def valid_workers(self, ts: TaskState) -> set:
s: set = None # type: ignore

if ts._worker_restrictions:
s = {w for w in ts._worker_restrictions if w in self._workers_dv}
s = {addr for addr in ts._worker_restrictions if addr in self._workers_dv}

if ts._host_restrictions:
# Resolve the alias here rather than early, for the worker
Expand All @@ -3379,9 +3393,9 @@ def valid_workers(self, ts: TaskState) -> set:
self._resources[resource] = dr = {}

sw: set = set()
for w, supplied in dr.items():
for addr, supplied in dr.items():
if supplied >= required:
sw.add(w)
sw.add(addr)

dw[resource] = sw

Expand All @@ -3391,8 +3405,13 @@ def valid_workers(self, ts: TaskState) -> set:
else:
s &= ww

if s is not None:
s = {self._workers_dv[w] for w in s}
if s is None:
if len(self._running) < len(self._workers_dv):
return self._running.copy()
else:
s = {self._workers_dv[addr] for addr in s}
if len(self._running) < len(self._workers_dv):
s &= self._running

return s

Expand Down Expand Up @@ -4212,7 +4231,9 @@ def heartbeat_worker(
async def add_worker(
self,
comm=None,
address=None,
*,
address: str,
status: str,
keys=(),
nthreads=None,
name=None,
Expand All @@ -4238,9 +4259,8 @@ async def add_worker(
address = normalize_address(address)
host = get_address_host(address)

ws: WorkerState = parent._workers_dv.get(address)
if ws is not None:
raise ValueError("Worker already exists %s" % ws)
if address in parent._workers_dv:
raise ValueError("Worker already exists %s" % address)

if name in parent._aliases:
logger.warning(
Expand All @@ -4255,8 +4275,10 @@ async def add_worker(
await comm.write(msg)
return

ws: WorkerState
parent._workers[address] = ws = WorkerState(
address=address,
status=Status.lookup[status], # type: ignore
pid=pid,
nthreads=nthreads,
memory_limit=memory_limit or 0,
Expand All @@ -4267,12 +4289,14 @@ async def add_worker(
nanny=nanny,
extra=extra,
)
if ws._status == Status.running:
parent._running.add(ws)

dh: dict = parent._host_info.get(host)
dh: dict = parent._host_info.get(host) # type: ignore
if dh is None:
parent._host_info[host] = dh = {}

dh_addresses: set = dh.get("addresses")
dh_addresses: set = dh.get("addresses") # type: ignore
if dh_addresses is None:
dh["addresses"] = dh_addresses = set()
dh["nthreads"] = 0
Expand All @@ -4292,7 +4316,8 @@ async def add_worker(
metrics=metrics,
)

# Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot exist before this.
# Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot
# exist before this.
self.check_idle_saturated(ws)

# for key in keys: # TODO
Expand All @@ -4318,7 +4343,7 @@ async def add_worker(
assert isinstance(nbytes, dict)
already_released_keys = []
for key in nbytes:
ts: TaskState = parent._tasks.get(key)
ts: TaskState = parent._tasks.get(key) # type: ignore
if ts is not None and ts.state != "released":
if ts.state == "memory":
self.add_keys(worker=address, keys=[key])
Expand Down Expand Up @@ -4347,14 +4372,15 @@ async def add_worker(
"stimulus_id": f"reconnect-already-released-{time()}",
}
)
for ts in list(parent._unrunnable):
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recommendations[ts._key] = "waiting"

if ws._status == Status.running:
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recommendations[ts._key] = "waiting"

if recommendations:
parent._transitions(recommendations, client_msgs, worker_msgs)
recommendations = {}

self.send_all(client_msgs, worker_msgs)

Expand Down Expand Up @@ -4896,6 +4922,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True):
parent._saturated.discard(ws)
del parent._workers[address]
ws.status = Status.closed
parent._running.discard(ws)
parent._total_occupancy -= ws._occupancy

recommendations: dict = {}
Expand Down Expand Up @@ -5143,6 +5170,11 @@ def validate_state(self, allow_overlap=False):
if not ws._processing:
assert not ws._occupancy
assert ws._address in parent._idle_dv
assert (ws._status == Status.running) == (ws in parent._running)

for ws in parent._running:
assert ws._status == Status.running
assert ws._address in parent._workers_dv

ts: TaskState
for k, ts in parent._tasks.items():
Expand Down Expand Up @@ -5423,20 +5455,42 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
ws._processing[ts] = 0
self.check_idle_saturated(ws)

def handle_worker_status_change(self, status: str, worker: str):
def handle_worker_status_change(self, status: str, worker: str) -> None:
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv.get(worker) # type: ignore
if not ws:
return
prev_status = ws._status
ws._status = Status.lookup[status] # type: ignore
if ws._status == prev_status:
return

self.log_event(
ws._address,
{
"action": "worker-status-change",
"prev-status": ws._status.name,
"prev-status": prev_status.name,
"status": status,
},
)
ws._status = Status.lookup[status] # type: ignore

if ws._status == Status.running:
parent._running.add(ws)

recs = {}
ts: TaskState
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recs[ts._key] = "waiting"
if recs:
client_msgs: dict = {}
worker_msgs: dict = {}
parent._transitions(recs, client_msgs, worker_msgs)
self.send_all(client_msgs, worker_msgs)

else:
parent._running.discard(ws)

async def handle_worker(self, comm=None, worker=None):
"""
Expand Down
29 changes: 12 additions & 17 deletions distributed/tests/test_client.py
Expand Up @@ -3650,51 +3650,46 @@ def test_reconnect(loop):
"9393",
"--no-dashboard",
]
with popen(scheduler_cli) as s:
with popen(scheduler_cli):
c = Client("127.0.0.1:9393", loop=loop)
start = time()
while len(c.nthreads()) != 1:
sleep(0.1)
assert time() < start + 3

c.wait_for_workers(1, timeout=10)
x = c.submit(inc, 1)
assert x.result() == 2
assert x.result(timeout=10) == 2

start = time()
while c.status != "connecting":
assert time() < start + 5
assert time() < start + 10
sleep(0.01)

assert x.status == "cancelled"
with pytest.raises(CancelledError):
x.result()
x.result(timeout=10)

with popen(scheduler_cli) as s:
with popen(scheduler_cli):
start = time()
while c.status != "running":
sleep(0.1)
assert time() < start + 5
assert time() < start + 10
start = time()
while len(c.nthreads()) != 1:
sleep(0.05)
assert time() < start + 15
assert time() < start + 10

x = c.submit(inc, 1)
assert x.result() == 2
assert x.result(timeout=10) == 2

start = time()
while True:
assert time() < start + 10
try:
x.result()
x.result(timeout=10)
assert False
except CommClosedError:
continue
except CancelledError:
break
assert time() < start + 5
sleep(0.1)

sync(loop, w.close)
sync(loop, w.close, timeout=1)
c.close()


Expand Down

0 comments on commit 3afc670

Please sign in to comment.