Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Colocate tasks #402

Merged
20 commits merged into from Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/webhook_events.md
Expand Up @@ -405,6 +405,10 @@ Each event will be submitted via HTTP POST to the user provided URL.
"items": {
"$ref": "#/definitions/TaskDebugFlag"
}
},
"colocate": {
"title": "Colocate",
"type": "boolean"
}
},
"required": [
Expand Down Expand Up @@ -1028,6 +1032,10 @@ Each event will be submitted via HTTP POST to the user provided URL.
"items": {
"$ref": "#/definitions/TaskDebugFlag"
}
},
"colocate": {
"title": "Colocate",
"type": "boolean"
}
},
"required": [
Expand Down
2 changes: 1 addition & 1 deletion src/api-service/__app__/onefuzzlib/agent_events.py
Expand Up @@ -130,6 +130,7 @@ def on_state_update(
# if tasks are running on the node when it reports as Done
# those are stopped early
node.mark_tasks_stopped_early()
node.to_reimage(done=True)
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved

# Model-validated.
#
Expand Down Expand Up @@ -242,7 +243,6 @@ def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Result[Non
node.debug_keep_node = True
node.save()

node.to_reimage(done=True)
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
task_event = TaskEvent(
task_id=task.task_id, machine_id=machine_id, event_data=WorkerEvent(done=event)
)
Expand Down
2 changes: 1 addition & 1 deletion src/api-service/__app__/onefuzzlib/tasks/main.py
Expand Up @@ -28,7 +28,7 @@


class Task(BASE_TASK, ORMMixin):
def ready_to_schedule(self) -> bool:
def check_prereq_tasks(self) -> bool:
if self.config.prereq_tasks:
for task_id in self.config.prereq_tasks:
task = Task.get_by_task_id(task_id)
Expand Down
280 changes: 198 additions & 82 deletions src/api-service/__app__/onefuzzlib/tasks/scheduler.py
Expand Up @@ -4,24 +4,30 @@
# Licensed under the MIT License.

import logging
from typing import Dict, List
from uuid import UUID
from typing import Dict, Generator, List, Optional, Tuple, TypeVar
from uuid import UUID, uuid4

from onefuzztypes.enums import OS, PoolState, TaskState
from onefuzztypes.models import WorkSet, WorkUnit

from ..azure.containers import (
StorageType,
blob_exists,
get_container_sas_url,
save_blob,
)
from onefuzztypes.models import TaskPool, TaskVm, WorkSet, WorkUnit
from pydantic import BaseModel

from ..azure.containers import StorageType, blob_exists, get_container_sas_url
from ..pools import Pool
from .config import build_task_config, get_setup_container
from .main import Task

HOURS = 60 * 60

# TODO: eventually, this should be tied to the pool.
MAX_TASKS_PER_SET = 10


A = TypeVar("A")


def chunks(items: List[A], size: int) -> Generator[List[A], None, None]:
return (items[x : x + size] for x in range(0, len(items), size))


def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
if pool.state not in PoolState.available():
Expand All @@ -39,88 +45,198 @@ def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
return True


def schedule_tasks() -> None:
to_schedule: Dict[UUID, List[Task]] = {}
# TODO - Once Pydantic supports hashable models, the Tuple should be replaced
# with a model.
#
# For info: https://github.com/samuelcolvin/pydantic/pull/1881


def bucket_tasks(tasks: List[Task]) -> Dict[Tuple, List[Task]]:
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
# buckets are hashed by:
# OS, JOB ID, vm sku & image (if available), pool name (if available),
# if the setup script requires rebooting, and a 'unique' value
#
# The unique value is set based on the following conditions:
# * if the task is set to run on more than one VM, than we assume it can't be shared
# * if the task is missing the 'colocate' flag or it's set to False

buckets: Dict[Tuple, List[Task]] = {}

for task in tasks:
vm: Optional[Tuple[str, str]] = None
pool: Optional[str] = None
unique: Optional[UUID] = None

# check for multiple VMs for pre-1.0.0 tasks
if task.config.vm:
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
vm = (task.config.vm.sku, task.config.vm.image)
if task.config.vm.count > 1:
unique = uuid4()

# check for multiple VMs for 1.0.0 and later tasks
if task.config.pool:
pool = task.config.pool.pool_name
if task.config.pool.count > 1:
unique = uuid4()

if not task.config.colocate:
unique = uuid4()

key = (
task.os,
task.job_id,
vm,
pool,
get_setup_container(task.config),
task.config.task.reboot_after_setup,
unique,
)
if key not in buckets:
buckets[key] = []
buckets[key].append(task)

return buckets


class BucketConfig(BaseModel):
count: int
reboot: bool
setup_url: str
setup_script: Optional[str]
pool: Pool


def build_work_unit(task: Task) -> Optional[Tuple[BucketConfig, WorkUnit]]:
pool = task.get_pool()
if not pool:
logging.info("unable to find pool for task: %s", task.task_id)
return None

logging.info("scheduling task: %s", task.task_id)

task_config = build_task_config(task.job_id, task.task_id, task.config)

setup_container = get_setup_container(task.config)
setup_url = get_container_sas_url(
setup_container, StorageType.corpus, read=True, list=True
)

setup_script = None

if task.os == OS.windows and blob_exists(
setup_container, "setup.ps1", StorageType.corpus
):
setup_script = "setup.ps1"
if task.os == OS.linux and blob_exists(
setup_container, "setup.sh", StorageType.corpus
):
setup_script = "setup.sh"

reboot = False
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
count = 1
if task.config.pool:
count = task.config.pool.count

# NOTE: "is True" is required to handle Optional[bool]
reboot = task.config.task.reboot_after_setup is True
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
elif task.config.vm:
# this branch should go away when we stop letting people specify
# VM configs directly.
count = task.config.vm.count

# NOTE: "is True" is required to handle Optional[bool]
reboot = (
task.config.vm.reboot_after_setup is True
or task.config.task.reboot_after_setup is True
)
else:
raise TypeError

work_unit = WorkUnit(
job_id=task_config.job_id,
task_id=task_config.task_id,
task_type=task_config.task_type,
config=task_config.json(),
)

bucket_config = BucketConfig(
pool=pool,
count=count,
reboot=reboot,
setup_script=setup_script,
setup_url=setup_url,
)

return bucket_config, work_unit


def build_work_set(tasks: List[Task]) -> Optional[Tuple[BucketConfig, WorkSet]]:
task_ids = [x.task_id for x in tasks]

bucket_config: Optional[BucketConfig] = None
work_units = []

for task in tasks:
if task.config.prereq_tasks:
# if all of the prereqs are in this bucket, they will be
# scheduled together
if not all([task_id in task_ids for task_id in task.config.prereq_tasks]):
if not task.check_prereq_tasks():
continue

result = build_work_unit(task)
if not result:
continue

not_ready_count = 0
new_bucket_config, work_unit = result
if bucket_config is None:
bucket_config = new_bucket_config
else:
if bucket_config != new_bucket_config:
raise Exception(
f"bucket configs differ: {bucket_config} VS {new_bucket_config}"
)

for task in Task.search_states(states=[TaskState.waiting]):
if not task.ready_to_schedule():
not_ready_count += 1
continue
work_units.append(work_unit)

if task.job_id not in to_schedule:
to_schedule[task.job_id] = []
to_schedule[task.job_id].append(task)
if bucket_config:
work_set = WorkSet(
reboot=bucket_config.reboot,
script=(bucket_config.setup_script is not None),
setup_url=bucket_config.setup_url,
work_units=work_units,
)
return (bucket_config, work_set)

if not to_schedule and not_ready_count > 0:
logging.info("tasks not ready: %d", not_ready_count)
return None

for tasks in to_schedule.values():
# TODO: for now, we're only scheduling one task per VM.

for task in tasks:
logging.info("scheduling task: %s", task.task_id)
agent_config = build_task_config(task.job_id, task.task_id, task.config)
def schedule_tasks() -> None:
tasks: List[Task] = []

setup_container = get_setup_container(task.config)
setup_url = get_container_sas_url(
setup_container, StorageType.corpus, read=True, list=True
)
tasks = Task.search_states(states=[TaskState.waiting])

setup_script = None

if task.os == OS.windows and blob_exists(
setup_container, "setup.ps1", StorageType.corpus
):
setup_script = "setup.ps1"
if task.os == OS.linux and blob_exists(
setup_container, "setup.sh", StorageType.corpus
):
setup_script = "setup.sh"

save_blob(
"task-configs",
"%s/config.json" % task.task_id,
agent_config.json(exclude_none=True),
StorageType.config,
)
reboot = False
count = 1
if task.config.pool:
count = task.config.pool.count
reboot = task.config.task.reboot_after_setup is True
elif task.config.vm:
# this branch should go away when we stop letting people specify
# VM configs directly.
count = task.config.vm.count
reboot = (
task.config.vm.reboot_after_setup is True
or task.config.task.reboot_after_setup is True
)
tasks_by_id = {x.task_id: x for x in tasks}
seen = set()

task_config = agent_config
task_config_json = task_config.json()
work_unit = WorkUnit(
job_id=task_config.job_id,
task_id=task_config.task_id,
task_type=task_config.task_type,
config=task_config_json,
)
not_ready_count = 0

# For now, only offer singleton work sets.
workset = WorkSet(
reboot=reboot,
script=(setup_script is not None),
setup_url=setup_url,
work_units=[work_unit],
)
buckets = bucket_tasks(tasks)

pool = task.get_pool()
if not pool:
logging.info("unable to find pool for task: %s", task.task_id)
for bucketed_tasks in buckets.values():
for chunk in chunks(bucketed_tasks, MAX_TASKS_PER_SET):
result = build_work_set(chunk)
if result is None:
continue
bucket_config, work_set = result

if schedule_workset(work_set, bucket_config.pool, bucket_config.count):
for work_unit in work_set.work_units:
task = tasks_by_id[work_unit.task_id]
task.state = TaskState.scheduled
task.save()
seen.add(task.task_id)

if schedule_workset(workset, pool, count):
task.state = TaskState.scheduled
task.save()
not_ready_count = len(tasks) - len(seen)
if not_ready_count > 0:
logging.info("tasks not ready: %d", not_ready_count)