Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sweeps): add Sweep.expected_run_count to public Api #4434

Merged
merged 14 commits into from
Nov 10, 2022
48 changes: 48 additions & 0 deletions tests/unit_tests/test_public_api.py
Expand Up @@ -9,6 +9,14 @@
import wandb.util
from wandb import Api

from .test_wandb_sweep import (
SWEEP_CONFIG_BAYES,
SWEEP_CONFIG_GRID,
SWEEP_CONFIG_GRID_NESTED,
SWEEP_CONFIG_RANDOM,
VALID_SWEEP_CONFIGS_MINIMAL,
)


def test_api_auto_login_no_tty():
with pytest.raises(wandb.UsageError):
Expand Down Expand Up @@ -204,6 +212,46 @@ def test_artifact_download_logger():
termlog.assert_not_called()


@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL)
def test_sweep_api(user, relay_server, sweep_config):
_project = "test"
with relay_server():
sweep_id = wandb.sweep(sweep_config, entity=user, project=_project)
print(f"sweep_id{sweep_id}")
sweep = Api().sweep(f"{user}/{_project}/sweeps/{sweep_id}")
assert sweep.entity == user
assert f"{user}/{_project}/sweeps/{sweep_id}" in sweep.url
assert sweep.state == "PENDING"
assert str(sweep) == f"<Sweep {user}/test/{sweep_id} (PENDING)>"


@pytest.mark.parametrize(
"sweep_config,expected_run_count",
[
(SWEEP_CONFIG_GRID, 3),
(SWEEP_CONFIG_GRID_NESTED, 9),
(SWEEP_CONFIG_BAYES, None),
(SWEEP_CONFIG_RANDOM, None),
],
ids=["test grid", "test grid nested", "test bayes", "test random"],
)
def test_sweep_api_expected_run_count(
user, relay_server, sweep_config, expected_run_count
):
_project = "test"
with relay_server() as relay:
sweep_id = wandb.sweep(sweep_config, entity=user, project=_project)

for comm in relay.context.raw_data:
q = comm["request"].get("query")
print(q)

print(f"sweep_id{sweep_id}")
sweep = Api().sweep(f"{user}/{_project}/sweeps/{sweep_id}")

assert sweep.expected_run_count == expected_run_count


def test_update_aliases_on_artifact(user, relay_server, wandb_init):
project = "test"
run = wandb_init(entity=user, project=project)
Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/test_wandb_artifacts_full.py
Expand Up @@ -222,6 +222,12 @@ def test_artifact_wait_failure(wandb_init, timeout):
run.finish()


@pytest.mark.skip(
reason="TODO(spencerpearson): this test passes locally, but flakes in CI. After much investigation, I still have no clue.",
# examples of flakes:
# https://app.circleci.com/pipelines/github/wandb/wandb/16334/workflows/319d3e58-853e-46ec-8a3f-088cac41351c/jobs/325741/tests#failed-test-0
# https://app.circleci.com/pipelines/github/wandb/wandb/16392/workflows/b26b3e63-c8d8-45f4-b7db-00f84b11f8b8/jobs/327312
)
def test_artifact_metadata_save(wandb_init, relay_server):
# Test artifact metadata sucessfully saved for len(numpy) > 32
dummy_metadata = np.array([0] * 33)
Expand Down
10 changes: 9 additions & 1 deletion tests/unit_tests/test_wandb_sweep.py
Expand Up @@ -25,7 +25,15 @@
SWEEP_CONFIG_GRID_NESTED: Dict[str, Any] = {
"name": "mock-sweep-grid",
"method": "grid",
"parameters": {"param1": {"parameters": {"param2": {"values": [1, 2, 3]}}}},
"parameters": {
"param1": {"values": [1, 2, 3]},
"param2": {
"parameters": {
"param3": {"values": [1, 2, 3]},
"param4": {"value": 1},
}
},
},
}
SWEEP_CONFIG_BAYES: Dict[str, Any] = {
"name": "mock-sweep-bayes",
Expand Down
37 changes: 33 additions & 4 deletions wandb/apis/public.py
Expand Up @@ -2479,10 +2479,28 @@ class Sweep(Attrs):
project: (str) name of project
config: (str) dictionary of sweep configuration
state: (str) the state of the sweep
expected_run_count: (int) number of expected runs for the sweep
"""

QUERY = gql(
"""
query Sweep($project: String, $entity: String, $name: String!) {
project(name: $project, entityName: $entity) {
sweep(sweepName: $name) {
id
name
state
runCountExpected
bestLoss
config
}
}
}
"""
)

LEGACY_QUERY = gql(
"""
query Sweep($project: String, $entity: String, $name: String!) {
project(name: $project, entityName: $entity) {
sweep(sweepName: $name) {
Expand Down Expand Up @@ -2565,6 +2583,11 @@ def best_run(self, order=None):
except IndexError:
return None

@property
def expected_run_count(self) -> Optional[int]:
"Returns the number of expected runs in the sweep or None for infinite runs."
return self._attrs.get("runCountExpected")

@property
def path(self):
return [
Expand Down Expand Up @@ -2605,10 +2628,16 @@ def get(
}
variables.update(kwargs)

response = client.execute(query, variable_values=variables)
if response.get("project") is None:
return None
elif response["project"].get("sweep") is None:
response = None
try:
response = client.execute(query, variable_values=variables)
except Exception:
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
# Don't handle exception, rely on legacy query
# TODO(gst): Implement updated introspection workaround
query = cls.LEGACY_QUERY
response = client.execute(query, variable_values=variables)

if not response or not response.get("project", {}).get("sweep"):
return None

sweep_response = response["project"]["sweep"]
Expand Down