diff --git a/tests/unit_tests/test_sweep_scheduler.py b/tests/unit_tests/test_sweep_scheduler.py index 681e9f6c7d5..53faa64f3ac 100644 --- a/tests/unit_tests/test_sweep_scheduler.py +++ b/tests/unit_tests/test_sweep_scheduler.py @@ -2,7 +2,9 @@ from unittest.mock import Mock, patch import pytest +import wandb from wandb.apis import internal, public +from wandb.errors import CommError from wandb.sdk.launch.sweeps import load_scheduler, SchedulerError from wandb.sdk.launch.sweeps.scheduler import ( Scheduler, @@ -12,6 +14,8 @@ ) from wandb.sdk.launch.sweeps.scheduler_sweep import SweepScheduler +from .test_wandb_sweep import VALID_SWEEP_CONFIGS_MINIMAL + def test_sweep_scheduler_load(): _scheduler = load_scheduler("sweep") @@ -21,114 +25,153 @@ def test_sweep_scheduler_load(): @patch.multiple(Scheduler, __abstractmethods__=set()) -def test_sweep_scheduler_base_state(monkeypatch): - api = internal.Api() - - def mock_run_complete_scheduler(self, *args, **kwargs): - self.state = SchedulerState.COMPLETED - - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_complete_scheduler, - ) - - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.COMPLETED - assert _scheduler.is_alive() is False +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +def test_sweep_scheduler_entity_project_sweep_id(user, relay_server, sweep_config): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + # Entity, project, and sweep should be everything you need to create a scheduler + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + _ = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + # Bogus sweep id should result in error + with pytest.raises(SchedulerError): + _ = Scheduler( + api, sweep_id="foo-sweep-id", entity=_entity, project=_project + ) - def mock_run_raise_keyboard_interupt(*args, **kwargs): - raise KeyboardInterrupt - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_raise_keyboard_interupt, - ) +@patch.multiple(Scheduler, __abstractmethods__=set()) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +def test_sweep_scheduler_base_scheduler_states( + user, relay_server, sweep_config, monkeypatch +): + + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + + def mock_run_complete_scheduler(self, *args, **kwargs): + self.state = SchedulerState.COMPLETED + + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_complete_scheduler, + ) + + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + assert _scheduler.state == SchedulerState.PENDING + assert _scheduler.is_alive() is True + _scheduler.start() + assert _scheduler.state == SchedulerState.COMPLETED + assert _scheduler.is_alive() is False - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.STOPPED - assert _scheduler.is_alive() is False + def mock_run_raise_keyboard_interupt(*args, **kwargs): + raise KeyboardInterrupt - def mock_run_raise_exception(*args, **kwargs): - raise Exception("Generic exception") + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_raise_keyboard_interupt, + ) - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_raise_exception, - ) - - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - with pytest.raises(Exception) as e: + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) _scheduler.start() - assert "Generic exception" in str(e.value) - assert _scheduler.state == SchedulerState.FAILED - assert _scheduler.is_alive() is False + assert _scheduler.state == SchedulerState.STOPPED + assert _scheduler.is_alive() is False - def mock_run_exit(self, *args, **kwargs): - self.exit() + def mock_run_raise_exception(*args, **kwargs): + raise Exception("Generic exception") - monkeypatch.setattr( - "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", - mock_run_exit, - ) + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_raise_exception, + ) - _scheduler = Scheduler(api, entity="foo", project="bar") - assert _scheduler.state == SchedulerState.PENDING - assert _scheduler.is_alive() is True - _scheduler.start() - assert _scheduler.state == SchedulerState.FAILED - assert _scheduler.is_alive() is False + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + with pytest.raises(Exception) as e: + _scheduler.start() + assert "Generic exception" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False + def mock_run_exit(self, *args, **kwargs): + self.exit() -@patch.multiple(Scheduler, __abstractmethods__=set()) -def test_sweep_scheduler_base_run_state(): - api = internal.Api() - # Mock api.get_run_state() to return crashed and running runs - mock_run_states = { - "run1": ("crashed", SimpleRunState.DEAD), - "run2": ("failed", SimpleRunState.DEAD), - "run3": ("killed", SimpleRunState.DEAD), - "run4": ("finished", SimpleRunState.DEAD), - "run5": ("running", SimpleRunState.ALIVE), - "run6": ("pending", SimpleRunState.ALIVE), - "run7": ("preempted", SimpleRunState.ALIVE), - "run8": ("preempting", SimpleRunState.ALIVE), - } - - def mock_get_run_state(entity, project, run_id): - return mock_run_states[run_id][0] + monkeypatch.setattr( + "wandb.sdk.launch.sweeps.scheduler.Scheduler._run", + mock_run_exit, + ) - api.get_run_state = mock_get_run_state - _scheduler = Scheduler(api, entity="foo", project="bar") - for run_id in mock_run_states.keys(): - _scheduler._runs[run_id] = SweepRun(id=run_id, state=SimpleRunState.ALIVE) - _scheduler._update_run_states() - for run_id, _state in mock_run_states.items(): - assert _scheduler._runs[run_id].state == _state[1] + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + _scheduler.start() + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False - def mock_get_run_state_raise_exception(*args, **kwargs): - raise Exception("Generic Exception") - api.get_run_state = mock_get_run_state_raise_exception - _scheduler = Scheduler(api, entity="foo", project="bar") - _scheduler._runs["foo_run_1"] = SweepRun(id="foo_run_1", state=SimpleRunState.ALIVE) - _scheduler._runs["foo_run_2"] = SweepRun(id="foo_run_2", state=SimpleRunState.ALIVE) - _scheduler._update_run_states() - assert _scheduler._runs["foo_run_1"].state == SimpleRunState.UNKNOWN - assert _scheduler._runs["foo_run_2"].state == SimpleRunState.UNKNOWN +@patch.multiple(Scheduler, __abstractmethods__=set()) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +def test_sweep_scheduler_base_run_states(user, relay_server, sweep_config): + with relay_server(): + _entity = user + _project = "test-project" + api = internal.Api() + sweep_id = wandb.sweep(sweep_config, entity=_entity, project=_project) + + # Mock api.get_run_state() to return crashed and running runs + mock_run_states = { + "run1": ("crashed", SimpleRunState.DEAD), + "run2": ("failed", SimpleRunState.DEAD), + "run3": ("killed", SimpleRunState.DEAD), + "run4": ("finished", SimpleRunState.DEAD), + "run5": ("running", SimpleRunState.ALIVE), + "run6": ("pending", SimpleRunState.ALIVE), + "run7": ("preempted", SimpleRunState.ALIVE), + "run8": ("preempting", SimpleRunState.ALIVE), + } + + def mock_get_run_state(entity, project, run_id, *args, **kwargs): + return mock_run_states[run_id][0] + + api.get_run_state = mock_get_run_state + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + # Load up the runs into the Scheduler run dict + for run_id in mock_run_states.keys(): + _scheduler._runs[run_id] = SweepRun(id=run_id, state=SimpleRunState.ALIVE) + _scheduler._update_run_states() + for run_id, _state in mock_run_states.items(): + if _state[1] == SimpleRunState.DEAD: + # Dead runs should be removed from the run dict + assert run_id not in _scheduler._runs.keys() + else: + assert _scheduler._runs[run_id].state == _state[1] + + # ---- If get_run_state errors out, runs should have the state UNKNOWN + def mock_get_run_state_raise_exception(*args, **kwargs): + raise CommError("Generic Exception") + + api.get_run_state = mock_get_run_state_raise_exception + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=_entity, project=_project) + _scheduler._runs["foo_run_1"] = SweepRun( + id="foo_run_1", state=SimpleRunState.ALIVE + ) + _scheduler._runs["foo_run_2"] = SweepRun( + id="foo_run_2", state=SimpleRunState.ALIVE + ) + _scheduler._update_run_states() + assert _scheduler._runs["foo_run_1"].state == SimpleRunState.UNKNOWN + assert _scheduler._runs["foo_run_2"].state == SimpleRunState.UNKNOWN @patch.multiple(Scheduler, __abstractmethods__=set()) -def test_sweep_scheduler_base_add_to_launch_queue(monkeypatch): +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +def test_sweep_scheduler_base_add_to_launch_queue(user, sweep_config, monkeypatch): api = internal.Api() + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) + def mock_launch_add(*args, **kwargs): return Mock(spec=public.QueuedRun) @@ -148,7 +191,7 @@ def mock_run_add_to_launch_queue(self, *args, **kwargs): mock_run_add_to_launch_queue, ) - _scheduler = Scheduler(api, entity="foo", project="bar") + _scheduler = Scheduler(api, sweep_id=sweep_id, entity=user, project=_project) assert _scheduler.state == SchedulerState.PENDING assert _scheduler.is_alive() is True _scheduler.start() @@ -159,71 +202,95 @@ def mock_run_add_to_launch_queue(self, *args, **kwargs): assert _scheduler._runs["foo_run"].state == SimpleRunState.DEAD -@pytest.mark.xfail(reason="TODO(hupo): fix") -def test_sweep_scheduler_sweeps(monkeypatch): +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +@pytest.mark.parametrize("num_workers", [1, 8]) +def test_sweep_scheduler_sweeps_stop_agent_hearbeat(user, sweep_config, num_workers): + api = internal.Api() + + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "stop"}] + + api.agent_heartbeat = mock_agent_heartbeat + + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) + scheduler = SweepScheduler( + api, sweep_id=sweep_id, entity=user, project=_project, num_workers=num_workers + ) + scheduler.start() + assert scheduler.state == SchedulerState.STOPPED + + +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +@pytest.mark.parametrize("num_workers", [1, 8]) +def test_sweep_scheduler_sweeps_invalid_agent_heartbeat( + user, sweep_config, num_workers +): api = internal.Api() + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) + + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "foo"}] + + api.agent_heartbeat = mock_agent_heartbeat + + with pytest.raises(SchedulerError) as e: + _scheduler = SweepScheduler( + api, + sweep_id=sweep_id, + entity=user, + project=_project, + num_workers=num_workers, + ) + _scheduler.start() + + assert "unknown command" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False + + def mock_agent_heartbeat(*args, **kwargs): + return [{"type": "run"}] # No run_id should throw error + + api.agent_heartbeat = mock_agent_heartbeat + + with pytest.raises(SchedulerError) as e: + _scheduler = SweepScheduler( + api, + sweep_id=sweep_id, + entity=user, + project=_project, + num_workers=num_workers, + ) + _scheduler.start() + + assert "missing run_id" in str(e.value) + assert _scheduler.state == SchedulerState.FAILED + assert _scheduler.is_alive() is False + +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +@pytest.mark.parametrize("num_workers", [1, 8]) +def test_sweep_scheduler_sweeps_run_and_heartbeat( + user, sweep_config, num_workers, monkeypatch +): + api = internal.Api() + # Mock agent heartbeat stops after 10 heartbeats api.agent_heartbeat = Mock( side_effect=[ [ { "type": "run", - "run_id": "foo_run_1", + "run_id": "mock-run-id-1", "args": {"foo_arg": {"value": 1}}, "program": "train.py", } - ], - [ - { - "type": "stop", - "run_id": "foo_run_1", - } - ], - [ - { - "type": "resume", - "run_id": "foo_run_1", - "args": {"foo_arg": {"value": 1}}, - "program": "train.py", - } - ], - [ - { - "type": "exit", - } - ], + ] ] + * 10 + + [[{"type": "stop"}]] ) - def mock_sweep(self, sweep_id, *args, **kwargs): - if sweep_id == "404sweep": - return False - return True - - monkeypatch.setattr("wandb.apis.internal.Api.sweep", mock_sweep) - - def mock_register_agent(*args, **kwargs): - return {"id": "foo_agent_pid"} - - monkeypatch.setattr("wandb.apis.internal.Api.register_agent", mock_register_agent) - - # def mock_add_to_launch_queue(self, *args, **kwargs): - # assert "entry_point" in kwargs - # assert kwargs["entry_point"] == [ - # "python", - # "train.py", - # "--foo_arg=1", - # ] - - # monkeypatch.setattr( - # "wandb.sdk.launch.sweeps.scheduler.Scheduler._add_to_launch_queue", - # mock_add_to_launch_queue, - # ) - - with pytest.raises(SchedulerError) as e: - SweepScheduler(api, sweep_id="404sweep") - assert "Could not find sweep" in str(e.value) - def mock_launch_add(*args, **kwargs): return Mock(spec=public.QueuedRun) @@ -233,21 +300,17 @@ def mock_launch_add(*args, **kwargs): ) def mock_get_run_state(*args, **kwargs): - return "finished" + return "runnning" + + api.get_run_state = mock_get_run_state - monkeypatch.setattr("wandb.apis.internal.Api.get_run_state", mock_get_run_state) + _project = "test-project" + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) _scheduler = SweepScheduler( - api, - entity="mock-entity", - project="mock-project", - sweep_id="mock-sweep", + api, sweep_id=sweep_id, entity=user, project=_project, num_workers=num_workers ) assert _scheduler.state == SchedulerState.PENDING assert _scheduler.is_alive() is True _scheduler.start() - for _heartbeat_agent in _scheduler._heartbeat_agents: - assert not _heartbeat_agent.thread.is_alive() - assert _scheduler.state == SchedulerState.COMPLETED - assert len(_scheduler._runs) == 1 - assert _scheduler._runs["foo_run_1"].state == SimpleRunState.DEAD + assert _scheduler._runs["mock-run-id-1"].state == SimpleRunState.DEAD diff --git a/tests/unit_tests/test_wandb_sweep.py b/tests/unit_tests/test_wandb_sweep.py index ee5a05a7417..b7cc789b096 100644 --- a/tests/unit_tests/test_wandb_sweep.py +++ b/tests/unit_tests/test_wandb_sweep.py @@ -33,30 +33,80 @@ "metric": {"name": "metric1", "goal": "maximize"}, "parameters": {"param1": {"values": [1, 2, 3]}}, } +SWEEP_CONFIG_BAYES_PROBABILITIES: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize"}, + "parameters": { + "param1": {"values": [1, 2, 3]}, + "param2": {"values": [1, 2, 3], "probabilities": [0.1, 0.2, 0.1]}, + }, +} +SWEEP_CONFIG_BAYES_DISTRIBUTION: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize"}, + "parameters": { + "param1": {"distribution": "normal", "mu": 100, "sigma": 10}, + }, +} +SWEEP_CONFIG_BAYES_DISTRIBUTION_NESTED: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize"}, + "parameters": { + "param1": {"values": [1, 2, 3]}, + "param2": { + "parameters": { + "param3": {"distribution": "q_uniform", "min": 0, "max": 256, "q": 1} + }, + }, + }, +} +SWEEP_CONFIG_BAYES_TARGET: Dict[str, Any] = { + "name": "mock-sweep-bayes", + "method": "bayes", + "metric": {"name": "metric1", "goal": "maximize", "target": 0.99}, + "parameters": { + "param1": {"distribution": "normal", "mu": 100, "sigma": 10}, + }, +} SWEEP_CONFIG_RANDOM: Dict[str, Any] = { "name": "mock-sweep-random", "method": "random", "parameters": {"param1": {"values": [1, 2, 3]}}, } -# List of all valid base configurations -VALID_SWEEP_CONFIGS: List[Dict[str, Any]] = [ - SWEEP_CONFIG_GRID, +# Minimal list of valid sweep configs +VALID_SWEEP_CONFIGS_MINIMAL: List[Dict[str, Any]] = [ + SWEEP_CONFIG_BAYES, + SWEEP_CONFIG_RANDOM, SWEEP_CONFIG_GRID_HYPERBAND, SWEEP_CONFIG_GRID_NESTED, - SWEEP_CONFIG_BAYES, +] +# All valid sweep configs, be careful as this will slow down tests +VALID_SWEEP_CONFIGS_ALL: List[Dict[str, Any]] = [ SWEEP_CONFIG_RANDOM, + SWEEP_CONFIG_BAYES, + # TODO: Probabilities seem to error out? + # SWEEP_CONFIG_BAYES_PROBABILITIES, + SWEEP_CONFIG_BAYES_DISTRIBUTION, + SWEEP_CONFIG_BAYES_DISTRIBUTION_NESTED, + SWEEP_CONFIG_BAYES_TARGET, + SWEEP_CONFIG_GRID, + SWEEP_CONFIG_GRID_NESTED, + SWEEP_CONFIG_GRID_HYPERBAND, ] -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_ALL) def test_sweep_create(user, relay_server, sweep_config): with relay_server() as relay: sweep_id = wandb.sweep(sweep_config, entity=user) assert sweep_id in relay.context.entries -@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS) +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) def test_sweep_entity_project_callable(user, relay_server, sweep_config): def sweep_callable(): return sweep_config diff --git a/wandb/sdk/launch/deploys/Dockerfile b/wandb/sdk/launch/deploys/Dockerfile index df9db592ab9..350ce059695 100644 --- a/wandb/sdk/launch/deploys/Dockerfile +++ b/wandb/sdk/launch/deploys/Dockerfile @@ -1,4 +1,5 @@ FROM python:3.9-slim-bullseye +LABEL maintainer='Weights & Biases ' # install git RUN apt-get update && apt-get upgrade -y \ diff --git a/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep b/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep new file mode 100644 index 00000000000..c8a1d6f665a --- /dev/null +++ b/wandb/sdk/launch/sweeps/Dockerfile.scheduler.sweep @@ -0,0 +1,19 @@ +FROM python:3.9-slim-bullseye +LABEL maintainer='Weights & Biases ' +LABEL version="0.1" + +# install git +RUN apt-get update && apt-get upgrade -y \ + && apt-get install -y git \ + && apt-get -qy autoremove \ + && apt-get clean && rm -r /var/lib/apt/lists/* + +# required pip packages +RUN pip install --no-cache-dir wandb[launch] +# user set up +RUN useradd -m -s /bin/bash --create-home --no-log-init -u 1000 -g 0 wandb_scheduler +USER wandb_scheduler +WORKDIR /home/wandb_scheduler +RUN chown -R wandb_scheduler /home/wandb_scheduler + +ENTRYPOINT ["wandb", "scheduler"] diff --git a/wandb/sdk/launch/sweeps/scheduler.py b/wandb/sdk/launch/sweeps/scheduler.py index 8acb21d761f..f2cdd2f6efa 100644 --- a/wandb/sdk/launch/sweeps/scheduler.py +++ b/wandb/sdk/launch/sweeps/scheduler.py @@ -1,3 +1,4 @@ +"""Abstract Scheduler class.""" from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -10,11 +11,13 @@ import wandb from wandb.apis.internal import Api import wandb.apis.public as public +from wandb.errors import CommError from wandb.sdk.launch.launch_add import launch_add +from wandb.sdk.launch.sweeps import SchedulerError from wandb.sdk.lib.runid import generate_id logger = logging.getLogger(__name__) -LOG_PREFIX = f"{click.style('sched:', fg='cyan')}: " +LOG_PREFIX = f"{click.style('sched:', fg='cyan')} " class SchedulerState(Enum): @@ -29,7 +32,7 @@ class SchedulerState(Enum): class SimpleRunState(Enum): ALIVE = 0 DEAD = 1 - UNKNOWN = 3 + UNKNOWN = 2 @dataclass @@ -40,6 +43,8 @@ class SweepRun: args: Optional[Dict[str, Any]] = None logs: Optional[List[str]] = None program: Optional[str] = None + # Threading can be used to run multiple workers in parallel + worker_id: Optional[int] = None class Scheduler(ABC): @@ -50,16 +55,11 @@ class Scheduler(ABC): def __init__( self, api: Api, - *args: Any, + *args: Optional[Any], + sweep_id: str = None, entity: Optional[str] = None, project: Optional[str] = None, - # ------- Begin Launch Options ------- - queue: Optional[str] = None, - job: Optional[str] = None, - resource: Optional[str] = None, - resource_args: Optional[Dict[str, Any]] = None, - # ------- End Launch Options ------- - **kwargs: Any, + **kwargs: Optional[Any], ): self._api = api self._entity = ( @@ -71,18 +71,19 @@ def __init__( self._project = ( project or os.environ.get("WANDB_PROJECT") or api.settings("project") ) - # ------- Begin Launch Options ------- - # TODO(hupo): Validation on these arguments. - self._launch_queue = queue - self._job = job - self._resource = resource - self._resource_args = resource_args - if resource == "kubernetes": - self._resource_args = {"kubernetes": {}} - # ------- End Launch Options ------- + # Make sure the provided sweep_id corresponds to a valid sweep + try: + self._api.sweep(sweep_id, "{}", entity=self._entity, project=self._project) + except Exception as e: + raise SchedulerError(f"{LOG_PREFIX}Exception when finding sweep: {e}") + self._sweep_id: str = sweep_id or "empty-sweep-id" self._state: SchedulerState = SchedulerState.PENDING - self._threading_lock: threading.Lock = threading.Lock() + # Dictionary of the runs being managed by the scheduler self._runs: Dict[str, SweepRun] = {} + # Threading lock to ensure thread-safe access to the runs dictionary + self._threading_lock: threading.Lock = threading.Lock() + # Scheduler may receive additional kwargs which will be piped into the launch command + self._kwargs: Dict[str, Any] = kwargs @abstractmethod def _start(self) -> None: @@ -103,9 +104,7 @@ def state(self) -> SchedulerState: @state.setter def state(self, value: SchedulerState) -> None: - logger.debug( - f"{LOG_PREFIX}Changing Scheduler state from {self.state.name} to {value.name}" - ) + logger.debug(f"{LOG_PREFIX}Scheduler was {self.state.name} is {value.name}") self._state = value def is_alive(self) -> bool: @@ -118,47 +117,32 @@ def is_alive(self) -> bool: return True def start(self) -> None: - _msg = f"{LOG_PREFIX}Scheduler starting." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler starting.") self._state = SchedulerState.STARTING self._start() self.run() def run(self) -> None: - _msg = f"{LOG_PREFIX}Scheduler Running." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler Running.") self.state = SchedulerState.RUNNING try: while True: if not self.is_alive(): break - try: - self._update_run_states() - self._run() - except RuntimeError as e: - _msg = f"{LOG_PREFIX}Scheduler encountered Runtime Error. {e} Trying again." - logger.debug(_msg) - wandb.termlog(_msg) + self._update_run_states() + self._run() except KeyboardInterrupt: - _msg = f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting.") self.state = SchedulerState.STOPPED self.exit() return except Exception as e: - _msg = f"{LOG_PREFIX}Scheduler failed with exception {e}" - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler failed with exception {e}") self.state = SchedulerState.FAILED self.exit() raise e else: - _msg = "Scheduler completed." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Scheduler completed.") self.exit() def exit(self) -> None: @@ -168,15 +152,28 @@ def exit(self) -> None: SchedulerState.STOPPED, ]: self.state = SchedulerState.FAILED - for run_id, _ in self._yield_runs(): - self._stop_run(run_id) + self._stop_runs() def _yield_runs(self) -> Iterator[Tuple[str, SweepRun]]: """Thread-safe way to iterate over the runs.""" with self._threading_lock: yield from self._runs.items() + def _stop_runs(self) -> None: + for run_id, _ in self._yield_runs(): + wandb.termlog(f"{LOG_PREFIX}Stopping run {run_id}.") + self._stop_run(run_id) + + def _stop_run(self, run_id: str) -> None: + """Stops a run and removes it from the scheduler""" + if run_id in self._runs: + run: SweepRun = self._runs[run_id] + run.state = SimpleRunState.DEAD + # TODO(hupo): Send command to backend to stop run + wandb.termlog(f"{LOG_PREFIX} Stopped run {run_id}.") + def _update_run_states(self) -> None: + _runs_to_remove: List[str] = [] for run_id, run in self._yield_runs(): try: _state = self._api.get_run_state(self._entity, self._project, run_id) @@ -187,6 +184,7 @@ def _update_run_states(self) -> None: "finished", ]: run.state = SimpleRunState.DEAD + _runs_to_remove.append(run_id) elif _state in [ "running", "pending", @@ -194,12 +192,17 @@ def _update_run_states(self) -> None: "preempting", ]: run.state = SimpleRunState.ALIVE - except Exception as e: - _msg = f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}" - logger.debug(_msg) - wandb.termlog(_msg) + except CommError as e: + wandb.termlog( + f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}" + ) run.state = SimpleRunState.UNKNOWN continue + # Remove any runs that are dead + with self._threading_lock: + for run_id in _runs_to_remove: + wandb.termlog(f"{LOG_PREFIX}Cleaning up dead run {run_id}.") + del self._runs[run_id] def _add_to_launch_queue( self, @@ -208,28 +211,27 @@ def _add_to_launch_queue( ) -> "public.QueuedRun": """Add a launch job to the Launch RunQueue.""" run_id = run_id or generate_id() + # One of Job and URI is required + _job = self._kwargs.get("job", None) + _uri = self._kwargs.get("uri", None) + if _job is None and _uri is None: + # If no Job is specified, use a placeholder URI to prevent Launch failure + _uri = "placeholder-uri-queuedrun-from-scheduler" + # Queue is required + _queue = self._kwargs.get("queue", "default") queued_run = launch_add( - # TODO(hupo): If no Job is specified, use a placeholder URI to prevent Launch failure - uri=None if self._job is not None else "placeholder-uri-queuedrun", - job=self._job, + run_id=run_id, + entry_point=entry_point, + uri=_uri, + job=_job, project=self._project, entity=self._entity, - queue=self._launch_queue, - entry_point=entry_point, - resource=self._resource, - resource_args=self._resource_args, - run_id=run_id, + queue=_queue, + resource=self._kwargs.get("resource", None), + resource_args=self._kwargs.get("resource_args", None), ) self._runs[run_id].queued_run = queued_run - _msg = f"{LOG_PREFIX}Added run to Launch RunQueue: {self._launch_queue} RunID:{run_id}." - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog( + f"{LOG_PREFIX}Added run to Launch RunQueue: {_queue} RunID:{run_id}." + ) return queued_run - - def _stop_run(self, run_id: str) -> None: - _msg = f"{LOG_PREFIX}Stopping run {run_id}." - logger.debug(_msg) - wandb.termlog(_msg) - run = self._runs.get(run_id, None) - if run is not None: - run.state = SimpleRunState.DEAD diff --git a/wandb/sdk/launch/sweeps/scheduler_sweep.py b/wandb/sdk/launch/sweeps/scheduler_sweep.py index c2663ab9a8c..c310383c666 100644 --- a/wandb/sdk/launch/sweeps/scheduler_sweep.py +++ b/wandb/sdk/launch/sweeps/scheduler_sweep.py @@ -1,12 +1,12 @@ +"""Scheduler for classic wandb Sweeps.""" from dataclasses import dataclass import logging import os import pprint import queue import socket -import threading import time -from typing import Any, List, Optional +from typing import Any, Dict, List import wandb from wandb import wandb_lib # type: ignore @@ -24,10 +24,9 @@ @dataclass -class HeartbeatAgent: - agent: dict - id: str - thread: threading.Thread +class _Worker: + agent_config: Dict[str, Any] + agent_id: str class SweepScheduler(Scheduler): @@ -38,110 +37,103 @@ class SweepScheduler(Scheduler): def __init__( self, *args: Any, - sweep_id: Optional[str] = None, num_workers: int = 4, - heartbeat_thread_sleep: float = 0.5, - heartbeat_queue_timeout: float = 1, - main_thread_sleep: float = 1, + heartbeat_queue_timeout: float = 1.0, + heartbeat_queue_sleep: float = 1.0, **kwargs: Any, ): super().__init__(*args, **kwargs) - # Make sure the provided sweep_id corresponds to a valid sweep - found = self._api.sweep( - sweep_id, "{}", entity=self._entity, project=self._project - ) - if not found: - raise SchedulerError( - f"{LOG_PREFIX}Could not find sweep {self._entity}/{self._project}/{sweep_id}" - ) - self._sweep_id = sweep_id - # Threading is used to run multiple workers in parallel + # Optionally run multiple workers in (pseudo-)parallel. Workers do not + # actually run training workloads, they simply send heartbeat messages + # (emulating a real agent) and add new runs to the launch queue. The + # launch agent is the one that actually runs the training workloads. + self._workers: Dict[int, _Worker] = {} self._num_workers: int = num_workers - self._heartbeat_thread_sleep: float = heartbeat_thread_sleep - self._heartbeat_queue_timeout: float = heartbeat_queue_timeout - self._main_thread_sleep: float = main_thread_sleep # Thread will pop items off the Sweeps RunQueue using AgentHeartbeat # and put them in this internal queue, which will be used to populate # the Launch RunQueue self._heartbeat_queue: "queue.Queue[SweepRun]" = queue.Queue() - # Emulation of N agents in a classic sweeps setup - self._heartbeat_agents: List[HeartbeatAgent] = [] + self._heartbeat_queue_timeout: float = heartbeat_queue_timeout + self._heartbeat_queue_sleep: float = heartbeat_queue_sleep def _start(self) -> None: - for worker_idx in range(self._num_workers): - logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker {worker_idx}\n") - _agent = self._api.register_agent( - f"{socket.gethostname()}-{worker_idx}", # host + for worker_id in range(self._num_workers): + logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker {worker_id}\n") + agent_config = self._api.register_agent( + f"{socket.gethostname()}-{worker_id}", # host sweep_id=self._sweep_id, project_name=self._project, entity=self._entity, ) - _thread = threading.Thread(target=self._heartbeat, args=[worker_idx]) - _thread.daemon = True - self._heartbeat_agents.append( - HeartbeatAgent( - agent=_agent, - id=_agent["id"], - thread=_thread, - ) + self._workers[worker_id] = _Worker( + agent_config=agent_config, + agent_id=agent_config["id"], ) - _thread.start() - def _heartbeat(self, worker_idx: int) -> None: - while True: - if not self.is_alive(): - return - # AgentHeartbeat wants dict of runs which are running or queued - _run_states = {} - for run_id, run in self._yield_runs(): - if run.state == SimpleRunState.ALIVE: - _run_states[run_id] = True - _msg = ( - f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n" - ) - logger.debug(_msg) - # TODO(hupo): Should be sub-set of _run_states specific to worker thread - commands = self._api.agent_heartbeat( - self._heartbeat_agents[worker_idx].id, {}, _run_states - ) - if commands: - _msg = f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n" - logger.debug(_msg) - for command in commands: - _type = command.get("type") - # type can be one of "run", "resume", "stop", "exit" - if _type == "exit": - self.state = SchedulerState.COMPLETED - self.exit() - return - if _type == "stop": - # TODO(hupo): Debug edge cases while stopping with active runs - self.state = SchedulerState.COMPLETED - self.exit() - return - run = SweepRun( - id=command.get("run_id"), - args=command.get("args"), - logs=command.get("logs"), - program=command.get("program"), - ) - with self._threading_lock: + def _heartbeat(self, worker_id: int) -> None: + # Make sure Scheduler is alive + if not self.is_alive(): + return + # AgentHeartbeat wants a Dict of runs which are running or queued + _run_states: Dict[str, bool] = {} + for run_id, run in self._yield_runs(): + # Filter out runs that are from a different worker thread + if run.worker_id == worker_id and run.state == SimpleRunState.ALIVE: + _run_states[run_id] = True + logger.debug( + f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n" + ) + commands: List[Dict[str, Any]] = self._api.agent_heartbeat( + self._workers[worker_id].agent_id, # agent_id: str + {}, # metrics: dict + _run_states, # run_states: dict + ) + logger.debug( + f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n" + ) + if commands: + for command in commands: + # The command "type" can be one of "run", "resume", "stop", "exit" + _type = command.get("type", None) + if _type in ["exit", "stop"]: + # Tell (virtual) agent to stop running + self.state = SchedulerState.STOPPED + self.exit() + return + elif _type in ["run", "resume"]: + _run_id = command.get("run_id", None) + if _run_id is None: + self.state = SchedulerState.FAILED + raise SchedulerError( + f"AgentHeartbeat command {command} missing run_id" + ) + if _run_id in self._runs: + wandb.termlog(f"{LOG_PREFIX} Skipping duplicate run {run_id}") + else: + run = SweepRun( + id=_run_id, + args=command.get("args", {}), + logs=command.get("logs", []), + program=command.get("program", None), + worker_id=worker_id, + ) self._runs[run.id] = run - if _type in ["run", "resume"]: self._heartbeat_queue.put(run) - continue - time.sleep(self._heartbeat_thread_sleep) + else: + self.state = SchedulerState.FAILED + raise SchedulerError(f"AgentHeartbeat unknown command type {_type}") def _run(self) -> None: + # Go through all workers and heartbeat + for worker_id in self._workers.keys(): + self._heartbeat(worker_id) try: run: SweepRun = self._heartbeat_queue.get( timeout=self._heartbeat_queue_timeout ) except queue.Empty: - _msg = f"{LOG_PREFIX}No jobs in Sweeps RunQueue, waiting..." - logger.debug(_msg) - wandb.termlog(_msg) - time.sleep(self._main_thread_sleep) + wandb.termlog(f"{LOG_PREFIX}No jobs in Sweeps RunQueue, waiting...") + time.sleep(self._heartbeat_queue_sleep) return # If run is already stopped just ignore the request if run.state in [ @@ -149,9 +141,9 @@ def _run(self) -> None: SimpleRunState.UNKNOWN, ]: return - _msg = f"{LOG_PREFIX}Converting Sweep Run (RunID:{run.id}) to Launch Job" - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog( + f"{LOG_PREFIX}Converting Sweep Run (RunID:{run.id}) to Launch Job" + ) # This is actually what populates the wandb config # since it is used in wandb.init() sweep_param_path = os.path.join( @@ -160,9 +152,7 @@ def _run(self) -> None: f"sweep-{self._sweep_id}", f"config-{run.id}.yaml", ) - _msg = f"{LOG_PREFIX}Saving params to {sweep_param_path}" - logger.debug(_msg) - wandb.termlog(_msg) + wandb.termlog(f"{LOG_PREFIX}Saving params to {sweep_param_path}") wandb_lib.config_util.save_config_file_from_dict(sweep_param_path, run.args) # Construct entry point using legacy sweeps utilities command_args = LegacySweepAgent._create_command_args({"args": run.args})["args"] @@ -173,4 +163,4 @@ def _run(self) -> None: ) def _exit(self) -> None: - self.state = SchedulerState.COMPLETED + pass