Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMM: Don't schedule tasks to paused workers #5431

Merged
merged 4 commits into from Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
106 changes: 80 additions & 26 deletions distributed/scheduler.py
Expand Up @@ -541,6 +541,7 @@ def __init__(
self,
*,
address: str,
status: Status,
pid: Py_ssize_t,
name: object,
nthreads: Py_ssize_t = 0,
Expand All @@ -560,9 +561,9 @@ def __init__(
self._services = services or {}
self._versions = versions or {}
self._nanny = nanny
self._status = status

self._hash = hash(address)
self._status = Status.undefined
self._nbytes = 0
self._occupancy = 0
self._memory_unmanaged_old = 0
Expand Down Expand Up @@ -721,6 +722,7 @@ def clean(self):
"""Return a version of this object that is appropriate for serialization"""
ws: WorkerState = WorkerState(
address=self._address,
status=self._status,
pid=self._pid,
name=self._name,
nthreads=self._nthreads,
Expand All @@ -736,9 +738,10 @@ def clean(self):
return ws

def __repr__(self):
return "<WorkerState %r, name: %s, memory: %d, processing: %d>" % (
return "<WorkerState %r, name: %s, status: %s, memory: %d, processing: %d>" % (
self._address,
self._name,
self._status.name,
len(self._has_what),
len(self._processing),
)
Expand All @@ -747,6 +750,7 @@ def _repr_html_(self):
return get_template("worker_state.html.j2").render(
address=self.address,
name=self.name,
status=self.status.name,
has_what=self._has_what,
processing=self.processing,
)
Expand Down Expand Up @@ -1872,6 +1876,8 @@ class SchedulerState:
Set of workers that are not fully utilized
* **saturated:** ``{WorkerState}``:
Set of workers that are not over-utilized
* **running:** ``{WorkerState}``:
Set of workers that are currently in running state

* **clients:** ``{client key: ClientState}``
Clients currently connected to the scheduler
Expand All @@ -1890,7 +1896,8 @@ class SchedulerState:
_idle_dv: dict # dict[str, WorkerState]
_n_tasks: Py_ssize_t
_resources: dict
_saturated: set
_saturated: set # set[WorkerState]
_running: set # set[WorkerState]
_tasks: dict
_task_groups: dict
_task_prefixes: dict
Expand Down Expand Up @@ -1977,6 +1984,9 @@ def __init__(
self._workers = workers
# Note: cython.cast, not typing.cast!
self._workers_dv = cast(dict, self._workers)
self._running = {
ws for ws in self._workers.values() if ws.status == Status.running
}
self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins}

# Variables from dask.config, cached by __init__ for performance
Expand Down Expand Up @@ -2041,9 +2051,13 @@ def resources(self):
return self._resources

@property
def saturated(self):
def saturated(self) -> "set[WorkerState]":
return self._saturated

@property
def running(self) -> "set[WorkerState]":
return self._running

@property
def tasks(self):
return self._tasks
Expand Down Expand Up @@ -3339,7 +3353,7 @@ def get_task_duration(self, ts: TaskState) -> double:

@ccall
@exceptval(check=False)
def valid_workers(self, ts: TaskState) -> set:
def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
"""Return set of currently valid workers for key

If all workers are valid then this returns ``None``.
Expand All @@ -3352,7 +3366,7 @@ def valid_workers(self, ts: TaskState) -> set:
s: set = None # type: ignore

if ts._worker_restrictions:
s = {w for w in ts._worker_restrictions if w in self._workers_dv}
s = {addr for addr in ts._worker_restrictions if addr in self._workers_dv}

if ts._host_restrictions:
# Resolve the alias here rather than early, for the worker
Expand All @@ -3379,9 +3393,9 @@ def valid_workers(self, ts: TaskState) -> set:
self._resources[resource] = dr = {}

sw: set = set()
for w, supplied in dr.items():
for addr, supplied in dr.items():
if supplied >= required:
sw.add(w)
sw.add(addr)

dw[resource] = sw

Expand All @@ -3391,8 +3405,13 @@ def valid_workers(self, ts: TaskState) -> set:
else:
s &= ww

if s is not None:
s = {self._workers_dv[w] for w in s}
if s is None:
if len(self._running) < len(self._workers_dv):
return self._running.copy()
else:
s = {self._workers_dv[addr] for addr in s}
if len(self._running) < len(self._workers_dv):
s &= self._running
fjetter marked this conversation as resolved.
Show resolved Hide resolved

return s

Expand Down Expand Up @@ -4212,7 +4231,9 @@ def heartbeat_worker(
async def add_worker(
self,
comm=None,
address=None,
*,
address: str,
status: str,
keys=(),
nthreads=None,
name=None,
Expand All @@ -4238,9 +4259,8 @@ async def add_worker(
address = normalize_address(address)
host = get_address_host(address)

ws: WorkerState = parent._workers_dv.get(address)
if ws is not None:
raise ValueError("Worker already exists %s" % ws)
if address in parent._workers_dv:
raise ValueError("Worker already exists %s" % address)

if name in parent._aliases:
logger.warning(
Expand All @@ -4255,8 +4275,10 @@ async def add_worker(
await comm.write(msg)
return

ws: WorkerState
parent._workers[address] = ws = WorkerState(
address=address,
status=Status.lookup[status], # type: ignore
pid=pid,
nthreads=nthreads,
memory_limit=memory_limit or 0,
Expand All @@ -4267,12 +4289,14 @@ async def add_worker(
nanny=nanny,
extra=extra,
)
if ws._status == Status.running:
parent._running.add(ws)

dh: dict = parent._host_info.get(host)
dh: dict = parent._host_info.get(host) # type: ignore
if dh is None:
parent._host_info[host] = dh = {}

dh_addresses: set = dh.get("addresses")
dh_addresses: set = dh.get("addresses") # type: ignore
if dh_addresses is None:
dh["addresses"] = dh_addresses = set()
dh["nthreads"] = 0
Expand All @@ -4292,7 +4316,8 @@ async def add_worker(
metrics=metrics,
)

# Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot exist before this.
# Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot
# exist before this.
self.check_idle_saturated(ws)

# for key in keys: # TODO
Expand All @@ -4318,7 +4343,7 @@ async def add_worker(
assert isinstance(nbytes, dict)
already_released_keys = []
for key in nbytes:
ts: TaskState = parent._tasks.get(key)
ts: TaskState = parent._tasks.get(key) # type: ignore
if ts is not None and ts.state != "released":
if ts.state == "memory":
self.add_keys(worker=address, keys=[key])
Expand Down Expand Up @@ -4347,14 +4372,15 @@ async def add_worker(
"stimulus_id": f"reconnect-already-released-{time()}",
}
)
for ts in list(parent._unrunnable):
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recommendations[ts._key] = "waiting"

if ws._status == Status.running:
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recommendations[ts._key] = "waiting"

if recommendations:
parent._transitions(recommendations, client_msgs, worker_msgs)
recommendations = {}

self.send_all(client_msgs, worker_msgs)

Expand Down Expand Up @@ -4896,6 +4922,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True):
parent._saturated.discard(ws)
del parent._workers[address]
ws.status = Status.closed
parent._running.discard(ws)
parent._total_occupancy -= ws._occupancy

recommendations: dict = {}
Expand Down Expand Up @@ -5143,6 +5170,11 @@ def validate_state(self, allow_overlap=False):
if not ws._processing:
assert not ws._occupancy
assert ws._address in parent._idle_dv
assert (ws._status == Status.running) == (ws in parent._running)

for ws in parent._running:
assert ws._status == Status.running
assert ws._address in parent._workers_dv

ts: TaskState
for k, ts in parent._tasks.items():
Expand Down Expand Up @@ -5423,20 +5455,42 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
ws._processing[ts] = 0
self.check_idle_saturated(ws)

def handle_worker_status_change(self, status: str, worker: str):
def handle_worker_status_change(self, status: str, worker: str) -> None:
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv.get(worker) # type: ignore
if not ws:
return
prev_status = ws._status
ws._status = Status.lookup[status] # type: ignore
if ws._status == prev_status:
return

self.log_event(
ws._address,
{
"action": "worker-status-change",
"prev-status": ws._status.name,
"prev-status": prev_status.name,
"status": status,
},
)
ws._status = Status.lookup[status] # type: ignore

if ws._status == Status.running:
parent._running.add(ws)

recs = {}
ts: TaskState
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recs[ts._key] = "waiting"
if recs:
client_msgs: dict = {}
worker_msgs: dict = {}
parent._transitions(recs, client_msgs, worker_msgs)
self.send_all(client_msgs, worker_msgs)
Comment on lines +5487 to +5490
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
client_msgs: dict = {}
worker_msgs: dict = {}
parent._transitions(recs, client_msgs, worker_msgs)
self.send_all(client_msgs, worker_msgs)
self.transitions(recs)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't client_msgs and worker_msgs filled in place by _transitions()? I fell for this already before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you're using the underscored method, you'll need to call the send yourself. If you use the non-underscored one this will be done. You basically just copied what the non-underscored method does,

def transitions(self, recommendations: dict):
"""Process transitions until none are left
This includes feedback from previous transitions and continues until we
reach a steady state
"""
parent: SchedulerState = cast(SchedulerState, self)
client_msgs: dict = {}
worker_msgs: dict = {}
parent._transitions(recommendations, client_msgs, worker_msgs)
self.send_all(client_msgs, worker_msgs)


else:
parent._running.discard(ws)

async def handle_worker(self, comm=None, worker=None):
"""
Expand Down
29 changes: 12 additions & 17 deletions distributed/tests/test_client.py
Expand Up @@ -3650,51 +3650,46 @@ def test_reconnect(loop):
"9393",
"--no-dashboard",
]
with popen(scheduler_cli) as s:
with popen(scheduler_cli):
c = Client("127.0.0.1:9393", loop=loop)
start = time()
while len(c.nthreads()) != 1:
sleep(0.1)
assert time() < start + 3

c.wait_for_workers(1, timeout=10)
x = c.submit(inc, 1)
assert x.result() == 2
assert x.result(timeout=10) == 2

start = time()
while c.status != "connecting":
assert time() < start + 5
assert time() < start + 10
sleep(0.01)

assert x.status == "cancelled"
with pytest.raises(CancelledError):
x.result()
x.result(timeout=10)

with popen(scheduler_cli) as s:
with popen(scheduler_cli):
start = time()
while c.status != "running":
sleep(0.1)
assert time() < start + 5
assert time() < start + 10
start = time()
while len(c.nthreads()) != 1:
sleep(0.05)
assert time() < start + 15
assert time() < start + 10

x = c.submit(inc, 1)
assert x.result() == 2
assert x.result(timeout=10) == 2

start = time()
while True:
assert time() < start + 10
try:
x.result()
x.result(timeout=10)
assert False
except CommClosedError:
continue
except CancelledError:
break
assert time() < start + 5
sleep(0.1)

sync(loop, w.close)
sync(loop, w.close, timeout=1)
c.close()


Expand Down