Skip to content

Commit

Permalink
Merge branch 'AMM/avoid_paused' into AMM/RetireWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 15, 2021
2 parents d05e8aa + d544113 commit 84c08ef
Show file tree
Hide file tree
Showing 10 changed files with 426 additions and 167 deletions.
20 changes: 20 additions & 0 deletions distributed/comm/tests/test_ucx.py
Expand Up @@ -19,6 +19,24 @@
HOST = "127.0.0.1"


def handle_exception(loop, context):
msg = context.get("exception", context["message"])
print(msg)


# Let's make sure that UCX gets time to cancel
# progress tasks before closing the event loop.
@pytest.fixture()
def event_loop(scope="function"):
loop = asyncio.new_event_loop()
loop.set_exception_handler(handle_exception)
ucp.reset()
yield loop
ucp.reset()
loop.run_until_complete(asyncio.sleep(0))
loop.close()


def test_registered():
assert "ucx" in backends
backend = get_backend("ucx")
Expand Down Expand Up @@ -123,6 +141,8 @@ async def client_communicate(key, delay=0):
await asyncio.gather(*futures)
assert set(l) == {1234} | set(range(N))

listener.stop()

asyncio.run(f())


Expand Down
1 change: 1 addition & 0 deletions distributed/deploy/local.py
Expand Up @@ -198,6 +198,7 @@ def __init__(

worker_kwargs.update(
{
"host": host,
"nthreads": threads_per_worker,
"services": worker_services,
"dashboard_address": worker_dashboard_address,
Expand Down
18 changes: 18 additions & 0 deletions distributed/deploy/tests/test_local.py
Expand Up @@ -7,6 +7,7 @@
from distutils.version import LooseVersion
from threading import Lock
from time import sleep
from urllib.parse import urlparse

import pytest
import tornado
Expand Down Expand Up @@ -1100,3 +1101,20 @@ async def test_cluster_info_sync():

info = cluster.scheduler.get_metadata(keys=["cluster-manager-info"])
assert info["foo"] == "bar"


@pytest.mark.asyncio
@pytest.mark.parametrize("host", [None, "127.0.0.1"])
@pytest.mark.parametrize("use_nanny", [True, False])
async def test_cluster_host_used_throughout_cluster(host, use_nanny):
"""Ensure that the `host` kwarg is propagated through scheduler, nanny, and workers"""
async with LocalCluster(host=host, asynchronous=True) as cluster:
url = urlparse(cluster.scheduler_address)
assert url.hostname == "127.0.0.1"
for worker in cluster.workers.values():
url = urlparse(worker.address)
assert url.hostname == "127.0.0.1"

if use_nanny:
url = urlparse(worker.process.worker_address)
assert url.hostname == "127.0.0.1"
2 changes: 2 additions & 0 deletions distributed/http/static/css/base.css
Expand Up @@ -124,6 +124,8 @@ body {
min-width: 160px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
z-index: 2;
max-height: 90%;
overflow-y: scroll;
}

.dropdown-content ul li {
Expand Down
154 changes: 78 additions & 76 deletions distributed/scheduler.py
Expand Up @@ -399,10 +399,15 @@ class WorkerState:
.. attribute:: processing: {TaskState: cost}
A dictionary of tasks that have been submitted to this worker.
Each task state is asssociated with the expected cost in seconds
Each task state is associated with the expected cost in seconds
of running that task, summing both the task's expected computation
time and the expected communication time of its result.
If a task is already executing on the worker and the excecution time is
twice the learned average TaskGroup duration, this will be set to twice
the current executing time. If the task is unknown, the default task
duration is used instead of the TaskGroup average.
Multiple tasks may be submitted to a worker in advance and the worker
will run them eventually, depending on its execution resources
(but see :doc:`work-stealing`).
Expand Down Expand Up @@ -904,6 +909,18 @@ def name(self) -> str:
def all_durations(self) -> "defaultdict[str, float]":
return self._all_durations

@ccall
@exceptval(check=False)
def add_duration(self, action: str, start: double, stop: double):
duration = stop - start
self._all_durations[action] += duration
if action == "compute":
old = self._duration_average
if old < 0:
self._duration_average = duration
else:
self._duration_average = 0.5 * duration + 0.5 * old

@property
def duration_average(self) -> double:
return self._duration_average
Expand Down Expand Up @@ -1066,6 +1083,18 @@ def nbytes_total(self):
def duration(self) -> double:
return self._duration

@ccall
@exceptval(check=False)
def add_duration(self, action: str, start: double, stop: double):
duration = stop - start
self._all_durations[action] += duration
if action == "compute":
if self._stop < stop:
self._stop = stop
self._start = self._start or start
self._duration += duration
self._prefix.add_duration(action, start, stop)

@property
def types(self) -> set:
return self._types
Expand Down Expand Up @@ -2582,6 +2611,8 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double:
If a task takes longer than twice the current average duration we
estimate the task duration to be 2x current-runtime, otherwise we set it
to be the average duration.
See also ``_remove_from_processing``
"""
exec_time: double = ws._executing.get(ts, 0)
duration: double = self.get_task_duration(ts)
Expand All @@ -2591,7 +2622,11 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double:
else:
comm: double = self.get_comm_cost(ts, ws)
total_duration = duration + comm
old = ws._processing.get(ts, 0)
ws._processing[ts] = total_duration
self._total_occupancy += total_duration - old
ws._occupancy += total_duration - old

return total_duration

def transition_waiting_processing(self, key):
Expand All @@ -2616,10 +2651,8 @@ def transition_waiting_processing(self, key):
return recommendations, client_msgs, worker_msgs
worker = ws._address

duration_estimate = self.set_duration_estimate(ts, ws)
self.set_duration_estimate(ts, ws)
ts._processing_on = ws
ws._occupancy += duration_estimate
self._total_occupancy += duration_estimate
ts.state = "processing"
self.consume_resources(ts, ws)
self.check_idle_saturated(ws)
Expand Down Expand Up @@ -2698,7 +2731,6 @@ def transition_processing_memory(
worker_msgs: dict = {}
try:
ts: TaskState = self._tasks[key]
tg: TaskGroup = ts._group

assert worker
assert isinstance(worker, str)
Expand Down Expand Up @@ -2733,57 +2765,26 @@ def transition_processing_memory(
}
]

has_compute_startstop: bool = False
compute_start: double
compute_stop: double
#############################
# Update Timing Information #
#############################
if startstops:
startstop: dict
for startstop in startstops:
stop = startstop["stop"]
start = startstop["start"]
action = startstop["action"]
if not has_compute_startstop and action == "compute":
compute_start = start
compute_stop = stop
has_compute_startstop = True

# record timings of all actions -- a cheaper way of
# getting timing info compared with get_task_stream()
ts._prefix._all_durations[action] += stop - start
tg._all_durations[action] += stop - start
ts._group.add_duration(
stop=startstop["stop"],
start=startstop["start"],
action=startstop["action"],
)

#############################
# Update Timing Information #
#############################
if has_compute_startstop and ws._processing.get(ts, True):
# Update average task duration for worker
old_duration: double = ts._prefix._duration_average
new_duration: double = compute_stop - compute_start
avg_duration: double
if old_duration < 0:
avg_duration = new_duration
else:
avg_duration = 0.5 * old_duration + 0.5 * new_duration

ts._prefix._duration_average = avg_duration
tg._duration += new_duration
tg._start = tg._start or compute_start
if tg._stop < compute_stop:
tg._stop = compute_stop

s: set = self._unknown_durations.pop(ts._prefix._name, None)
tts: TaskState
if s:
for tts in s:
if tts._processing_on is not None:
wws = tts._processing_on
comm: double = self.get_comm_cost(tts, wws)
old: double = wws._processing[tts]
new: double = avg_duration + comm
diff: double = new - old
wws._processing[tts] = new
wws._occupancy += diff
self._total_occupancy += diff
s: set = self._unknown_durations.pop(ts._prefix._name, set())
tts: TaskState
steal = self.extensions.get("stealing")
for tts in s:
if tts._processing_on:
self.set_duration_estimate(tts, tts._processing_on)
if steal:
steal.put_key_in_stealable(tts)

############################
# Update State Information #
Expand Down Expand Up @@ -3328,10 +3329,14 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double:
return nbytes / self._bandwidth

@ccall
def get_task_duration(self, ts: TaskState, default: double = -1) -> double:
"""
Get the estimated computation cost of the given task
(not including any communication cost).
def get_task_duration(self, ts: TaskState) -> double:
"""Get the estimated computation cost of the given task (not including
any communication cost).
If no data has been observed, value of
`distributed.scheduler.default-task-durations` are used. If none is set
for this task, `distributed.scheduler.unknown-task-duration` is used
instead.
"""
duration: double = ts._prefix._duration_average
if duration >= 0:
Expand All @@ -3341,7 +3346,7 @@ def get_task_duration(self, ts: TaskState, default: double = -1) -> double:
if s is None:
self._unknown_durations[ts._prefix._name] = s = set()
s.add(ts)
return default if default >= 0 else self.UNKNOWN_TASK_DURATION
return self.UNKNOWN_TASK_DURATION

@ccall
@exceptval(check=False)
Expand Down Expand Up @@ -7091,8 +7096,12 @@ def get_metadata(self, comm=None, keys=None, default=no_default):
raise

def set_restrictions(self, comm=None, worker=None):
ts: TaskState
for key, restrictions in worker.items():
self.tasks[key]._worker_restrictions = set(restrictions)
ts = self.tasks[key]
if isinstance(restrictions, str):
restrictions = {restrictions}
ts._worker_restrictions = set(restrictions)

def get_task_status(self, comm=None, keys=None):
parent: SchedulerState = cast(SchedulerState, self)
Expand Down Expand Up @@ -7653,7 +7662,6 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0):
try:
if self.status == Status.closed:
return

last = time()
next_time = timedelta(seconds=0.1)

Expand Down Expand Up @@ -7780,6 +7788,8 @@ def _remove_from_processing(
) -> str: # -> str | None
"""
Remove *ts* from the set of processing tasks.
See also ``Scheduler.set_duration_estimate``
"""
ws: WorkerState = ts._processing_on
ts._processing_on = None # type: ignore
Expand Down Expand Up @@ -7937,6 +7947,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
ws: WorkerState
dts: TaskState

# FIXME: The duration attribute is not used on worker. We could safe ourselves the time to compute and submit this
if duration < 0:
duration = state.get_task_duration(ts)

Expand Down Expand Up @@ -8009,27 +8020,18 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict:
@exceptval(check=False)
def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
"""See reevaluate_occupancy"""
old: double = ws._occupancy
new: double = 0
diff: double
ts: TaskState
est: double
old = ws._occupancy
for ts in ws._processing:
est = state.set_duration_estimate(ts, ws)
new += est
state.set_duration_estimate(ts, ws)

ws._occupancy = new
diff = new - old
state._total_occupancy += diff
state.check_idle_saturated(ws)

# significant increase in duration
if new > old * 1.3:
steal = state._extensions.get("stealing")
if steal is not None:
for ts in ws._processing:
steal.remove_key_from_stealable(ts)
steal.put_key_in_stealable(ts)
steal = state.extensions.get("stealing")
if not steal:
return
if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3:
for ts in ws._processing:
steal.recalculate_cost(ts)


@cfunc
Expand Down

0 comments on commit 84c08ef

Please sign in to comment.