Skip to content

Commit

Permalink
Ensure occupancy tracking works as expected for long running tasks (#…
Browse files Browse the repository at this point in the history
…6351)

Co-authored-by: Ed Younis <edyounis123@gmail.com>
  • Loading branch information
fjetter and edyounis committed May 19, 2022
1 parent ff94776 commit 4d29246
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 24 deletions.
25 changes: 14 additions & 11 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ def transition_waiting_processing(self, key, stimulus_id):
return recommendations, client_msgs, worker_msgs
worker = ws.address

self.set_duration_estimate(ts, ws)
self._set_duration_estimate(ts, ws)
ts.processing_on = ws
ts.state = "processing"
self.consume_resources(ts, ws)
Expand Down Expand Up @@ -1986,7 +1986,7 @@ def transition_processing_memory(
steal = self.extensions.get("stealing")
for tts in s:
if tts.processing_on:
self.set_duration_estimate(tts, tts.processing_on)
self._set_duration_estimate(tts, tts.processing_on)
if steal:
steal.recalculate_cost(tts)

Expand Down Expand Up @@ -2509,7 +2509,7 @@ def transition_released_forgotten(self, key, stimulus_id):
# Assigning Tasks to Workers #
##############################

def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float:
def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None:
"""Estimate task duration using worker state and task state.
If a task takes longer than twice the current average duration we
Expand All @@ -2518,6 +2518,11 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float:
See also ``_remove_from_processing``
"""
# Long running tasks do not contribute to occupancy calculations and we
# do not set any task duration estimates
if ts in ws.long_running:
return

exec_time: float = ws.executing.get(ts, 0)
duration: float = self.get_task_duration(ts)
total_duration: float
Expand All @@ -2526,14 +2531,11 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float:
else:
comm: float = self.get_comm_cost(ts, ws)
total_duration = duration + comm

old = ws.processing.get(ts, 0)
ws.processing[ts] = total_duration

if ts not in ws.long_running:
self.total_occupancy += total_duration - old
ws.occupancy += total_duration - old

return total_duration
self.total_occupancy += total_duration - old
ws.occupancy += total_duration - old

def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
"""Update the status of the idle and saturated state
Expand Down Expand Up @@ -2745,7 +2747,7 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState):
ts: TaskState
old = ws.occupancy
for ts in ws.processing:
self.set_duration_estimate(ts, ws)
self._set_duration_estimate(ts, ws)

self.check_idle_saturated(ws)
steal = self.extensions.get("stealing")
Expand Down Expand Up @@ -7178,7 +7180,7 @@ def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None:
See also
--------
Scheduler.set_duration_estimate
Scheduler._set_duration_estimate
"""
ws = ts.processing_on
assert ws
Expand All @@ -7188,6 +7190,7 @@ def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None:
return None

duration = ws.processing.pop(ts)
ws.long_running.discard(ts)
if not ws.processing:
state.total_occupancy -= ws.occupancy
ws.occupancy = 0
Expand Down
90 changes: 77 additions & 13 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import zipfile
from collections import deque
from collections.abc import Generator
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from operator import add
from threading import Semaphore
Expand All @@ -40,6 +40,7 @@
CancelledError,
Event,
LocalCluster,
Lock,
Nanny,
TimeoutError,
Worker,
Expand Down Expand Up @@ -5115,41 +5116,104 @@ def f(x):
assert results == [sum(map(inc, range(10)))] * 10


@pytest.mark.parametrize("raise_exception", [True, False])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_long_running_not_in_occupancy(c, s, a):
async def test_long_running_not_in_occupancy(c, s, a, raise_exception):
# https://github.com/dask/distributed/issues/5332
from distributed import Lock
# See also test_long_running_removal_clean

l = Lock()
entered = Event()
await l.acquire()

def long_running(lock):
sleep(0.1)
def long_running(lock, entered):
entered.set()
secede()
lock.acquire()
if raise_exception:
raise RuntimeError("Exception in task")

f = c.submit(long_running, l)
while f.key not in s.tasks:
await asyncio.sleep(0.01)
assert s.workers[a.address].occupancy == parse_timedelta(
f = c.submit(long_running, l, entered)
await entered.wait()
ts = s.tasks[f.key]
ws = s.workers[a.address]
assert ws.occupancy == parse_timedelta(
dask.config.get("distributed.scheduler.unknown-task-duration")
)

while s.workers[a.address].occupancy:
while ws.occupancy:
await asyncio.sleep(0.01)
await a.heartbeat()

ts = s.tasks[f.key]
ws = s.workers[a.address]
s.set_duration_estimate(ts, ws)
s._set_duration_estimate(ts, ws)
assert s.workers[a.address].occupancy == 0
assert s.total_occupancy == 0
assert ws.occupancy == 0

s.reevaluate_occupancy(0)
assert s.workers[a.address].occupancy == 0
await l.release()

with (
pytest.raises(RuntimeError, match="Exception in task")
if raise_exception
else nullcontext()
):
await f

assert s.total_occupancy == 0
assert ws.occupancy == 0
assert not ws.long_running


@pytest.mark.parametrize("ordinary_task", [True, False])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_long_running_removal_clean(c, s, a, ordinary_task):
# https://github.com/dask/distributed/issues/5975 which could reduce
# occupancy to negative values upon finishing long running tasks

# See also test_long_running_not_in_occupancy

l = Lock()
entered = Event()
l2 = Lock()
entered2 = Event()
await l.acquire()
await l2.acquire()

def long_running_secede(lock, entered):
entered.set()
secede()
lock.acquire()

def long_running(lock, entered):
entered.set()
lock.acquire()

f = c.submit(long_running_secede, l, entered)
await entered.wait()

if ordinary_task:
f2 = c.submit(long_running, l2, entered2)
await entered2.wait()
await l.release()
await f

ws = s.workers[a.address]

if ordinary_task:
# Should be exactly 0.5 but if for whatever reason this test runs slow,
# some approximation may kick in increasing this number
assert s.total_occupancy >= 0.5
assert ws.occupancy >= 0.5
await l2.release()
await f2

# In the end, everything should be reset
assert s.total_occupancy == 0
assert ws.occupancy == 0
assert not ws.long_running


@gen_cluster(client=True)
async def test_sub_submit_priority(c, s, a, b):
Expand Down

0 comments on commit 4d29246

Please sign in to comment.