Skip to content

Commit

Permalink
feat(sweeps): Adds Sweep.expected_run_count to public Api (#4434)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Nov 10, 2022
1 parent 61bdf1a commit 27a13cc
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 5 deletions.
48 changes: 48 additions & 0 deletions tests/unit_tests/test_public_api.py
Expand Up @@ -9,6 +9,14 @@
import wandb.util
from wandb import Api

from .test_wandb_sweep import (
SWEEP_CONFIG_BAYES,
SWEEP_CONFIG_GRID,
SWEEP_CONFIG_GRID_NESTED,
SWEEP_CONFIG_RANDOM,
VALID_SWEEP_CONFIGS_MINIMAL,
)


def test_api_auto_login_no_tty():
with pytest.raises(wandb.UsageError):
Expand Down Expand Up @@ -204,6 +212,46 @@ def test_artifact_download_logger():
termlog.assert_not_called()


@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL)
def test_sweep_api(user, relay_server, sweep_config):
_project = "test"
with relay_server():
sweep_id = wandb.sweep(sweep_config, entity=user, project=_project)
print(f"sweep_id{sweep_id}")
sweep = Api().sweep(f"{user}/{_project}/sweeps/{sweep_id}")
assert sweep.entity == user
assert f"{user}/{_project}/sweeps/{sweep_id}" in sweep.url
assert sweep.state == "PENDING"
assert str(sweep) == f"<Sweep {user}/test/{sweep_id} (PENDING)>"


@pytest.mark.parametrize(
"sweep_config,expected_run_count",
[
(SWEEP_CONFIG_GRID, 3),
(SWEEP_CONFIG_GRID_NESTED, 9),
(SWEEP_CONFIG_BAYES, None),
(SWEEP_CONFIG_RANDOM, None),
],
ids=["test grid", "test grid nested", "test bayes", "test random"],
)
def test_sweep_api_expected_run_count(
user, relay_server, sweep_config, expected_run_count
):
_project = "test"
with relay_server() as relay:
sweep_id = wandb.sweep(sweep_config, entity=user, project=_project)

for comm in relay.context.raw_data:
q = comm["request"].get("query")
print(q)

print(f"sweep_id{sweep_id}")
sweep = Api().sweep(f"{user}/{_project}/sweeps/{sweep_id}")

assert sweep.expected_run_count == expected_run_count


def test_update_aliases_on_artifact(user, relay_server, wandb_init):
project = "test"
run = wandb_init(entity=user, project=project)
Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/test_wandb_artifacts_full.py
Expand Up @@ -222,6 +222,12 @@ def test_artifact_wait_failure(wandb_init, timeout):
run.finish()


@pytest.mark.skip(
reason="TODO(spencerpearson): this test passes locally, but flakes in CI. After much investigation, I still have no clue.",
# examples of flakes:
# https://app.circleci.com/pipelines/github/wandb/wandb/16334/workflows/319d3e58-853e-46ec-8a3f-088cac41351c/jobs/325741/tests#failed-test-0
# https://app.circleci.com/pipelines/github/wandb/wandb/16392/workflows/b26b3e63-c8d8-45f4-b7db-00f84b11f8b8/jobs/327312
)
def test_artifact_metadata_save(wandb_init, relay_server):
# Test artifact metadata sucessfully saved for len(numpy) > 32
dummy_metadata = np.array([0] * 33)
Expand Down
10 changes: 9 additions & 1 deletion tests/unit_tests/test_wandb_sweep.py
Expand Up @@ -25,7 +25,15 @@
SWEEP_CONFIG_GRID_NESTED: Dict[str, Any] = {
"name": "mock-sweep-grid",
"method": "grid",
"parameters": {"param1": {"parameters": {"param2": {"values": [1, 2, 3]}}}},
"parameters": {
"param1": {"values": [1, 2, 3]},
"param2": {
"parameters": {
"param3": {"values": [1, 2, 3]},
"param4": {"value": 1},
}
},
},
}
SWEEP_CONFIG_BAYES: Dict[str, Any] = {
"name": "mock-sweep-bayes",
Expand Down
37 changes: 33 additions & 4 deletions wandb/apis/public.py
Expand Up @@ -2479,10 +2479,28 @@ class Sweep(Attrs):
project: (str) name of project
config: (str) dictionary of sweep configuration
state: (str) the state of the sweep
expected_run_count: (int) number of expected runs for the sweep
"""

QUERY = gql(
"""
query Sweep($project: String, $entity: String, $name: String!) {
project(name: $project, entityName: $entity) {
sweep(sweepName: $name) {
id
name
state
runCountExpected
bestLoss
config
}
}
}
"""
)

LEGACY_QUERY = gql(
"""
query Sweep($project: String, $entity: String, $name: String!) {
project(name: $project, entityName: $entity) {
sweep(sweepName: $name) {
Expand Down Expand Up @@ -2565,6 +2583,11 @@ def best_run(self, order=None):
except IndexError:
return None

@property
def expected_run_count(self) -> Optional[int]:
"Returns the number of expected runs in the sweep or None for infinite runs."
return self._attrs.get("runCountExpected")

@property
def path(self):
return [
Expand Down Expand Up @@ -2605,10 +2628,16 @@ def get(
}
variables.update(kwargs)

response = client.execute(query, variable_values=variables)
if response.get("project") is None:
return None
elif response["project"].get("sweep") is None:
response = None
try:
response = client.execute(query, variable_values=variables)
except Exception:
# Don't handle exception, rely on legacy query
# TODO(gst): Implement updated introspection workaround
query = cls.LEGACY_QUERY
response = client.execute(query, variable_values=variables)

if not response or not response.get("project", {}).get("sweep"):
return None

sweep_response = response["project"]["sweep"]
Expand Down

0 comments on commit 27a13cc

Please sign in to comment.