From b246a3800eadf7c4acd8ed8b62e90b9ef4df808e Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 08:33:54 -0500 Subject: [PATCH 01/17] cleanup --- wandb/sdk/launch/deploys/Dockerfile | 1 + .../launch/sweeps/Dockerfile.scheduler.sweep | 18 ++++++++ wandb/sdk/launch/sweeps/scheduler.py | 41 +++++++------------ wandb/sdk/launch/sweeps/scheduler_sweep.py | 1 + 4 files changed, 35 insertions(+), 26 deletions(-) create mode 100644 wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep diff --git a/wandb/sdk/launch/deploys/Dockerfile b/wandb/sdk/launch/deploys/Dockerfile index df9db592ab9..350ce059695 100644 --- a/wandb/sdk/launch/deploys/Dockerfile +++ b/wandb/sdk/launch/deploys/Dockerfile @@ -1,4 +1,5 @@ FROM python:3.9-slim-bullseye +LABEL maintainer='Weights & Biases ' # install git RUN apt-get update && apt-get upgrade -y \ diff --git a/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep b/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep new file mode 100644 index 00000000000..f7570b02f34 --- /dev/null +++ b/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep @@ -0,0 +1,18 @@ +FROM python:3.9-slim-bullseye +LABEL maintainer='Weights & Biases ' + +# install git +RUN apt-get update && apt-get upgrade -y \ + && apt-get install -y git \ + && apt-get -qy autoremove \ + && apt-get clean && rm -r /var/lib/apt/lists/* + +# required pip packages +RUN pip install --no-cache-dir wandb[launch] +# user set up +RUN useradd -m -s /bin/bash --create-home --no-log-init -u 1000 -g 0 wandb_scheduler +USER wandb_scheduler +WORKDIR /home/wandb_scheduler +RUN chown -R wandb_scheduler /home/wandb_scheduler + +ENTRYPOINT ["wandb", "scheduler"] diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index 8acb21d761f..b68e7365be7 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -1,3 +1,4 @@ +"""Abstract Scheduler class.""" from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -126,9 +127,7 @@ def start(self) -> None: self.run() def run(self) -> None: - _msg = f"{LOG_PREFIX}Scheduler Running." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler Running.") self.state = SchedulerState.RUNNING try: while True: @@ -138,27 +137,19 @@ def run(self) -> None: self._update_run_states() self._run() except RuntimeError as e: - _msg = f"{LOG_PREFIX}Scheduler encountered Runtime Error. {e} Trying again." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler encountered Runtime Error. {e} Trying again.") except KeyboardInterrupt: - _msg = f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting.") self.state = SchedulerState.STOPPED self.exit() return except Exception as e: - _msg = f"{LOG_PREFIX}Scheduler failed with exception {e}" - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler failed with exception {e}") self.state = SchedulerState.FAILED self.exit() raise e else: - _msg = "Scheduler completed." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler completed.") self.exit() def exit(self) -> None: @@ -176,6 +167,12 @@ def _yield_runs(self) -> Iterator[Tuple[str, SweepRun]]: with self._threading_lock: yield from self._runs.items() + def _stop_run(self, run_id: str) -> None: + wandb.termlog(f"{LOG_PREFIX}Stopping run {run_id}.") + run = self._runs.get(run_id, None) + if run is not None: + run.state = SimpleRunState.DEAD + def _update_run_states(self) -> None: for run_id, run in self._yield_runs(): try: @@ -220,16 +217,8 @@ def _add_to_launch_queue( resource_args=self._resource_args, run_id=run_id, ) - self._runs[run_id].queued_run = queued_run - _msg = f"{LOG_PREFIX}Added run to Launch RunQueue: {self._launch_queue} RunID:{run_id}." - logger.debug(_msg) - wandb.termlog(_msg) + self._runs[run_id].queued_run: public.QueuedRun = queued_run + wandb.termlog(f"{LOG_PREFIX}Added run to Launch RunQueue: {self._launch_queue} RunID:{run_id}.") return queued_run - def _stop_run(self, run_id: str) -> None: - _msg = f"{LOG_PREFIX}Stopping run {run_id}." - logger.debug(_msg) - wandb.termlog(_msg) - run = self._runs.get(run_id, None) - if run is not None: - run.state = SimpleRunState.DEAD + diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index c2663ab9a8c..cafa81e33a8 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -1,3 +1,4 @@ +"""Scheduler for classic wandb Sweeps.""" from dataclasses import dataclass import logging import os From 7aca2f99fc26c7bb1131f5d8d45f90aa642be885 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 09:19:29 -0500 Subject: [PATCH 02/17] more cleanup --- tests/unit_tests/test_sweep_scheduler.py | 2 +- wandb/sdk/launch/sweeps/scheduler.py | 17 +++-- wandb/sdk/launch/sweeps/scheduler_sweep.py | 79 ++++++++++++---------- 3 files changed, 57 insertions(+), 41 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 681e9f6c7d5..fe9ce3d9da4 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -246,7 +246,7 @@ def mock_get_run_state(*args, **kwargs): assert _scheduler.state == SchedulerState.PENDING assert _scheduler.is_alive() is True _scheduler.start() - for _heartbeat_agent in _scheduler._heartbeat_agents: + for _heartbeat_agent in _scheduler._workers: assert not _heartbeat_agent.thread.is_alive() assert _scheduler.state == SchedulerState.COMPLETED assert len(_scheduler._runs) == 1 diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index b68e7365be7..55b1be50109 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -41,6 +41,9 @@ class SweepRun: args: Optional[Dict[str, Any]] = None logs: Optional[List[str]] = None program: Optional[str] = None + # Threading can be used to run multiple workers in parallel + worker_id: Optional[int] = None + worker_thread: Optional[threading.Thread] = None class Scheduler(ABC): @@ -159,19 +162,21 @@ def exit(self) -> None: SchedulerState.STOPPED, ]: self.state = SchedulerState.FAILED - for run_id, _ in self._yield_runs(): - self._stop_run(run_id) + self._stop_runs() def _yield_runs(self) -> Iterator[Tuple[str, SweepRun]]: """Thread-safe way to iterate over the runs.""" with self._threading_lock: yield from self._runs.items() + def _stop_runs(self) -> None: + for run_id, _ in self._yield_runs(): + wandb.termlog(f"{LOG_PREFIX}Stopping run {run_id}.") + self._stop_run(run_id) + + @abstractmethod def _stop_run(self, run_id: str) -> None: - wandb.termlog(f"{LOG_PREFIX}Stopping run {run_id}.") - run = self._runs.get(run_id, None) - if run is not None: - run.state = SimpleRunState.DEAD + pass def _update_run_states(self) -> None: for run_id, run in self._yield_runs(): diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index cafa81e33a8..f7e28c2814c 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -7,7 +7,7 @@ import socket import threading import time -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import wandb from wandb import wandb_lib # type: ignore @@ -25,10 +25,11 @@ @dataclass -class HeartbeatAgent: - agent: dict - id: str +class _Worker: + agent_config: Dict[str, Any] + agent_id: str thread: threading.Thread + stop: threading.Event class SweepScheduler(Scheduler): @@ -41,9 +42,9 @@ def __init__( *args: Any, sweep_id: Optional[str] = None, num_workers: int = 4, - heartbeat_thread_sleep: float = 0.5, + worker_sleep: float = 0.5, heartbeat_queue_timeout: float = 1, - main_thread_sleep: float = 1, + heartbeat_queue_sleep: float = 1, **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -57,45 +58,52 @@ def __init__( ) self._sweep_id = sweep_id # Threading is used to run multiple workers in parallel + self._workers: List[_Worker] = [] self._num_workers: int = num_workers - self._heartbeat_thread_sleep: float = heartbeat_thread_sleep - self._heartbeat_queue_timeout: float = heartbeat_queue_timeout - self._main_thread_sleep: float = main_thread_sleep + self._worker_sleep: float = worker_sleep # Thread will pop items off the Sweeps RunQueue using AgentHeartbeat # and put them in this internal queue, which will be used to populate # the Launch RunQueue self._heartbeat_queue: "queue.Queue[SweepRun]" = queue.Queue() - # Emulation of N agents in a classic sweeps setup - self._heartbeat_agents: List[HeartbeatAgent] = [] + self._heartbeat_queue_timeout: float = heartbeat_queue_timeout + self._heartbeat_queue_sleep: float = heartbeat_queue_sleep def _start(self) -> None: - for worker_idx in range(self._num_workers): - logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker {worker_idx}\n") - _agent = self._api.register_agent( - f"{socket.gethostname()}-{worker_idx}", # host + for worker_id in range(self._num_workers): + logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker {worker_id}\n") + agent_config = self._api.register_agent( + f"{socket.gethostname()}-{worker_id}", # host sweep_id=self._sweep_id, project_name=self._project, entity=self._entity, ) - _thread = threading.Thread(target=self._heartbeat, args=[worker_idx]) + # Worker threads call heartbeat function + _thread = threading.Thread(target=self._heartbeat, args=[worker_id]) _thread.daemon = True - self._heartbeat_agents.append( - HeartbeatAgent( - agent=_agent, - id=_agent["id"], + self._workers.append( + _Worker( + agent_config=agent_config, + agent_id=agent_config["id"], thread=_thread, + # Worker threads will be killed with an Event + stop = threading.Event(), ) ) _thread.start() - def _heartbeat(self, worker_idx: int) -> None: + def _heartbeat(self, worker_id: int) -> None: while True: + # Make sure Scheduler is alive if not self.is_alive(): return + # Check to see if worker thread has been orderred to stop + if self._workers[worker_id].stop.is_set(): + return # AgentHeartbeat wants dict of runs which are running or queued _run_states = {} for run_id, run in self._yield_runs(): - if run.state == SimpleRunState.ALIVE: + # Filter out runs that are from a different worker thread + if run.worker_id == worker_id and run.state == SimpleRunState.ALIVE: _run_states[run_id] = True _msg = ( f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n" @@ -103,7 +111,7 @@ def _heartbeat(self, worker_idx: int) -> None: logger.debug(_msg) # TODO(hupo): Should be sub-set of _run_states specific to worker thread commands = self._api.agent_heartbeat( - self._heartbeat_agents[worker_idx].id, {}, _run_states + self._workers[worker_id].agent_id, {}, _run_states ) if commands: _msg = f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n" @@ -131,7 +139,7 @@ def _heartbeat(self, worker_idx: int) -> None: if _type in ["run", "resume"]: self._heartbeat_queue.put(run) continue - time.sleep(self._heartbeat_thread_sleep) + time.sleep(self._worker_sleep) def _run(self) -> None: try: @@ -139,10 +147,8 @@ def _run(self) -> None: timeout=self._heartbeat_queue_timeout ) except queue.Empty: - _msg = f"{LOG_PREFIX}No jobs in Sweeps RunQueue, waiting..." - logger.debug(_msg) - wandb.termlog(_msg) - time.sleep(self._main_thread_sleep) + wandb.termlog(f"{LOG_PREFIX}No jobs in Sweeps RunQueue, waiting...") + time.sleep(self._heartbeat_queue_sleep) return # If run is already stopped just ignore the request if run.state in [ @@ -150,9 +156,7 @@ def _run(self) -> None: SimpleRunState.UNKNOWN, ]: return - _msg = f"{LOG_PREFIX}Converting Sweep Run (RunID:{run.id}) to Launch Job" - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Converting Sweep Run (RunID:{run.id}) to Launch Job") # This is actually what populates the wandb config # since it is used in wandb.init() sweep_param_path = os.path.join( @@ -161,9 +165,7 @@ def _run(self) -> None: f"sweep-{self._sweep_id}", f"config-{run.id}.yaml", ) - _msg = f"{LOG_PREFIX}Saving params to {sweep_param_path}" - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Saving params to {sweep_param_path}") wandb_lib.config_util.save_config_file_from_dict(sweep_param_path, run.args) # Construct entry point using legacy sweeps utilities command_args = LegacySweepAgent._create_command_args({"args": run.args})["args"] @@ -173,5 +175,14 @@ def _run(self) -> None: run_id=run.id, ) + def _stop_run(self, run_id: str) -> None: + run = self._runs.get(run_id, None) + if run is not None: + # Set threading event to stop the worker thread + if self._workers[run.worker_id].thread.is_alive(): + self._workers[run.worker_id].stop.set() + run.state = SimpleRunState.DEAD + + def _exit(self) -> None: self.state = SchedulerState.COMPLETED From 41682bf230343a87de9fb846cd905b9b7cff5156 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 09:41:35 -0500 Subject: [PATCH 03/17] remove launch properties from scheduler --- wandb/sdk/launch/sweeps/scheduler.py | 59 +++++++++------------- wandb/sdk/launch/sweeps/scheduler_sweep.py | 9 ++-- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index 55b1be50109..d586de93514 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -54,16 +54,10 @@ class Scheduler(ABC): def __init__( self, api: Api, - *args: Any, + *args: Optional[Any], entity: Optional[str] = None, project: Optional[str] = None, - # ------- Begin Launch Options ------- - queue: Optional[str] = None, - job: Optional[str] = None, - resource: Optional[str] = None, - resource_args: Optional[Dict[str, Any]] = None, - # ------- End Launch Options ------- - **kwargs: Any, + **kwargs: Optional[Any], ): self._api = api self._entity = ( @@ -75,18 +69,10 @@ def __init__( self._project = ( project or os.environ.get("WANDB_PROJECT") or api.settings("project") ) - # ------- Begin Launch Options ------- - # TODO(hupo): Validation on these arguments. - self._launch_queue = queue - self._job = job - self._resource = resource - self._resource_args = resource_args - if resource == "kubernetes": - self._resource_args = {"kubernetes": {}} - # ------- End Launch Options ------- self._state: SchedulerState = SchedulerState.PENDING self._threading_lock: threading.Lock = threading.Lock() self._runs: Dict[str, SweepRun] = {} + self._kwargs: Dict[str, Any] = kwargs @abstractmethod def _start(self) -> None: @@ -107,9 +93,7 @@ def state(self) -> SchedulerState: @state.setter def state(self, value: SchedulerState) -> None: - logger.debug( - f"{LOG_PREFIX}Changing Scheduler state from {self.state.name} to {value.name}" - ) + logger.debug(f"{LOG_PREFIX}Scheduler was {self.state.name} is {value.name}") self._state = value def is_alive(self) -> bool: @@ -122,9 +106,7 @@ def is_alive(self) -> bool: return True def start(self) -> None: - _msg = f"{LOG_PREFIX}Scheduler starting." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler starting.") self._state = SchedulerState.STARTING self._start() self.run() @@ -140,7 +122,9 @@ def run(self) -> None: self._update_run_states() self._run() except RuntimeError as e: - wandb.termlog(f"{LOG_PREFIX}Scheduler encountered Runtime Error. {e} Trying again.") + wandb.termlog( + f"{LOG_PREFIX}Scheduler encountered Runtime Error. {e} Trying again." + ) except KeyboardInterrupt: wandb.termlog(f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting.") self.state = SchedulerState.STOPPED @@ -210,20 +194,25 @@ def _add_to_launch_queue( ) -> "public.QueuedRun": """Add a launch job to the Launch RunQueue.""" run_id = run_id or generate_id() + # One of Job and URI is required + _job = self._kwargs.get("job", None) + _uri = self._kwargs.get("uri", None) + if _job is None and _uri is None: + # If no Job is specified, use a placeholder URI to prevent Launch failure + _uri = "placeholder-uri-queuedrun-from-scheduler" queued_run = launch_add( - # TODO(hupo): If no Job is specified, use a placeholder URI to prevent Launch failure - uri=None if self._job is not None else "placeholder-uri-queuedrun", - job=self._job, + run_id=run_id, + entry_point=entry_point, + uri=_uri, + job=_job, project=self._project, entity=self._entity, - queue=self._launch_queue, - entry_point=entry_point, - resource=self._resource, - resource_args=self._resource_args, - run_id=run_id, + queue=self._kwargs.get("queue", None), + resource=self._kwargs.get("resource", None), + resource_args=self._kwargs.get("resource_args", None), ) self._runs[run_id].queued_run: public.QueuedRun = queued_run - wandb.termlog(f"{LOG_PREFIX}Added run to Launch RunQueue: {self._launch_queue} RunID:{run_id}.") + wandb.termlog( + f"{LOG_PREFIX}Added run to Launch RunQueue: {self._launch_queue} RunID:{run_id}." + ) return queued_run - - diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index f7e28c2814c..a7a2944b47b 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -56,7 +56,7 @@ def __init__( raise SchedulerError( f"{LOG_PREFIX}Could not find sweep {self._entity}/{self._project}/{sweep_id}" ) - self._sweep_id = sweep_id + self._sweep_id: str = sweep_id # Threading is used to run multiple workers in parallel self._workers: List[_Worker] = [] self._num_workers: int = num_workers @@ -86,7 +86,7 @@ def _start(self) -> None: agent_id=agent_config["id"], thread=_thread, # Worker threads will be killed with an Event - stop = threading.Event(), + stop=threading.Event(), ) ) _thread.start() @@ -156,7 +156,9 @@ def _run(self) -> None: SimpleRunState.UNKNOWN, ]: return - wandb.termlog(f"{LOG_PREFIX}Converting Sweep Run (RunID:{run.id}) to Launch Job") + wandb.termlog( + f"{LOG_PREFIX}Converting Sweep Run (RunID:{run.id}) to Launch Job" + ) # This is actually what populates the wandb config # since it is used in wandb.init() sweep_param_path = os.path.join( @@ -183,6 +185,5 @@ def _stop_run(self, run_id: str) -> None: self._workers[run.worker_id].stop.set() run.state = SimpleRunState.DEAD - def _exit(self) -> None: self.state = SchedulerState.COMPLETED From 5bfbd3eb8f103e0465af070ff6f722000f482d39 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 10:27:57 -0500 Subject: [PATCH 04/17] simplified worker heartbeat --- wandb/sdk/launch/sweeps/scheduler_sweep.py | 50 ++++++++++------------ 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index a7a2944b47b..c1e9edbe0ed 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -57,7 +57,10 @@ def __init__( f"{LOG_PREFIX}Could not find sweep {self._entity}/{self._project}/{sweep_id}" ) self._sweep_id: str = sweep_id - # Threading is used to run multiple workers in parallel + # Threading is used to run multiple workers in parallel. Workers do not + # actually run training workloads, they simply send heartbeat messages + # (emulating a real agent) and add new runs to the launch queue. The + # launch agent is the one that actually runs the training workloads. self._workers: List[_Worker] = [] self._num_workers: int = num_workers self._worker_sleep: float = worker_sleep @@ -99,44 +102,37 @@ def _heartbeat(self, worker_id: int) -> None: # Check to see if worker thread has been orderred to stop if self._workers[worker_id].stop.is_set(): return - # AgentHeartbeat wants dict of runs which are running or queued - _run_states = {} + # AgentHeartbeat wants a Dict of runs which are running or queued + _run_states: Dict[str, bool] = {} for run_id, run in self._yield_runs(): # Filter out runs that are from a different worker thread if run.worker_id == worker_id and run.state == SimpleRunState.ALIVE: _run_states[run_id] = True - _msg = ( - f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n" - ) - logger.debug(_msg) - # TODO(hupo): Should be sub-set of _run_states specific to worker thread - commands = self._api.agent_heartbeat( - self._workers[worker_id].agent_id, {}, _run_states + logger.debug(f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n") + commands: List[Dict[str, Any]] = self._api.agent_heartbeat( + self._workers[worker_id].agent_id, # agent_id: str + {}, # metrics: dict + _run_states # run_states: dict ) + logger.debug(f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n") if commands: - _msg = f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n" - logger.debug(_msg) for command in commands: + # The command "type" can be one of "run", "resume", "stop", "exit" _type = command.get("type") - # type can be one of "run", "resume", "stop", "exit" - if _type == "exit": - self.state = SchedulerState.COMPLETED - self.exit() - return - if _type == "stop": - # TODO(hupo): Debug edge cases while stopping with active runs + if _type in ["exit", "stop"]: + # (virtual) agent should stop running + self._workers[worker_id].stop.set() self.state = SchedulerState.COMPLETED self.exit() return - run = SweepRun( - id=command.get("run_id"), - args=command.get("args"), - logs=command.get("logs"), - program=command.get("program"), - ) - with self._threading_lock: - self._runs[run.id] = run if _type in ["run", "resume"]: + run = SweepRun( + id=command.get("run_id"), + args=command.get("args"), + logs=command.get("logs"), + program=command.get("program"), + ) + self._runs[run.id] = run self._heartbeat_queue.put(run) continue time.sleep(self._worker_sleep) From fc258278223bf3d26de5007440e9f980f0eaeebe Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 10:45:50 -0500 Subject: [PATCH 05/17] get rid of dead runs --- wandb/sdk/launch/sweeps/scheduler.py | 9 +++++++- wandb/sdk/launch/sweeps/scheduler_sweep.py | 26 +++++++++++++--------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index d586de93514..f6746ef7c8a 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -30,7 +30,7 @@ class SchedulerState(Enum): class SimpleRunState(Enum): ALIVE = 0 DEAD = 1 - UNKNOWN = 3 + UNKNOWN = 2 @dataclass @@ -71,6 +71,7 @@ def __init__( ) self._state: SchedulerState = SchedulerState.PENDING self._threading_lock: threading.Lock = threading.Lock() + # List of the runs managed by the scheduler self._runs: Dict[str, SweepRun] = {} self._kwargs: Dict[str, Any] = kwargs @@ -186,6 +187,12 @@ def _update_run_states(self) -> None: wandb.termlog(_msg) run.state = SimpleRunState.UNKNOWN continue + # Remove any runs that are dead + with self._threading_lock: + for run_id, run in self._runs.items(): + if run.state == SimpleRunState.DEAD: + wandb.termlog(f"{LOG_PREFIX}Removing dead run {run_id}.") + del self._runs[run_id] def _add_to_launch_queue( self, diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index c1e9edbe0ed..7e155f6318b 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -108,21 +108,23 @@ def _heartbeat(self, worker_id: int) -> None: # Filter out runs that are from a different worker thread if run.worker_id == worker_id and run.state == SimpleRunState.ALIVE: _run_states[run_id] = True - logger.debug(f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n") + logger.debug( + f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n" + ) commands: List[Dict[str, Any]] = self._api.agent_heartbeat( - self._workers[worker_id].agent_id, # agent_id: str - {}, # metrics: dict - _run_states # run_states: dict + self._workers[worker_id].agent_id, # agent_id: str + {}, # metrics: dict + _run_states, # run_states: dict + ) + logger.debug( + f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n" ) - logger.debug(f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n") if commands: for command in commands: # The command "type" can be one of "run", "resume", "stop", "exit" _type = command.get("type") if _type in ["exit", "stop"]: - # (virtual) agent should stop running - self._workers[worker_id].stop.set() - self.state = SchedulerState.COMPLETED + # Tell (virtual) agent to stop running self.exit() return if _type in ["run", "resume"]: @@ -131,6 +133,7 @@ def _heartbeat(self, worker_id: int) -> None: args=command.get("args"), logs=command.get("logs"), program=command.get("program"), + worker_id=worker_id, ) self._runs[run.id] = run self._heartbeat_queue.put(run) @@ -176,9 +179,10 @@ def _run(self) -> None: def _stop_run(self, run_id: str) -> None: run = self._runs.get(run_id, None) if run is not None: - # Set threading event to stop the worker thread - if self._workers[run.worker_id].thread.is_alive(): - self._workers[run.worker_id].stop.set() + _worker = self._workers.get(run.worker_id, None) + if _worker and _worker.thread.is_alive(): + # Set threading event to stop the worker thread + _worker.stop.set() run.state = SimpleRunState.DEAD def _exit(self) -> None: From a817713596a7ab0052fa99f8bef65e7a3cf13afb Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 11:02:01 -0500 Subject: [PATCH 06/17] formatting --- wandb/sdk/launch/sweeps/scheduler.py | 22 ++++++++++---- wandb/sdk/launch/sweeps/scheduler_sweep.py | 35 +++++++--------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index f6746ef7c8a..fa563daadc4 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -12,6 +12,7 @@ from wandb.apis.internal import Api import wandb.apis.public as public from wandb.sdk.launch.launch_add import launch_add +from wandb.sdk.launch.sweeps import SchedulerError from wandb.sdk.lib.runid import generate_id logger = logging.getLogger(__name__) @@ -55,6 +56,7 @@ def __init__( self, api: Api, *args: Optional[Any], + sweep_id: str = None, entity: Optional[str] = None, project: Optional[str] = None, **kwargs: Optional[Any], @@ -69,6 +71,15 @@ def __init__( self._project = ( project or os.environ.get("WANDB_PROJECT") or api.settings("project") ) + # Make sure the provided sweep_id corresponds to a valid sweep + found = self._api.sweep( + sweep_id, "{}", entity=self._entity, project=self._project + ) + if not found: + raise SchedulerError( + f"{LOG_PREFIX}Could not find sweep {self._entity}/{self._project}/{sweep_id}" + ) + self._sweep_id: str = sweep_id or "empty-sweep-id" self._state: SchedulerState = SchedulerState.PENDING self._threading_lock: threading.Lock = threading.Lock() # List of the runs managed by the scheduler @@ -207,6 +218,8 @@ def _add_to_launch_queue( if _job is None and _uri is None: # If no Job is specified, use a placeholder URI to prevent Launch failure _uri = "placeholder-uri-queuedrun-from-scheduler" + # Queue is required + _queue = self._kwargs.get("queue", "default") queued_run = launch_add( run_id=run_id, entry_point=entry_point, @@ -214,12 +227,11 @@ def _add_to_launch_queue( job=_job, project=self._project, entity=self._entity, - queue=self._kwargs.get("queue", None), - resource=self._kwargs.get("resource", None), - resource_args=self._kwargs.get("resource_args", None), + queue=_queue, + **self._kwargs, ) - self._runs[run_id].queued_run: public.QueuedRun = queued_run + self._runs[run_id].queued_run = queued_run wandb.termlog( - f"{LOG_PREFIX}Added run to Launch RunQueue: {self._launch_queue} RunID:{run_id}." + f"{LOG_PREFIX}Added run to Launch RunQueue: {_queue} RunID:{run_id}." ) return queued_run diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index 7e155f6318b..b260cbe2676 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -7,11 +7,10 @@ import socket import threading import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import wandb from wandb import wandb_lib # type: ignore -from wandb.sdk.launch.sweeps import SchedulerError from wandb.sdk.launch.sweeps.scheduler import ( LOG_PREFIX, Scheduler, @@ -40,7 +39,6 @@ class SweepScheduler(Scheduler): def __init__( self, *args: Any, - sweep_id: Optional[str] = None, num_workers: int = 4, worker_sleep: float = 0.5, heartbeat_queue_timeout: float = 1, @@ -48,20 +46,11 @@ def __init__( **kwargs: Any, ): super().__init__(*args, **kwargs) - # Make sure the provided sweep_id corresponds to a valid sweep - found = self._api.sweep( - sweep_id, "{}", entity=self._entity, project=self._project - ) - if not found: - raise SchedulerError( - f"{LOG_PREFIX}Could not find sweep {self._entity}/{self._project}/{sweep_id}" - ) - self._sweep_id: str = sweep_id # Threading is used to run multiple workers in parallel. Workers do not # actually run training workloads, they simply send heartbeat messages # (emulating a real agent) and add new runs to the launch queue. The # launch agent is the one that actually runs the training workloads. - self._workers: List[_Worker] = [] + self._workers: Dict[int, _Worker] = [] self._num_workers: int = num_workers self._worker_sleep: float = worker_sleep # Thread will pop items off the Sweeps RunQueue using AgentHeartbeat @@ -83,14 +72,12 @@ def _start(self) -> None: # Worker threads call heartbeat function _thread = threading.Thread(target=self._heartbeat, args=[worker_id]) _thread.daemon = True - self._workers.append( - _Worker( - agent_config=agent_config, - agent_id=agent_config["id"], - thread=_thread, - # Worker threads will be killed with an Event - stop=threading.Event(), - ) + self._workers[worker_id] = _Worker( + agent_config=agent_config, + agent_id=agent_config["id"], + thread=_thread, + # Worker threads will be killed with an Event + stop=threading.Event(), ) _thread.start() @@ -129,9 +116,9 @@ def _heartbeat(self, worker_id: int) -> None: return if _type in ["run", "resume"]: run = SweepRun( - id=command.get("run_id"), - args=command.get("args"), - logs=command.get("logs"), + id=command.get("run_id", "empty-run-id"), + args=command.get("args", {}), + logs=command.get("logs", []), program=command.get("program"), worker_id=worker_id, ) From 0cf8439c82e7ebcc78aa2806c5382eacae86c04a Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 12:16:42 -0500 Subject: [PATCH 07/17] progress on tests --- tests/unit_tests/test_sweep_scheduler.py | 128 +++++++++++++---------- wandb/sdk/launch/sweeps/scheduler.py | 20 ++-- 2 files changed, 84 insertions(+), 64 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index fe9ce3d9da4..ba0dbda374a 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest +import wandb from wandb.apis import internal, public from wandb.sdk.launch.sweeps import load_scheduler, SchedulerError from wandb.sdk.launch.sweeps.scheduler import ( @@ -12,6 +13,8 @@ ) from wandb.sdk.launch.sweeps.scheduler_sweep import SweepScheduler +from .test_wandb_sweep import VALID_SWEEP_CONFIGS + def test_sweep_scheduler_load(): _scheduler = load_scheduler("sweep") @@ -21,70 +24,89 @@ def test_sweep_scheduler_load(): @patch.multiple(Scheduler, __abstractmethods__=set()) -def test_sweep_scheduler_base_state(monkeypatch): - api = internal.Api() +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) +def test_sweep_scheduler_entity_project_sweep_id(user, relay_server, sweep_config): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + # Entity, project, and sweep should be everything you need to create a scheduler + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + _ = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + # Bogus sweep id should result in error + with pytest.raises(SchedulerError): + _ = Scheduler( + api, sweep_id="foo-sweep-id", entity=_entity, project=_project + ) - def mock_run_complete_scheduler(self, *args, **kwargs): - self.state = SchedulerState.COMPLETED - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_complete_scheduler, - ) - - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.COMPLETED - assert _scheduler.is_alive() is False +@patch.multiple(Scheduler, __abstractmethods__=set()) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) +def test_sweep_scheduler_base_scheduler_state( + user, relay_server, sweep_config, monkeypatch +): + + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + + def mock_run_complete_scheduler(self, *args, **kwargs): + self.state = SchedulerState.COMPLETED + + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_complete_scheduler, + ) + + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + assert _scheduler.state == SchedulerState.PENDING + assert _scheduler.is_alive() is True + _scheduler.start() + assert _scheduler.state == SchedulerState.COMPLETED + assert _scheduler.is_alive() is False - def mock_run_raise_keyboard_interupt(*args, **kwargs): - raise KeyboardInterrupt + def mock_run_raise_keyboard_interupt(*args, **kwargs): + raise KeyboardInterrupt - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_raise_keyboard_interupt, - ) + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_raise_keyboard_interupt, + ) - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.STOPPED - assert _scheduler.is_alive() is False + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + _scheduler.start() + assert _scheduler.state == SchedulerState.STOPPED + assert _scheduler.is_alive() is False - def mock_run_raise_exception(*args, **kwargs): - raise Exception("Generic exception") + def mock_run_raise_exception(*args, **kwargs): + raise Exception("Generic exception") - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_raise_exception, - ) + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_raise_exception, + ) - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - with pytest.raises(Exception) as e: - _scheduler.start() - assert "Generic exception" in str(e.value) - assert _scheduler.state == SchedulerState.FAILED - assert _scheduler.is_alive() is False + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + with pytest.raises(Exception) as e: + _scheduler.start() + assert "Generic exception" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False - def mock_run_exit(self, *args, **kwargs): - self.exit() + def mock_run_exit(self, *args, **kwargs): + self.exit() - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_exit, - ) + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_exit, + ) - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.FAILED - assert _scheduler.is_alive() is False + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + _scheduler.start() + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False @patch.multiple(Scheduler, __abstractmethods__=set()) diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index fa563daadc4..c81d8baa1fb 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from enum import Enum import logging +from multiprocessing.sharedctypes import Value import os import threading from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -72,13 +73,10 @@ def __init__( project or os.environ.get("WANDB_PROJECT") or api.settings("project") ) # Make sure the provided sweep_id corresponds to a valid sweep - found = self._api.sweep( - sweep_id, "{}", entity=self._entity, project=self._project - ) - if not found: - raise SchedulerError( - f"{LOG_PREFIX}Could not find sweep {self._entity}/{self._project}/{sweep_id}" - ) + try: + self._api.sweep(sweep_id, "{}", entity=self._entity, project=self._project) + except Exception as e: + raise SchedulerError(f"{LOG_PREFIX}Exception when finding sweep: {e}") self._sweep_id: str = sweep_id or "empty-sweep-id" self._state: SchedulerState = SchedulerState.PENDING self._threading_lock: threading.Lock = threading.Lock() @@ -98,6 +96,10 @@ def _run(self) -> None: def _exit(self) -> None: pass + @abstractmethod + def _stop_run(self, run_id: str) -> None: + pass + @property def state(self) -> SchedulerState: logger.debug(f"{LOG_PREFIX}Scheduler state is {self._state.name}") @@ -170,10 +172,6 @@ def _stop_runs(self) -> None: wandb.termlog(f"{LOG_PREFIX}Stopping run {run_id}.") self._stop_run(run_id) - @abstractmethod - def _stop_run(self, run_id: str) -> None: - pass - def _update_run_states(self) -> None: for run_id, run in self._yield_runs(): try: From 50ba07b464e7052e54b0d38fe0786f01ab08aa3d Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 12:48:43 -0500 Subject: [PATCH 08/17] run state testing --- tests/unit_tests/test_sweep_scheduler.py | 88 ++++++++++++++---------- tests/unit_tests/test_wandb_sweep.py | 13 +++- wandb/sdk/launch/sweeps/scheduler.py | 18 ++--- 3 files changed, 71 insertions(+), 48 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index ba0dbda374a..2262abb9ccd 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -13,7 +13,7 @@ ) from wandb.sdk.launch.sweeps.scheduler_sweep import SweepScheduler -from .test_wandb_sweep import VALID_SWEEP_CONFIGS +from .test_wandb_sweep import VALID_SWEEP_CONFIGS_SMALL def test_sweep_scheduler_load(): @@ -24,7 +24,7 @@ def test_sweep_scheduler_load(): @patch.multiple(Scheduler, __abstractmethods__=set()) -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) def test_sweep_scheduler_entity_project_sweep_id(user, relay_server, sweep_config): with relay_server(): _entity = user @@ -41,8 +41,8 @@ def test_sweep_scheduler_entity_project_sweep_id(user, relay_server, sweep_confi @patch.multiple(Scheduler, __abstractmethods__=set()) -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) -def test_sweep_scheduler_base_scheduler_state( +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +def test_sweep_scheduler_base_scheduler_states( user, relay_server, sweep_config, monkeypatch ): @@ -110,41 +110,55 @@ def mock_run_exit(self, *args, **kwargs): @patch.multiple(Scheduler, __abstractmethods__=set()) -def test_sweep_scheduler_base_run_state(): - api = internal.Api() - # Mock api.get_run_state() to return crashed and running runs - mock_run_states = { - "run1": ("crashed", SimpleRunState.DEAD), - "run2": ("failed", SimpleRunState.DEAD), - "run3": ("killed", SimpleRunState.DEAD), - "run4": ("finished", SimpleRunState.DEAD), - "run5": ("running", SimpleRunState.ALIVE), - "run6": ("pending", SimpleRunState.ALIVE), - "run7": ("preempted", SimpleRunState.ALIVE), - "run8": ("preempting", SimpleRunState.ALIVE), - } - - def mock_get_run_state(entity, project, run_id): - return mock_run_states[run_id][0] - - api.get_run_state = mock_get_run_state - _scheduler = Scheduler(api, entity="foo", project="bar") - for run_id in mock_run_states.keys(): - _scheduler._runs[run_id] = SweepRun(id=run_id, state=SimpleRunState.ALIVE) - _scheduler._update_run_states() - for run_id, _state in mock_run_states.items(): - assert _scheduler._runs[run_id].state == _state[1] +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +def test_sweep_scheduler_base_run_states(user, relay_server, sweep_config, monkeypatch): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) - def mock_get_run_state_raise_exception(*args, **kwargs): - raise Exception("Generic Exception") + # Mock api.get_run_state() to return crashed and running runs + mock_run_states = { + "run1": ("crashed", SimpleRunState.DEAD), + "run2": ("failed", SimpleRunState.DEAD), + "run3": ("killed", SimpleRunState.DEAD), + "run4": ("finished", SimpleRunState.DEAD), + "run5": ("running", SimpleRunState.ALIVE), + "run6": ("pending", SimpleRunState.ALIVE), + "run7": ("preempted", SimpleRunState.ALIVE), + "run8": ("preempting", SimpleRunState.ALIVE), + } + def mock_get_run_state(entity, project, run_id, *args, **kwargs): + print('GOT HERE') + return mock_run_states[run_id][0] + + monkeypatch.setattr("wandb.apis.internal.Api.get_run_state", mock_get_run_state) - api.get_run_state = mock_get_run_state_raise_exception - _scheduler = Scheduler(api, entity="foo", project="bar") - _scheduler._runs["foo_run_1"] = SweepRun(id="foo_run_1", state=SimpleRunState.ALIVE) - _scheduler._runs["foo_run_2"] = SweepRun(id="foo_run_2", state=SimpleRunState.ALIVE) - _scheduler._update_run_states() - assert _scheduler._runs["foo_run_1"].state == SimpleRunState.UNKNOWN - assert _scheduler._runs["foo_run_2"].state == SimpleRunState.UNKNOWN + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + # Load up the runs into the Scheduler run dict + for run_id in mock_run_states.keys(): + _scheduler._runs[run_id] = SweepRun(id=run_id, state=SimpleRunState.ALIVE) + _scheduler._update_run_states() + for run_id, _state in mock_run_states.items(): + if _state[1] == SimpleRunState.DEAD: + # Dead runs should be removed from the run dict + assert run_id not in _scheduler._runs.keys() + else: + assert _scheduler._runs[run_id].state == _state[1] + + # ---- If get_run_state errors out, runs should have the state UNKNOWN + def mock_get_run_state_raise_exception(*args, **kwargs): + raise Exception("Generic Exception") + + monkeypatch.setattr("wandb.apis.internal.Api.get_run_state", mock_get_run_state_raise_exception) + + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + _scheduler._runs["foo_run_1"] = SweepRun(id="foo_run_1", state=SimpleRunState.ALIVE) + _scheduler._runs["foo_run_2"] = SweepRun(id="foo_run_2", state=SimpleRunState.ALIVE) + _scheduler._update_run_states() + assert _scheduler._runs["foo_run_1"].state == SimpleRunState.UNKNOWN + assert _scheduler._runs["foo_run_2"].state == SimpleRunState.UNKNOWN @patch.multiple(Scheduler, __abstractmethods__=set()) diff --git a/tests/unit_tests/test_wandb_sweep.py b/tests/unit_tests/test_wandb_sweep.py index ee5a05a7417..8e344ac49dc 100644 --- a/tests/unit_tests/test_wandb_sweep.py +++ b/tests/unit_tests/test_wandb_sweep.py @@ -40,7 +40,14 @@ } # List of all valid base configurations -VALID_SWEEP_CONFIGS: List[Dict[str, Any]] = [ +VALID_SWEEP_CONFIGS_SMALL: List[Dict[str, Any]] = [ + SWEEP_CONFIG_GRID, + # SWEEP_CONFIG_GRID_HYPERBAND, + # SWEEP_CONFIG_GRID_NESTED, + # SWEEP_CONFIG_BAYES, + # SWEEP_CONFIG_RANDOM, +] +VALID_SWEEP_CONFIGS_FULL: List[Dict[str, Any]] = [ SWEEP_CONFIG_GRID, SWEEP_CONFIG_GRID_HYPERBAND, SWEEP_CONFIG_GRID_NESTED, @@ -49,14 +56,14 @@ ] -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_FULL) def test_sweep_create(user, relay_server, sweep_config): with relay_server() as relay: sweep_id = wandb.sweep(sweep_config, entity=user) assert sweep_id in relay.context.entries -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) def test_sweep_entity_project_callable(user, relay_server, sweep_config): def sweep_callable(): return sweep_config diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index c81d8baa1fb..af713126e4d 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -17,7 +17,7 @@ from wandb.sdk.lib.runid import generate_id logger = logging.getLogger(__name__) -LOG_PREFIX = f"{click.style('sched:', fg='cyan')}: " +LOG_PREFIX = f"{click.style('sched:', fg='cyan')} " class SchedulerState(Enum): @@ -173,9 +173,11 @@ def _stop_runs(self) -> None: self._stop_run(run_id) def _update_run_states(self) -> None: + _runs_to_remove: List[str] = [] for run_id, run in self._yield_runs(): try: _state = self._api.get_run_state(self._entity, self._project, run_id) + print(f'{LOG_PREFIX}Run {run_id} is {_state}') if _state is None or _state in [ "crashed", "failed", @@ -183,6 +185,7 @@ def _update_run_states(self) -> None: "finished", ]: run.state = SimpleRunState.DEAD + _runs_to_remove.append(run_id) elif _state in [ "running", "pending", @@ -191,17 +194,16 @@ def _update_run_states(self) -> None: ]: run.state = SimpleRunState.ALIVE except Exception as e: - _msg = f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}" - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}") run.state = SimpleRunState.UNKNOWN continue + print(f'{LOG_PREFIX}Removing {len(_runs_to_remove)} runs.') # Remove any runs that are dead with self._threading_lock: - for run_id, run in self._runs.items(): - if run.state == SimpleRunState.DEAD: - wandb.termlog(f"{LOG_PREFIX}Removing dead run {run_id}.") - del self._runs[run_id] + for run_id in _runs_to_remove: + wandb.termlog(f"{LOG_PREFIX}Removing dead run {run_id}.") + print(f"deleting {run_id}") + del self._runs[run_id] def _add_to_launch_queue( self, From 38d9420f6a6ad13a6fa6a75550945050ad26694e Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 13:14:32 -0500 Subject: [PATCH 09/17] explicit worker killing --- tests/unit_tests/test_sweep_scheduler.py | 31 ++++++++++++++-------- tests/unit_tests/test_wandb_sweep.py | 8 +++--- wandb/sdk/launch/sweeps/scheduler.py | 21 ++++++++++----- wandb/sdk/launch/sweeps/scheduler_sweep.py | 14 +++++----- 4 files changed, 44 insertions(+), 30 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 2262abb9ccd..249a7a78deb 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -111,7 +111,7 @@ def mock_run_exit(self, *args, **kwargs): @patch.multiple(Scheduler, __abstractmethods__=set()) @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) -def test_sweep_scheduler_base_run_states(user, relay_server, sweep_config, monkeypatch): +def test_sweep_scheduler_base_run_states(user, relay_server, sweep_config): with relay_server(): _entity = user _project = "test-project" @@ -129,12 +129,11 @@ def test_sweep_scheduler_base_run_states(user, relay_server, sweep_config, monke "run7": ("preempted", SimpleRunState.ALIVE), "run8": ("preempting", SimpleRunState.ALIVE), } + def mock_get_run_state(entity, project, run_id, *args, **kwargs): - print('GOT HERE') return mock_run_states[run_id][0] - monkeypatch.setattr("wandb.apis.internal.Api.get_run_state", mock_get_run_state) - + api.get_run_state = mock_get_run_state _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) # Load up the runs into the Scheduler run dict for run_id in mock_run_states.keys(): @@ -151,19 +150,29 @@ def mock_get_run_state(entity, project, run_id, *args, **kwargs): def mock_get_run_state_raise_exception(*args, **kwargs): raise Exception("Generic Exception") - monkeypatch.setattr("wandb.apis.internal.Api.get_run_state", mock_get_run_state_raise_exception) - + api.get_run_state = mock_get_run_state_raise_exception _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) - _scheduler._runs["foo_run_1"] = SweepRun(id="foo_run_1", state=SimpleRunState.ALIVE) - _scheduler._runs["foo_run_2"] = SweepRun(id="foo_run_2", state=SimpleRunState.ALIVE) + _scheduler._runs["foo_run_1"] = SweepRun( + id="foo_run_1", state=SimpleRunState.ALIVE + ) + _scheduler._runs["foo_run_2"] = SweepRun( + id="foo_run_2", state=SimpleRunState.ALIVE + ) _scheduler._update_run_states() assert _scheduler._runs["foo_run_1"].state == SimpleRunState.UNKNOWN assert _scheduler._runs["foo_run_2"].state == SimpleRunState.UNKNOWN @patch.multiple(Scheduler, __abstractmethods__=set()) -def test_sweep_scheduler_base_add_to_launch_queue(monkeypatch): - api = internal.Api() +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +def test_sweep_scheduler_base_add_to_launch_queue( + user, relay_server, sweep_config, monkeypatch +): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) def mock_launch_add(*args, **kwargs): return Mock(spec=public.QueuedRun) @@ -184,7 +193,7 @@ def mock_run_add_to_launch_queue(self, *args, **kwargs): mock_run_add_to_launch_queue, ) - _scheduler = Scheduler(api, entity="foo", project="bar") + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) assert _scheduler.state == SchedulerState.PENDING assert _scheduler.is_alive() is True _scheduler.start() diff --git a/tests/unit_tests/test_wandb_sweep.py b/tests/unit_tests/test_wandb_sweep.py index 8e344ac49dc..cfab1c44024 100644 --- a/tests/unit_tests/test_wandb_sweep.py +++ b/tests/unit_tests/test_wandb_sweep.py @@ -42,10 +42,10 @@ # List of all valid base configurations VALID_SWEEP_CONFIGS_SMALL: List[Dict[str, Any]] = [ SWEEP_CONFIG_GRID, - # SWEEP_CONFIG_GRID_HYPERBAND, - # SWEEP_CONFIG_GRID_NESTED, - # SWEEP_CONFIG_BAYES, - # SWEEP_CONFIG_RANDOM, + SWEEP_CONFIG_GRID_HYPERBAND, + SWEEP_CONFIG_GRID_NESTED, + SWEEP_CONFIG_BAYES, + SWEEP_CONFIG_RANDOM, ] VALID_SWEEP_CONFIGS_FULL: List[Dict[str, Any]] = [ SWEEP_CONFIG_GRID, diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index af713126e4d..0dfa912c5cf 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -45,7 +45,6 @@ class SweepRun: program: Optional[str] = None # Threading can be used to run multiple workers in parallel worker_id: Optional[int] = None - worker_thread: Optional[threading.Thread] = None class Scheduler(ABC): @@ -97,7 +96,7 @@ def _exit(self) -> None: pass @abstractmethod - def _stop_run(self, run_id: str) -> None: + def _kill_worker(self) -> None: pass @property @@ -172,12 +171,20 @@ def _stop_runs(self) -> None: wandb.termlog(f"{LOG_PREFIX}Stopping run {run_id}.") self._stop_run(run_id) + def _stop_run(self, run_id) -> None: + """Stops a run and removes it from the scheduler""" + if run_id in self._runs: + run = self._runs[run_id] + if run.worker_id is not None: + self._kill_worker(run.worker_id) + wandb.termlog(f"{LOG_PREFIX} Stopped run {run_id}.") + run.state = SimpleRunState.DEAD + def _update_run_states(self) -> None: _runs_to_remove: List[str] = [] for run_id, run in self._yield_runs(): try: _state = self._api.get_run_state(self._entity, self._project, run_id) - print(f'{LOG_PREFIX}Run {run_id} is {_state}') if _state is None or _state in [ "crashed", "failed", @@ -194,15 +201,15 @@ def _update_run_states(self) -> None: ]: run.state = SimpleRunState.ALIVE except Exception as e: - wandb.termlog(f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}") + wandb.termlog( + f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}" + ) run.state = SimpleRunState.UNKNOWN continue - print(f'{LOG_PREFIX}Removing {len(_runs_to_remove)} runs.') # Remove any runs that are dead with self._threading_lock: for run_id in _runs_to_remove: - wandb.termlog(f"{LOG_PREFIX}Removing dead run {run_id}.") - print(f"deleting {run_id}") + wandb.termlog(f"{LOG_PREFIX}Cleaning up dead run {run_id}.") del self._runs[run_id] def _add_to_launch_queue( diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index b260cbe2676..ace7f89cade 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -163,14 +163,12 @@ def _run(self) -> None: run_id=run.id, ) - def _stop_run(self, run_id: str) -> None: - run = self._runs.get(run_id, None) - if run is not None: - _worker = self._workers.get(run.worker_id, None) - if _worker and _worker.thread.is_alive(): - # Set threading event to stop the worker thread - _worker.stop.set() - run.state = SimpleRunState.DEAD + def _kill_worker(self, worker_id: int) -> None: + _worker = self._workers.get(worker_id, None) + if _worker and _worker.thread.is_alive(): + # Set threading event to stop the worker thread + _worker.stop.set() + _worker.thread.join() def _exit(self) -> None: self.state = SchedulerState.COMPLETED From b4cd6c89e9f5e7c24ac11e75bf0d8a48d9b58116 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 13:50:15 -0500 Subject: [PATCH 10/17] format --- tests/unit_tests/test_sweep_scheduler.py | 155 ++++++++++----------- wandb/sdk/launch/sweeps/scheduler.py | 7 +- wandb/sdk/launch/sweeps/scheduler_sweep.py | 12 +- 3 files changed, 83 insertions(+), 91 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 249a7a78deb..7e634069ecf 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -174,83 +174,81 @@ def test_sweep_scheduler_base_add_to_launch_queue( api = internal.Api() sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) - def mock_launch_add(*args, **kwargs): - return Mock(spec=public.QueuedRun) + def mock_launch_add(*args, **kwargs): + return Mock(spec=public.QueuedRun) - monkeypatch.setattr( - "wandb.sdk.launch.launch_add._launch_add", - mock_launch_add, - ) + monkeypatch.setattr( + "wandb.sdk.launch.launch_add._launch_add", + mock_launch_add, + ) - def mock_run_add_to_launch_queue(self, *args, **kwargs): - self._runs["foo_run"] = SweepRun(id="foo_run", state=SimpleRunState.ALIVE) - self._add_to_launch_queue(run_id="foo_run") - self.state = SchedulerState.COMPLETED - self.exit() + def mock_run_add_to_launch_queue(self, *args, **kwargs): + self._runs["foo_run"] = SweepRun(id="foo_run", state=SimpleRunState.ALIVE) + self._add_to_launch_queue(run_id="foo_run") + self.state = SchedulerState.COMPLETED + self.exit() - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_add_to_launch_queue, - ) + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_add_to_launch_queue, + ) - _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.COMPLETED - assert _scheduler.is_alive() is False - assert len(_scheduler._runs) == 1 - assert isinstance(_scheduler._runs["foo_run"].queued_run, public.QueuedRun) - assert _scheduler._runs["foo_run"].state == SimpleRunState.DEAD - - -@pytest.mark.xfail(reason="TODO(hupo): fix") -def test_sweep_scheduler_sweeps(monkeypatch): - api = internal.Api() - - api.agent_heartbeat = Mock( - side_effect=[ - [ - { - "type": "run", - "run_id": "foo_run_1", - "args": {"foo_arg": {"value": 1}}, - "program": "train.py", - } - ], - [ - { - "type": "stop", - "run_id": "foo_run_1", - } - ], - [ - { - "type": "resume", - "run_id": "foo_run_1", - "args": {"foo_arg": {"value": 1}}, - "program": "train.py", - } - ], - [ - { - "type": "exit", - } - ], - ] - ) + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + assert _scheduler.state == SchedulerState.PENDING + assert _scheduler.is_alive() is True + _scheduler.start() + assert _scheduler.state == SchedulerState.COMPLETED + assert _scheduler.is_alive() is False + assert len(_scheduler._runs) == 1 + assert isinstance(_scheduler._runs["foo_run"].queued_run, public.QueuedRun) + assert _scheduler._runs["foo_run"].state == SimpleRunState.DEAD - def mock_sweep(self, sweep_id, *args, **kwargs): - if sweep_id == "404sweep": - return False - return True - monkeypatch.setattr("wandb.apis.internal.Api.sweep", mock_sweep) +@patch.multiple(Scheduler, __abstractmethods__=set()) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +def test_sweep_scheduler_sweeps(user, relay_server, sweep_config, monkeypatch): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + + # api.agent_heartbeat = Mock( + # side_effect=[ + # [ + # { + # "type": "run", + # "run_id": "foo_run_1", + # "args": {"foo_arg": {"value": 1}}, + # "program": "train.py", + # } + # ], + # [ + # { + # "type": "stop", + # "run_id": "foo_run_1", + # } + # ], + # [ + # { + # "type": "resume", + # "run_id": "foo_run_1", + # "args": {"foo_arg": {"value": 1}}, + # "program": "train.py", + # } + # ], + # [ + # { + # "type": "exit", + # } + # ], + # ] + # ) - def mock_register_agent(*args, **kwargs): - return {"id": "foo_agent_pid"} + # def mock_register_agent(*args, **kwargs): + # return {"id": "foo_agent_pid"} - monkeypatch.setattr("wandb.apis.internal.Api.register_agent", mock_register_agent) + # monkeypatch.setattr("wandb.apis.internal.Api.register_agent", mock_register_agent) # def mock_add_to_launch_queue(self, *args, **kwargs): # assert "entry_point" in kwargs @@ -265,10 +263,6 @@ def mock_register_agent(*args, **kwargs): # mock_add_to_launch_queue, # ) - with pytest.raises(SchedulerError) as e: - SweepScheduler(api, sweep_id="404sweep") - assert "Could not find sweep" in str(e.value) - def mock_launch_add(*args, **kwargs): return Mock(spec=public.QueuedRun) @@ -280,19 +274,16 @@ def mock_launch_add(*args, **kwargs): def mock_get_run_state(*args, **kwargs): return "finished" - monkeypatch.setattr("wandb.apis.internal.Api.get_run_state", mock_get_run_state) + api.get_run_state = mock_get_run_state _scheduler = SweepScheduler( - api, - entity="mock-entity", - project="mock-project", - sweep_id="mock-sweep", + api, sweep_id=sweep_id, entity=_entity, project=_project ) assert _scheduler.state == SchedulerState.PENDING assert _scheduler.is_alive() is True _scheduler.start() - for _heartbeat_agent in _scheduler._workers: - assert not _heartbeat_agent.thread.is_alive() + # for _heartbeat_agent in _scheduler._workers: + # assert not _heartbeat_agent.thread.is_alive() assert _scheduler.state == SchedulerState.COMPLETED - assert len(_scheduler._runs) == 1 - assert _scheduler._runs["foo_run_1"].state == SimpleRunState.DEAD + # assert len(_scheduler._runs) == 1 + # assert _scheduler._runs["foo_run_1"].state == SimpleRunState.DEAD diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index 0dfa912c5cf..1300188d041 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from enum import Enum import logging -from multiprocessing.sharedctypes import Value import os import threading from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -96,7 +95,7 @@ def _exit(self) -> None: pass @abstractmethod - def _kill_worker(self) -> None: + def _kill_worker(self, worker_id: int) -> None: pass @property @@ -171,10 +170,10 @@ def _stop_runs(self) -> None: wandb.termlog(f"{LOG_PREFIX}Stopping run {run_id}.") self._stop_run(run_id) - def _stop_run(self, run_id) -> None: + def _stop_run(self, run_id: str) -> None: """Stops a run and removes it from the scheduler""" if run_id in self._runs: - run = self._runs[run_id] + run: SweepRun = self._runs[run_id] if run.worker_id is not None: self._kill_worker(run.worker_id) wandb.termlog(f"{LOG_PREFIX} Stopped run {run_id}.") diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index ace7f89cade..cb13a06f9b9 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -39,10 +39,10 @@ class SweepScheduler(Scheduler): def __init__( self, *args: Any, - num_workers: int = 4, + num_workers: int = 2, worker_sleep: float = 0.5, - heartbeat_queue_timeout: float = 1, - heartbeat_queue_sleep: float = 1, + heartbeat_queue_timeout: float = 0.5, + heartbeat_queue_sleep: float = 0.5, **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -50,7 +50,7 @@ def __init__( # actually run training workloads, they simply send heartbeat messages # (emulating a real agent) and add new runs to the launch queue. The # launch agent is the one that actually runs the training workloads. - self._workers: Dict[int, _Worker] = [] + self._workers: Dict[int, _Worker] = {} self._num_workers: int = num_workers self._worker_sleep: float = worker_sleep # Thread will pop items off the Sweeps RunQueue using AgentHeartbeat @@ -168,7 +168,9 @@ def _kill_worker(self, worker_id: int) -> None: if _worker and _worker.thread.is_alive(): # Set threading event to stop the worker thread _worker.stop.set() - _worker.thread.join() + print(f"{LOG_PREFIX}Killing AgentHeartbeat worker {worker_id}") + _worker.thread.join() + print(f"{LOG_PREFIX}AgentHeartbeat worker {worker_id} killed") def _exit(self) -> None: self.state = SchedulerState.COMPLETED From f4eaf404ba946fccaaf5c74f680dd5d9ae541ef4 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 15:17:33 -0500 Subject: [PATCH 11/17] shorten sleeps, rename configs --- tests/unit_tests/test_sweep_scheduler.py | 12 +++++----- tests/unit_tests/test_wandb_sweep.py | 22 +++++++++---------- .../launch/sweeps/Dockerfile.scheduler.sweep | 1 + wandb/sdk/launch/sweeps/scheduler_sweep.py | 6 ++--- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 7e634069ecf..fc3403842a3 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -13,7 +13,7 @@ ) from wandb.sdk.launch.sweeps.scheduler_sweep import SweepScheduler -from .test_wandb_sweep import VALID_SWEEP_CONFIGS_SMALL +from .test_wandb_sweep import VALID_SWEEP_CONFIGS_MINIMAL def test_sweep_scheduler_load(): @@ -24,7 +24,7 @@ def test_sweep_scheduler_load(): @patch.multiple(Scheduler, __abstractmethods__=set()) -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_scheduler_entity_project_sweep_id(user, relay_server, sweep_config): with relay_server(): _entity = user @@ -41,7 +41,7 @@ def test_sweep_scheduler_entity_project_sweep_id(user, relay_server, sweep_confi @patch.multiple(Scheduler, __abstractmethods__=set()) -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_scheduler_base_scheduler_states( user, relay_server, sweep_config, monkeypatch ): @@ -110,7 +110,7 @@ def mock_run_exit(self, *args, **kwargs): @patch.multiple(Scheduler, __abstractmethods__=set()) -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_scheduler_base_run_states(user, relay_server, sweep_config): with relay_server(): _entity = user @@ -164,7 +164,7 @@ def mock_get_run_state_raise_exception(*args, **kwargs): @patch.multiple(Scheduler, __abstractmethods__=set()) -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_scheduler_base_add_to_launch_queue( user, relay_server, sweep_config, monkeypatch ): @@ -205,7 +205,7 @@ def mock_run_add_to_launch_queue(self, *args, **kwargs): @patch.multiple(Scheduler, __abstractmethods__=set()) -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_scheduler_sweeps(user, relay_server, sweep_config, monkeypatch): with relay_server(): _entity = user diff --git a/tests/unit_tests/test_wandb_sweep.py b/tests/unit_tests/test_wandb_sweep.py index cfab1c44024..1fe79f1537d 100644 --- a/tests/unit_tests/test_wandb_sweep.py +++ b/tests/unit_tests/test_wandb_sweep.py @@ -39,31 +39,31 @@ "parameters": {"param1": {"values": [1, 2, 3]}}, } -# List of all valid base configurations -VALID_SWEEP_CONFIGS_SMALL: List[Dict[str, Any]] = [ - SWEEP_CONFIG_GRID, - SWEEP_CONFIG_GRID_HYPERBAND, - SWEEP_CONFIG_GRID_NESTED, +# Minimal list of valid sweep configs +VALID_SWEEP_CONFIGS_MINIMAL: List[Dict[str, Any]] = [ SWEEP_CONFIG_BAYES, SWEEP_CONFIG_RANDOM, -] -VALID_SWEEP_CONFIGS_FULL: List[Dict[str, Any]] = [ - SWEEP_CONFIG_GRID, SWEEP_CONFIG_GRID_HYPERBAND, SWEEP_CONFIG_GRID_NESTED, - SWEEP_CONFIG_BAYES, +] +# All valid sweep configs, be careful as this will slow down tests +VALID_SWEEP_CONFIGS_ALL: List[Dict[str, Any]] = [ SWEEP_CONFIG_RANDOM, + SWEEP_CONFIG_BAYES, + SWEEP_CONFIG_GRID, + SWEEP_CONFIG_GRID_NESTED, + SWEEP_CONFIG_GRID_HYPERBAND, ] -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_FULL) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_ALL) def test_sweep_create(user, relay_server, sweep_config): with relay_server() as relay: sweep_id = wandb.sweep(sweep_config, entity=user) assert sweep_id in relay.context.entries -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_SMALL) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_entity_project_callable(user, relay_server, sweep_config): def sweep_callable(): return sweep_config diff --git a/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep b/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep index f7570b02f34..c8a1d6f665a 100644 --- a/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep +++ b/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep @@ -1,5 +1,6 @@ FROM python:3.9-slim-bullseye LABEL maintainer='Weights & Biases ' +LABEL version="0.1" # install git RUN apt-get update && apt-get upgrade -y \ diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index cb13a06f9b9..c0f8f1154ca 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -40,9 +40,9 @@ def __init__( self, *args: Any, num_workers: int = 2, - worker_sleep: float = 0.5, - heartbeat_queue_timeout: float = 0.5, - heartbeat_queue_sleep: float = 0.5, + worker_sleep: float = 0.1, + heartbeat_queue_timeout: float = 0.1, + heartbeat_queue_sleep: float = 0.1, **kwargs: Any, ): super().__init__(*args, **kwargs) From 0e2e758a04a81fc394ded46c35c8fc858e2b1293 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 15:24:57 -0500 Subject: [PATCH 12/17] boilerplate --- tests/unit_tests/test_sweep_scheduler.py | 29 ++++++++++++++++++++-- wandb/sdk/launch/sweeps/scheduler_sweep.py | 2 +- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index fc3403842a3..72154d44101 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -204,15 +204,40 @@ def mock_run_add_to_launch_queue(self, *args, **kwargs): assert _scheduler._runs["foo_run"].state == SimpleRunState.DEAD -@patch.multiple(Scheduler, __abstractmethods__=set()) @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -def test_sweep_scheduler_sweeps(user, relay_server, sweep_config, monkeypatch): +def test_sweep_scheduler_sweeps_add_to_launch_queue(user, relay_server, sweep_config, monkeypatch): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + + +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +def test_sweep_scheduler_sweeps_single_threading(user, relay_server, sweep_config, monkeypatch): with relay_server(): _entity = user _project = "test-project" api = internal.Api() sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + _scheduler = SweepScheduler( + api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=1 + ) + +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +def test_sweep_scheduler_sweeps_multi_threading(user, relay_server, sweep_config, monkeypatch): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + + _scheduler = SweepScheduler( + api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=4 + ) + + # api.agent_heartbeat = Mock( # side_effect=[ # [ diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index c0f8f1154ca..e4503e823a8 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -39,7 +39,7 @@ class SweepScheduler(Scheduler): def __init__( self, *args: Any, - num_workers: int = 2, + num_workers: int = 4, worker_sleep: float = 0.1, heartbeat_queue_timeout: float = 0.1, heartbeat_queue_sleep: float = 0.1, From f568aee3a47019af1da7e2ed49cdea7a704b4055 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Mon, 15 Aug 2022 15:25:11 -0500 Subject: [PATCH 13/17] boilerplate 2 --- tests/unit_tests/test_sweep_scheduler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 72154d44101..d8dc6fe6797 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -225,6 +225,12 @@ def test_sweep_scheduler_sweeps_single_threading(user, relay_server, sweep_confi api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=1 ) + def mock_get_run_state(*args, **kwargs): + return "finished" + + api.get_run_state = mock_get_run_state + + @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_scheduler_sweeps_multi_threading(user, relay_server, sweep_config, monkeypatch): with relay_server(): From a1306e6b5319ee68d89ba2ff3254ddf7e71a5b64 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Tue, 16 Aug 2022 09:00:41 -0500 Subject: [PATCH 14/17] progress --- tests/unit_tests/test_sweep_scheduler.py | 100 +++++++++++++++++---- tests/unit_tests/test_wandb_sweep.py | 6 +- wandb/sdk/launch/sweeps/scheduler.py | 6 +- wandb/sdk/launch/sweeps/scheduler_sweep.py | 34 +++++-- 4 files changed, 117 insertions(+), 29 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index d8dc6fe6797..e8d9d80b313 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -205,34 +205,93 @@ def mock_run_add_to_launch_queue(self, *args, **kwargs): @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -def test_sweep_scheduler_sweeps_add_to_launch_queue(user, relay_server, sweep_config, monkeypatch): +@pytest.mark.parametrize("num_workers", [1, 3, 8]) +def test_sweep_scheduler_sweeps_kill_threads( + user, relay_server, sweep_config, num_workers, monkeypatch +): with relay_server(): _entity = user _project = "test-project" api = internal.Api() sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "run", "run_id": "mock-run-id"}] + + def mock_get_run_state(*args, **kwargs): + return "running" + + def mock_register_agent(*args, **kwargs): + return {"id": "mock-agent-id"} + + def mock_launch_add(*args, **kwargs): + return Mock(spec=public.QueuedRun) + + api.get_run_state = mock_get_run_state + api.register_agent = mock_register_agent + api.agent_heartbeat = mock_agent_heartbeat + monkeypatch.setattr( + "wandb.sdk.launch.launch_add._launch_add", + mock_launch_add, + ) + + _scheduler = SweepScheduler( + api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=num_workers + ) + _scheduler.start() + _scheduler.exit() + assert _scheduler.state == SchedulerState.COMPLETED + assert _scheduler.is_alive() is False @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -def test_sweep_scheduler_sweeps_single_threading(user, relay_server, sweep_config, monkeypatch): +def test_sweep_scheduler_sweeps_invalid_agent_heartbeat( + user, relay_server, sweep_config +): with relay_server(): _entity = user _project = "test-project" api = internal.Api() sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) - _scheduler = SweepScheduler( - api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=1 - ) + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "foo"}] - def mock_get_run_state(*args, **kwargs): - return "finished" + def mock_register_agent(*args, **kwargs): + return {"id": "mock-agent-id"} - api.get_run_state = mock_get_run_state + api.register_agent = mock_register_agent + api.agent_heartbeat = mock_agent_heartbeat + with pytest.raises(SchedulerError) as e: + _scheduler = SweepScheduler( + api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=1 + ) + _scheduler.start() + + assert "Unknown command" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False + + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "run"}] # No run_id should throw error + + api.agent_heartbeat = mock_agent_heartbeat + + with pytest.raises(SchedulerError) as e: + _scheduler = SweepScheduler( + api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=1 + ) + _scheduler.start() + + assert "missing run_id" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -def test_sweep_scheduler_sweeps_multi_threading(user, relay_server, sweep_config, monkeypatch): +@pytest.mark.parametrize("num_workers", [1, 3]) +def test_sweep_scheduler_sweeps_run_and_heartbeat( + user, relay_server, sweep_config, num_workers, monkeypatch +): with relay_server(): _entity = user _project = "test-project" @@ -240,9 +299,11 @@ def test_sweep_scheduler_sweeps_multi_threading(user, relay_server, sweep_config sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) _scheduler = SweepScheduler( - api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=4 + api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=num_workers ) - + assert _scheduler.state == SchedulerState.PENDING + assert _scheduler.is_alive() is True + _scheduler.start() # api.agent_heartbeat = Mock( # side_effect=[ @@ -276,6 +337,14 @@ def test_sweep_scheduler_sweeps_multi_threading(user, relay_server, sweep_config # ] # ) + # # Mock api.get_run_state() to return running and finally finished runs + # mock_run_state_counter: int = 0 + # def mock_get_run_state(*args, **kwargs): + # if mock_run_state_counter < 4: + # mock_run_state_counter += 1 + # return "running" + # return "finished" + # def mock_register_agent(*args, **kwargs): # return {"id": "foo_agent_pid"} @@ -302,17 +371,10 @@ def mock_launch_add(*args, **kwargs): mock_launch_add, ) - def mock_get_run_state(*args, **kwargs): - return "finished" - - api.get_run_state = mock_get_run_state - _scheduler = SweepScheduler( api, sweep_id=sweep_id, entity=_entity, project=_project ) - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() + # for _heartbeat_agent in _scheduler._workers: # assert not _heartbeat_agent.thread.is_alive() assert _scheduler.state == SchedulerState.COMPLETED diff --git a/tests/unit_tests/test_wandb_sweep.py b/tests/unit_tests/test_wandb_sweep.py index 1fe79f1537d..5eb79d98739 100644 --- a/tests/unit_tests/test_wandb_sweep.py +++ b/tests/unit_tests/test_wandb_sweep.py @@ -42,9 +42,9 @@ # Minimal list of valid sweep configs VALID_SWEEP_CONFIGS_MINIMAL: List[Dict[str, Any]] = [ SWEEP_CONFIG_BAYES, - SWEEP_CONFIG_RANDOM, - SWEEP_CONFIG_GRID_HYPERBAND, - SWEEP_CONFIG_GRID_NESTED, + # SWEEP_CONFIG_RANDOM, + # SWEEP_CONFIG_GRID_HYPERBAND, + # SWEEP_CONFIG_GRID_NESTED, ] # All valid sweep configs, be careful as this will slow down tests VALID_SWEEP_CONFIGS_ALL: List[Dict[str, Any]] = [ diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index 1300188d041..b150aa3b71b 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -77,9 +77,11 @@ def __init__( raise SchedulerError(f"{LOG_PREFIX}Exception when finding sweep: {e}") self._sweep_id: str = sweep_id or "empty-sweep-id" self._state: SchedulerState = SchedulerState.PENDING - self._threading_lock: threading.Lock = threading.Lock() - # List of the runs managed by the scheduler + # Dictionary of the runs being managed by the scheduler self._runs: Dict[str, SweepRun] = {} + # Threading lock to ensure thread-safe access to the runs dictionary + self._threading_lock: threading.Lock = threading.Lock() + # Scheduler may receive additional kwargs which will be piped into the launch command self._kwargs: Dict[str, Any] = kwargs @abstractmethod diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index e4503e823a8..4309be17c67 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -11,6 +11,7 @@ import wandb from wandb import wandb_lib # type: ignore +from wandb.sdk.launch.sweeps import SchedulerError from wandb.sdk.launch.sweeps.scheduler import ( LOG_PREFIX, Scheduler, @@ -69,9 +70,16 @@ def _start(self) -> None: project_name=self._project, entity=self._entity, ) + + def excepthook(args): + print(f"In excepthook {args}") + + threading.excepthook = excepthook + # Worker threads call heartbeat function _thread = threading.Thread(target=self._heartbeat, args=[worker_id]) - _thread.daemon = True + + # _thread.daemon = True self._workers[worker_id] = _Worker( agent_config=agent_config, agent_id=agent_config["id"], @@ -109,25 +117,37 @@ def _heartbeat(self, worker_id: int) -> None: if commands: for command in commands: # The command "type" can be one of "run", "resume", "stop", "exit" - _type = command.get("type") + _type = command.get("type", None) if _type in ["exit", "stop"]: # Tell (virtual) agent to stop running self.exit() return - if _type in ["run", "resume"]: + elif _type in ["run", "resume"]: + _run_id = command.get("run_id", None) + if _run_id is None: + raise SchedulerError( + f"{LOG_PREFIX}AgentHeartbeat command {command} missing run_id" + ) run = SweepRun( - id=command.get("run_id", "empty-run-id"), + id=_run_id, args=command.get("args", {}), logs=command.get("logs", []), - program=command.get("program"), + program=command.get("program", None), worker_id=worker_id, ) self._runs[run.id] = run self._heartbeat_queue.put(run) continue + else: + raise SchedulerError( + f"{LOG_PREFIX}Unknown command type {_type}" + ) time.sleep(self._worker_sleep) def _run(self) -> None: + # # Join worker threads to check for exceptions + # for worker_id in self._workers: + # self._workers[worker_id].thread.join(timeout=0.1) try: run: SweepRun = self._heartbeat_queue.get( timeout=self._heartbeat_queue_timeout @@ -164,6 +184,7 @@ def _run(self) -> None: ) def _kill_worker(self, worker_id: int) -> None: + print(f"{LOG_PREFIX}Killing AgentHeartbeat worker {worker_id}") _worker = self._workers.get(worker_id, None) if _worker and _worker.thread.is_alive(): # Set threading event to stop the worker thread @@ -173,4 +194,7 @@ def _kill_worker(self, worker_id: int) -> None: print(f"{LOG_PREFIX}AgentHeartbeat worker {worker_id} killed") def _exit(self) -> None: + # Kill all the worker threads + for worker_id in self._workers: + self._kill_worker(worker_id) self.state = SchedulerState.COMPLETED From ace45f7356c6055ca53a2b9298e2382a7e1da5bb Mon Sep 17 00:00:00 2001 From: Hu Po Date: Tue, 16 Aug 2022 09:56:51 -0500 Subject: [PATCH 15/17] ripped out threading --- tests/unit_tests/test_sweep_scheduler.py | 281 ++++++++------------- wandb/sdk/launch/sweeps/scheduler.py | 18 +- wandb/sdk/launch/sweeps/scheduler_sweep.py | 127 ++++------ 3 files changed, 158 insertions(+), 268 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index e8d9d80b313..4edf6d8f691 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -165,203 +165,130 @@ def mock_get_run_state_raise_exception(*args, **kwargs): @patch.multiple(Scheduler, __abstractmethods__=set()) @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -def test_sweep_scheduler_base_add_to_launch_queue( - user, relay_server, sweep_config, monkeypatch -): - with relay_server(): - _entity = user - _project = "test-project" - api = internal.Api() - sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) +def test_sweep_scheduler_base_add_to_launch_queue(user, sweep_config, monkeypatch): + api = internal.Api() - def mock_launch_add(*args, **kwargs): - return Mock(spec=public.QueuedRun) + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) - monkeypatch.setattr( - "wandb.sdk.launch.launch_add._launch_add", - mock_launch_add, - ) + def mock_launch_add(*args, **kwargs): + return Mock(spec=public.QueuedRun) - def mock_run_add_to_launch_queue(self, *args, **kwargs): - self._runs["foo_run"] = SweepRun(id="foo_run", state=SimpleRunState.ALIVE) - self._add_to_launch_queue(run_id="foo_run") - self.state = SchedulerState.COMPLETED - self.exit() + monkeypatch.setattr( + "wandb.sdk.launch.launch_add._launch_add", + mock_launch_add, + ) - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_add_to_launch_queue, - ) + def mock_run_add_to_launch_queue(self, *args, **kwargs): + self._runs["foo_run"] = SweepRun(id="foo_run", state=SimpleRunState.ALIVE) + self._add_to_launch_queue(run_id="foo_run") + self.state = SchedulerState.COMPLETED + self.exit() - _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.COMPLETED - assert _scheduler.is_alive() is False - assert len(_scheduler._runs) == 1 - assert isinstance(_scheduler._runs["foo_run"].queued_run, public.QueuedRun) - assert _scheduler._runs["foo_run"].state == SimpleRunState.DEAD + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_add_to_launch_queue, + ) + + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=user, project=_project) + assert _scheduler.state == SchedulerState.PENDING + assert _scheduler.is_alive() is True + _scheduler.start() + assert _scheduler.state == SchedulerState.COMPLETED + assert _scheduler.is_alive() is False + assert len(_scheduler._runs) == 1 + assert isinstance(_scheduler._runs["foo_run"].queued_run, public.QueuedRun) + assert _scheduler._runs["foo_run"].state == SimpleRunState.DEAD @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) @pytest.mark.parametrize("num_workers", [1, 3, 8]) -def test_sweep_scheduler_sweeps_kill_threads( - user, relay_server, sweep_config, num_workers, monkeypatch -): - with relay_server(): - _entity = user - _project = "test-project" - api = internal.Api() - sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) +def test_sweep_scheduler_sweeps_stop_agent_hearbeat(user, sweep_config, num_workers): + api = internal.Api() - def mock_agent_heartbeat(*args, **kwargs): - return [{"type": "run", "run_id": "mock-run-id"}] + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "stop"}] - def mock_get_run_state(*args, **kwargs): - return "running" + api.agent_heartbeat = mock_agent_heartbeat - def mock_register_agent(*args, **kwargs): - return {"id": "mock-agent-id"} + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) + scheduler = SweepScheduler( + api, sweep_id=sweep_id, entity=user, project=_project, num_workers=num_workers + ) + scheduler.start() + assert scheduler.state == SchedulerState.STOPPED - def mock_launch_add(*args, **kwargs): - return Mock(spec=public.QueuedRun) - - api.get_run_state = mock_get_run_state - api.register_agent = mock_register_agent - api.agent_heartbeat = mock_agent_heartbeat - monkeypatch.setattr( - "wandb.sdk.launch.launch_add._launch_add", - mock_launch_add, - ) - - _scheduler = SweepScheduler( - api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=num_workers - ) - _scheduler.start() - _scheduler.exit() - assert _scheduler.state == SchedulerState.COMPLETED - assert _scheduler.is_alive() is False @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +@pytest.mark.parametrize("num_workers", [1, 3, 8]) def test_sweep_scheduler_sweeps_invalid_agent_heartbeat( - user, relay_server, sweep_config + user, sweep_config, num_workers ): - with relay_server(): - _entity = user - _project = "test-project" - api = internal.Api() - sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + api = internal.Api() + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) - def mock_agent_heartbeat(*args, **kwargs): - return [{"type": "foo"}] + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "foo"}] - def mock_register_agent(*args, **kwargs): - return {"id": "mock-agent-id"} + api.agent_heartbeat = mock_agent_heartbeat - api.register_agent = mock_register_agent - api.agent_heartbeat = mock_agent_heartbeat + with pytest.raises(SchedulerError) as e: + _scheduler = SweepScheduler( + api, + sweep_id=sweep_id, + entity=user, + project=_project, + num_workers=num_workers, + ) + _scheduler.start() - with pytest.raises(SchedulerError) as e: - _scheduler = SweepScheduler( - api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=1 - ) - _scheduler.start() - - assert "Unknown command" in str(e.value) - assert _scheduler.state == SchedulerState.FAILED - assert _scheduler.is_alive() is False + assert "unknown command" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False - def mock_agent_heartbeat(*args, **kwargs): - return [{"type": "run"}] # No run_id should throw error + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "run"}] # No run_id should throw error - api.agent_heartbeat = mock_agent_heartbeat + api.agent_heartbeat = mock_agent_heartbeat + + with pytest.raises(SchedulerError) as e: + _scheduler = SweepScheduler( + api, + sweep_id=sweep_id, + entity=user, + project=_project, + num_workers=num_workers, + ) + _scheduler.start() + + assert "missing run_id" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False - with pytest.raises(SchedulerError) as e: - _scheduler = SweepScheduler( - api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=1 - ) - _scheduler.start() - - assert "missing run_id" in str(e.value) - assert _scheduler.state == SchedulerState.FAILED - assert _scheduler.is_alive() is False @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) @pytest.mark.parametrize("num_workers", [1, 3]) def test_sweep_scheduler_sweeps_run_and_heartbeat( - user, relay_server, sweep_config, num_workers, monkeypatch + user, sweep_config, num_workers, monkeypatch ): - with relay_server(): - _entity = user - _project = "test-project" - api = internal.Api() - sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) - - _scheduler = SweepScheduler( - api, sweep_id=sweep_id, entity=_entity, project=_project, num_workers=num_workers + api = internal.Api() + # Mock agent heartbeat stops after 10 heartbeats + api.agent_heartbeat = Mock( + side_effect=[ + [ + { + "type": "run", + "run_id": "mock-run-id-1", + "args": {"foo_arg": {"value": 1}}, + "program": "train.py", + } + ] + ] + * 10 + + [[{"type": "stop"}]] ) - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - - # api.agent_heartbeat = Mock( - # side_effect=[ - # [ - # { - # "type": "run", - # "run_id": "foo_run_1", - # "args": {"foo_arg": {"value": 1}}, - # "program": "train.py", - # } - # ], - # [ - # { - # "type": "stop", - # "run_id": "foo_run_1", - # } - # ], - # [ - # { - # "type": "resume", - # "run_id": "foo_run_1", - # "args": {"foo_arg": {"value": 1}}, - # "program": "train.py", - # } - # ], - # [ - # { - # "type": "exit", - # } - # ], - # ] - # ) - - # # Mock api.get_run_state() to return running and finally finished runs - # mock_run_state_counter: int = 0 - # def mock_get_run_state(*args, **kwargs): - # if mock_run_state_counter < 4: - # mock_run_state_counter += 1 - # return "running" - # return "finished" - - # def mock_register_agent(*args, **kwargs): - # return {"id": "foo_agent_pid"} - - # monkeypatch.setattr("wandb.apis.internal.Api.register_agent", mock_register_agent) - - # def mock_add_to_launch_queue(self, *args, **kwargs): - # assert "entry_point" in kwargs - # assert kwargs["entry_point"] == [ - # "python", - # "train.py", - # "--foo_arg=1", - # ] - - # monkeypatch.setattr( - # "wandb.sdk.launch.sweeps.scheduler.Scheduler._add_to_launch_queue", - # mock_add_to_launch_queue, - # ) def mock_launch_add(*args, **kwargs): return Mock(spec=public.QueuedRun) @@ -371,12 +298,18 @@ def mock_launch_add(*args, **kwargs): mock_launch_add, ) + def mock_get_run_state(*args, **kwargs): + return "runnning" + + api.get_run_state = mock_get_run_state + + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) + _scheduler = SweepScheduler( - api, sweep_id=sweep_id, entity=_entity, project=_project + api, sweep_id=sweep_id, entity=user, project=_project, num_workers=num_workers ) - - # for _heartbeat_agent in _scheduler._workers: - # assert not _heartbeat_agent.thread.is_alive() - assert _scheduler.state == SchedulerState.COMPLETED - # assert len(_scheduler._runs) == 1 - # assert _scheduler._runs["foo_run_1"].state == SimpleRunState.DEAD + assert _scheduler.state == SchedulerState.PENDING + assert _scheduler.is_alive() is True + _scheduler.start() + assert _scheduler._runs["mock-run-id-1"].state == SimpleRunState.DEAD diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index b150aa3b71b..614375a0888 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -96,10 +96,6 @@ def _run(self) -> None: def _exit(self) -> None: pass - @abstractmethod - def _kill_worker(self, worker_id: int) -> None: - pass - @property def state(self) -> SchedulerState: logger.debug(f"{LOG_PREFIX}Scheduler state is {self._state.name}") @@ -132,13 +128,8 @@ def run(self) -> None: while True: if not self.is_alive(): break - try: - self._update_run_states() - self._run() - except RuntimeError as e: - wandb.termlog( - f"{LOG_PREFIX}Scheduler encountered Runtime Error. {e} Trying again." - ) + self._update_run_states() + self._run() except KeyboardInterrupt: wandb.termlog(f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting.") self.state = SchedulerState.STOPPED @@ -176,10 +167,9 @@ def _stop_run(self, run_id: str) -> None: """Stops a run and removes it from the scheduler""" if run_id in self._runs: run: SweepRun = self._runs[run_id] - if run.worker_id is not None: - self._kill_worker(run.worker_id) - wandb.termlog(f"{LOG_PREFIX} Stopped run {run_id}.") run.state = SimpleRunState.DEAD + # TODO(hupo): Send command to backend to stop run + wandb.termlog(f"{LOG_PREFIX} Stopped run {run_id}.") def _update_run_states(self) -> None: _runs_to_remove: List[str] = [] diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index 4309be17c67..1ebfd8cde50 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -5,7 +5,6 @@ import pprint import queue import socket -import threading import time from typing import Any, Dict, List @@ -28,8 +27,6 @@ class _Worker: agent_config: Dict[str, Any] agent_id: str - thread: threading.Thread - stop: threading.Event class SweepScheduler(Scheduler): @@ -47,13 +44,12 @@ def __init__( **kwargs: Any, ): super().__init__(*args, **kwargs) - # Threading is used to run multiple workers in parallel. Workers do not + # Optionally run multiple workers in (pseudo-)parallel. Workers do not # actually run training workloads, they simply send heartbeat messages # (emulating a real agent) and add new runs to the launch queue. The # launch agent is the one that actually runs the training workloads. self._workers: Dict[int, _Worker] = {} self._num_workers: int = num_workers - self._worker_sleep: float = worker_sleep # Thread will pop items off the Sweeps RunQueue using AgentHeartbeat # and put them in this internal queue, which will be used to populate # the Launch RunQueue @@ -70,64 +66,51 @@ def _start(self) -> None: project_name=self._project, entity=self._entity, ) - - def excepthook(args): - print(f"In excepthook {args}") - - threading.excepthook = excepthook - - # Worker threads call heartbeat function - _thread = threading.Thread(target=self._heartbeat, args=[worker_id]) - - # _thread.daemon = True self._workers[worker_id] = _Worker( agent_config=agent_config, agent_id=agent_config["id"], - thread=_thread, - # Worker threads will be killed with an Event - stop=threading.Event(), ) - _thread.start() def _heartbeat(self, worker_id: int) -> None: - while True: - # Make sure Scheduler is alive - if not self.is_alive(): - return - # Check to see if worker thread has been orderred to stop - if self._workers[worker_id].stop.is_set(): - return - # AgentHeartbeat wants a Dict of runs which are running or queued - _run_states: Dict[str, bool] = {} - for run_id, run in self._yield_runs(): - # Filter out runs that are from a different worker thread - if run.worker_id == worker_id and run.state == SimpleRunState.ALIVE: - _run_states[run_id] = True - logger.debug( - f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n" - ) - commands: List[Dict[str, Any]] = self._api.agent_heartbeat( - self._workers[worker_id].agent_id, # agent_id: str - {}, # metrics: dict - _run_states, # run_states: dict - ) - logger.debug( - f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n" - ) - if commands: - for command in commands: - # The command "type" can be one of "run", "resume", "stop", "exit" - _type = command.get("type", None) - if _type in ["exit", "stop"]: - # Tell (virtual) agent to stop running - self.exit() - return - elif _type in ["run", "resume"]: - _run_id = command.get("run_id", None) - if _run_id is None: - raise SchedulerError( - f"{LOG_PREFIX}AgentHeartbeat command {command} missing run_id" - ) + # Make sure Scheduler is alive + if not self.is_alive(): + return + # AgentHeartbeat wants a Dict of runs which are running or queued + _run_states: Dict[str, bool] = {} + for run_id, run in self._yield_runs(): + # Filter out runs that are from a different worker thread + if run.worker_id == worker_id and run.state == SimpleRunState.ALIVE: + _run_states[run_id] = True + logger.debug( + f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n" + ) + commands: List[Dict[str, Any]] = self._api.agent_heartbeat( + self._workers[worker_id].agent_id, # agent_id: str + {}, # metrics: dict + _run_states, # run_states: dict + ) + logger.debug( + f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n" + ) + if commands: + for command in commands: + # The command "type" can be one of "run", "resume", "stop", "exit" + _type = command.get("type", None) + if _type in ["exit", "stop"]: + # Tell (virtual) agent to stop running + self.state = SchedulerState.STOPPED + self.exit() + return + elif _type in ["run", "resume"]: + _run_id = command.get("run_id", None) + if _run_id is None: + self.state = SchedulerState.FAILED + raise SchedulerError( + f"AgentHeartbeat command {command} missing run_id" + ) + if _run_id in self._runs: + wandb.termlog(f"{LOG_PREFIX} Skipping duplicate run {run_id}") + else: run = SweepRun( id=_run_id, args=command.get("args", {}), @@ -137,17 +120,14 @@ def _heartbeat(self, worker_id: int) -> None: ) self._runs[run.id] = run self._heartbeat_queue.put(run) - continue - else: - raise SchedulerError( - f"{LOG_PREFIX}Unknown command type {_type}" - ) - time.sleep(self._worker_sleep) + else: + self.state = SchedulerState.FAILED + raise SchedulerError(f"AgentHeartbeat unknown command type {_type}") def _run(self) -> None: - # # Join worker threads to check for exceptions - # for worker_id in self._workers: - # self._workers[worker_id].thread.join(timeout=0.1) + # Go through all workers and heartbeat + for worker_id in self._workers.keys(): + self._heartbeat(worker_id) try: run: SweepRun = self._heartbeat_queue.get( timeout=self._heartbeat_queue_timeout @@ -183,18 +163,5 @@ def _run(self) -> None: run_id=run.id, ) - def _kill_worker(self, worker_id: int) -> None: - print(f"{LOG_PREFIX}Killing AgentHeartbeat worker {worker_id}") - _worker = self._workers.get(worker_id, None) - if _worker and _worker.thread.is_alive(): - # Set threading event to stop the worker thread - _worker.stop.set() - print(f"{LOG_PREFIX}Killing AgentHeartbeat worker {worker_id}") - _worker.thread.join() - print(f"{LOG_PREFIX}AgentHeartbeat worker {worker_id} killed") - def _exit(self) -> None: - # Kill all the worker threads - for worker_id in self._workers: - self._kill_worker(worker_id) - self.state = SchedulerState.COMPLETED + pass From 730c7b59e0557241a6d3652c0168335ce4a01780 Mon Sep 17 00:00:00 2001 From: Hu Po Date: Tue, 16 Aug 2022 10:53:53 -0500 Subject: [PATCH 16/17] more sweep configs --- tests/unit_tests/test_sweep_scheduler.py | 6 +-- tests/unit_tests/test_wandb_sweep.py | 49 ++++++++++++++++++++++-- wandb/sdk/launch/sweeps/scheduler.py | 3 +- 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 4edf6d8f691..1c90af752ad 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -202,7 +202,7 @@ def mock_run_add_to_launch_queue(self, *args, **kwargs): @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -@pytest.mark.parametrize("num_workers", [1, 3, 8]) +@pytest.mark.parametrize("num_workers", [1, 8]) def test_sweep_scheduler_sweeps_stop_agent_hearbeat(user, sweep_config, num_workers): api = internal.Api() @@ -221,7 +221,7 @@ def mock_agent_heartbeat(*args, **kwargs): @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -@pytest.mark.parametrize("num_workers", [1, 3, 8]) +@pytest.mark.parametrize("num_workers", [1, 8]) def test_sweep_scheduler_sweeps_invalid_agent_heartbeat( user, sweep_config, num_workers ): @@ -269,7 +269,7 @@ def mock_agent_heartbeat(*args, **kwargs): @pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) -@pytest.mark.parametrize("num_workers", [1, 3]) +@pytest.mark.parametrize("num_workers", [1, 8]) def test_sweep_scheduler_sweeps_run_and_heartbeat( user, sweep_config, num_workers, monkeypatch ): diff --git a/tests/unit_tests/test_wandb_sweep.py b/tests/unit_tests/test_wandb_sweep.py index 5eb79d98739..b7cc789b096 100644 --- a/tests/unit_tests/test_wandb_sweep.py +++ b/tests/unit_tests/test_wandb_sweep.py @@ -33,6 +33,44 @@ "metric": {"name": "metric1", "goal": "maximize"}, "parameters": {"param1": {"values": [1, 2, 3]}}, } +SWEEP_CONFIG_BAYES_PROBABILITIES: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize"}, + "parameters": { + "param1": {"values": [1, 2, 3]}, + "param2": {"values": [1, 2, 3], "probabilities": [0.1, 0.2, 0.1]}, + }, +} +SWEEP_CONFIG_BAYES_DISTRIBUTION: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize"}, + "parameters": { + "param1": {"distribution": "normal", "mu": 100, "sigma": 10}, + }, +} +SWEEP_CONFIG_BAYES_DISTRIBUTION_NESTED: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize"}, + "parameters": { + "param1": {"values": [1, 2, 3]}, + "param2": { + "parameters": { + "param3": {"distribution": "q_uniform", "min": 0, "max": 256, "q": 1} + }, + }, + }, +} +SWEEP_CONFIG_BAYES_TARGET: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize", "target": 0.99}, + "parameters": { + "param1": {"distribution": "normal", "mu": 100, "sigma": 10}, + }, +} SWEEP_CONFIG_RANDOM: Dict[str, Any] = { "name": "mock-sweep-random", "method": "random", @@ -42,14 +80,19 @@ # Minimal list of valid sweep configs VALID_SWEEP_CONFIGS_MINIMAL: List[Dict[str, Any]] = [ SWEEP_CONFIG_BAYES, - # SWEEP_CONFIG_RANDOM, - # SWEEP_CONFIG_GRID_HYPERBAND, - # SWEEP_CONFIG_GRID_NESTED, + SWEEP_CONFIG_RANDOM, + SWEEP_CONFIG_GRID_HYPERBAND, + SWEEP_CONFIG_GRID_NESTED, ] # All valid sweep configs, be careful as this will slow down tests VALID_SWEEP_CONFIGS_ALL: List[Dict[str, Any]] = [ SWEEP_CONFIG_RANDOM, SWEEP_CONFIG_BAYES, + # TODO: Probabilities seem to error out? + # SWEEP_CONFIG_BAYES_PROBABILITIES, + SWEEP_CONFIG_BAYES_DISTRIBUTION, + SWEEP_CONFIG_BAYES_DISTRIBUTION_NESTED, + SWEEP_CONFIG_BAYES_TARGET, SWEEP_CONFIG_GRID, SWEEP_CONFIG_GRID_NESTED, SWEEP_CONFIG_GRID_HYPERBAND, diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index 614375a0888..3ba30a0b715 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -226,7 +226,8 @@ def _add_to_launch_queue( project=self._project, entity=self._entity, queue=_queue, - **self._kwargs, + resource=self._kwargs.get("resource", None), + resource_args=self._kwargs.get("resource_args", None), ) self._runs[run_id].queued_run = queued_run wandb.termlog( From cbfc4b1de5244466cff5a62ad39ef8432a68052f Mon Sep 17 00:00:00 2001 From: Hu Po Date: Tue, 16 Aug 2022 15:54:21 -0500 Subject: [PATCH 17/17] review suggestions --- tests/unit_tests/test_sweep_scheduler.py | 3 ++- wandb/sdk/launch/sweeps/scheduler.py | 3 ++- wandb/sdk/launch/sweeps/scheduler_sweep.py | 5 ++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 1c90af752ad..53faa64f3ac 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -4,6 +4,7 @@ import pytest import wandb from wandb.apis import internal, public +from wandb.errors import CommError from wandb.sdk.launch.sweeps import load_scheduler, SchedulerError from wandb.sdk.launch.sweeps.scheduler import ( Scheduler, @@ -148,7 +149,7 @@ def mock_get_run_state(entity, project, run_id, *args, **kwargs): # ---- If get_run_state errors out, runs should have the state UNKNOWN def mock_get_run_state_raise_exception(*args, **kwargs): - raise Exception("Generic Exception") + raise CommError("Generic Exception") api.get_run_state = mock_get_run_state_raise_exception _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index 3ba30a0b715..f2cdd2f6efa 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -11,6 +11,7 @@ import wandb from wandb.apis.internal import Api import wandb.apis.public as public +from wandb.errors import CommError from wandb.sdk.launch.launch_add import launch_add from wandb.sdk.launch.sweeps import SchedulerError from wandb.sdk.lib.runid import generate_id @@ -191,7 +192,7 @@ def _update_run_states(self) -> None: "preempting", ]: run.state = SimpleRunState.ALIVE - except Exception as e: + except CommError as e: wandb.termlog( f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}" ) diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index 1ebfd8cde50..c310383c666 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -38,9 +38,8 @@ def __init__( self, *args: Any, num_workers: int = 4, - worker_sleep: float = 0.1, - heartbeat_queue_timeout: float = 0.1, - heartbeat_queue_sleep: float = 0.1, + heartbeat_queue_timeout: float = 1.0, + heartbeat_queue_sleep: float = 1.0, **kwargs: Any, ): super().__init__(*args, **kwargs)