From 27a13ccdec4db6375f21c8cdb92ae3886f11ca74 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Thu, 10 Nov 2022 15:03:03 -0800 Subject: [PATCH] feat(sweeps): Adds Sweep.expected_run_count to public Api (#4434) --- tests/unit_tests/test_public_api.py | 48 +++++++++++++++++++ tests/unit_tests/test_wandb_artifacts_full.py | 6 +++ tests/unit_tests/test_wandb_sweep.py | 10 +++- wandb/apis/public.py | 37 ++++++++++++-- 4 files changed, 96 insertions(+), 5 deletions(-) diff --git a/tests/unit_tests/test_public_api.py b/tests/unit_tests/test_public_api.py index 5e58fd90e5c..7fc2efd2766 100644 --- a/tests/unit_tests/test_public_api.py +++ b/tests/unit_tests/test_public_api.py @@ -9,6 +9,14 @@ import wandb.util from wandb import Api +from .test_wandb_sweep import ( + SWEEP_CONFIG_BAYES, + SWEEP_CONFIG_GRID, + SWEEP_CONFIG_GRID_NESTED, + SWEEP_CONFIG_RANDOM, + VALID_SWEEP_CONFIGS_MINIMAL, +) + def test_api_auto_login_no_tty(): with pytest.raises(wandb.UsageError): @@ -204,6 +212,46 @@ def test_artifact_download_logger(): termlog.assert_not_called() +@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL) +def test_sweep_api(user, relay_server, sweep_config): + _project = "test" + with relay_server(): + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) + print(f"sweep_id{sweep_id}") + sweep = Api().sweep(f"{user}/{_project}/sweeps/{sweep_id}") + assert sweep.entity == user + assert f"{user}/{_project}/sweeps/{sweep_id}" in sweep.url + assert sweep.state == "PENDING" + assert str(sweep) == f"" + + +@pytest.mark.parametrize( + "sweep_config,expected_run_count", + [ + (SWEEP_CONFIG_GRID, 3), + (SWEEP_CONFIG_GRID_NESTED, 9), + (SWEEP_CONFIG_BAYES, None), + (SWEEP_CONFIG_RANDOM, None), + ], + ids=["test grid", "test grid nested", "test bayes", "test random"], +) +def test_sweep_api_expected_run_count( + user, relay_server, sweep_config, expected_run_count +): + _project = "test" + with relay_server() as relay: + sweep_id = wandb.sweep(sweep_config, entity=user, project=_project) + + for comm in relay.context.raw_data: + q = comm["request"].get("query") + print(q) + + print(f"sweep_id{sweep_id}") + sweep = Api().sweep(f"{user}/{_project}/sweeps/{sweep_id}") + + assert sweep.expected_run_count == expected_run_count + + def test_update_aliases_on_artifact(user, relay_server, wandb_init): project = "test" run = wandb_init(entity=user, project=project) diff --git a/tests/unit_tests/test_wandb_artifacts_full.py b/tests/unit_tests/test_wandb_artifacts_full.py index 2a4585f32a1..16c2a7ac570 100644 --- a/tests/unit_tests/test_wandb_artifacts_full.py +++ b/tests/unit_tests/test_wandb_artifacts_full.py @@ -222,6 +222,12 @@ def test_artifact_wait_failure(wandb_init, timeout): run.finish() +@pytest.mark.skip( + reason="TODO(spencerpearson): this test passes locally, but flakes in CI. After much investigation, I still have no clue.", + # examples of flakes: + # https://app.circleci.com/pipelines/github/wandb/wandb/16334/workflows/319d3e58-853e-46ec-8a3f-088cac41351c/jobs/325741/tests#failed-test-0 + # https://app.circleci.com/pipelines/github/wandb/wandb/16392/workflows/b26b3e63-c8d8-45f4-b7db-00f84b11f8b8/jobs/327312 +) def test_artifact_metadata_save(wandb_init, relay_server): # Test artifact metadata sucessfully saved for len(numpy) > 32 dummy_metadata = np.array([0] * 33) diff --git a/tests/unit_tests/test_wandb_sweep.py b/tests/unit_tests/test_wandb_sweep.py index b7cc789b096..e300c94e555 100644 --- a/tests/unit_tests/test_wandb_sweep.py +++ b/tests/unit_tests/test_wandb_sweep.py @@ -25,7 +25,15 @@ SWEEP_CONFIG_GRID_NESTED: Dict[str, Any] = { "name": "mock-sweep-grid", "method": "grid", - "parameters": {"param1": {"parameters": {"param2": {"values": [1, 2, 3]}}}}, + "parameters": { + "param1": {"values": [1, 2, 3]}, + "param2": { + "parameters": { + "param3": {"values": [1, 2, 3]}, + "param4": {"value": 1}, + } + }, + }, } SWEEP_CONFIG_BAYES: Dict[str, Any] = { "name": "mock-sweep-bayes", diff --git a/wandb/apis/public.py b/wandb/apis/public.py index 3d86566555f..9d8c2221b6c 100644 --- a/wandb/apis/public.py +++ b/wandb/apis/public.py @@ -2479,10 +2479,28 @@ class Sweep(Attrs): project: (str) name of project config: (str) dictionary of sweep configuration state: (str) the state of the sweep + expected_run_count: (int) number of expected runs for the sweep """ QUERY = gql( """ + query Sweep($project: String, $entity: String, $name: String!) { + project(name: $project, entityName: $entity) { + sweep(sweepName: $name) { + id + name + state + runCountExpected + bestLoss + config + } + } + } + """ + ) + + LEGACY_QUERY = gql( + """ query Sweep($project: String, $entity: String, $name: String!) { project(name: $project, entityName: $entity) { sweep(sweepName: $name) { @@ -2565,6 +2583,11 @@ def best_run(self, order=None): except IndexError: return None + @property + def expected_run_count(self) -> Optional[int]: + "Returns the number of expected runs in the sweep or None for infinite runs." + return self._attrs.get("runCountExpected") + @property def path(self): return [ @@ -2605,10 +2628,16 @@ def get( } variables.update(kwargs) - response = client.execute(query, variable_values=variables) - if response.get("project") is None: - return None - elif response["project"].get("sweep") is None: + response = None + try: + response = client.execute(query, variable_values=variables) + except Exception: + # Don't handle exception, rely on legacy query + # TODO(gst): Implement updated introspection workaround + query = cls.LEGACY_QUERY + response = client.execute(query, variable_values=variables) + + if not response or not response.get("project", {}).get("sweep"): return None sweep_response = response["project"]["sweep"]