Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(artifacts): add optional timeout parameter to artifacts wait() #4181

Merged
merged 14 commits into from Aug 31, 2022
Merged
27 changes: 27 additions & 0 deletions tests/unit_tests/test_wandb_artifacts_full.py
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest
import wandb
from wandb.sdk.wandb_run import WaitTimeoutError

sm = wandb.wandb_sdk.internal.sender.SendManager

Expand Down Expand Up @@ -193,3 +194,29 @@ def make_table():
artifact2.add(t1, "t2")
assert artifact2.manifest.entries["t2.table.json"].ref is not None
run.finish()


def test_artifact_wait_success(wandb_init):
# Test artifact wait() timeout parameter
timeout = 2
estellazx marked this conversation as resolved.
Show resolved Hide resolved
leeway = 0.50
run = wandb_init()
artifact = wandb.Artifact("art", type="dataset")
start_timestamp = time.time()
run.log_artifact(artifact).wait(wait_timeout_secs=timeout)
elapsed_time = time.time() - start_timestamp
assert elapsed_time < timeout * (1 + leeway)
run.finish()


@pytest.mark.parametrize("timeout", [0, 1e-1])
estellazx marked this conversation as resolved.
Show resolved Hide resolved
def test_artifact_wait_failure(wandb_init, timeout):
# Test to expect WaitTimeoutError when wait timeout is reached and large image
# wasn't uploaded yet
large_image = wandb.Image(np.zeros((1000, 1000)))
estellazx marked this conversation as resolved.
Show resolved Hide resolved
run = wandb_init()
with pytest.raises(WaitTimeoutError):
artifact = wandb.Artifact("art", type="image")
artifact.add(large_image, "image")
run.log_artifact(artifact).wait(wait_timeout_secs=timeout)
run.finish()
1 change: 1 addition & 0 deletions tests/unit_tests/test_wandb_run.py
Expand Up @@ -2,6 +2,7 @@
import pickle
import platform
import sys
import time
estellazx marked this conversation as resolved.
Show resolved Hide resolved
from unittest import mock

import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions wandb/sdk/wandb_artifacts.py
Expand Up @@ -647,9 +647,9 @@ def delete(self) -> None:
"Cannot call delete on an artifact before it has been logged or in offline mode"
)

def wait(self) -> ArtifactInterface:
def wait(self, wait_timeout_secs: int = None) -> ArtifactInterface:
estellazx marked this conversation as resolved.
Show resolved Hide resolved
if self._logged_artifact:
return self._logged_artifact.wait()
return self._logged_artifact.wait(wait_timeout_secs)

raise ValueError(
"Cannot call wait on an artifact before it has been logged or in offline mode"
Expand Down
15 changes: 13 additions & 2 deletions wandb/sdk/wandb_run.py
Expand Up @@ -135,6 +135,12 @@ class TeardownHook(NamedTuple):
stage: TeardownStage


class WaitTimeoutError(Exception):
estellazx marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, message):
super().__init__(message)
self.message = message


class RunStatusChecker:
"""Periodically polls the background process for relevant updates.

Expand Down Expand Up @@ -3575,9 +3581,14 @@ def __getattr__(self, item: str) -> Any:
self._assert_instance()
return getattr(self._instance, item)

def wait(self) -> ArtifactInterface:
def wait(self, wait_timeout_secs: int = None) -> ArtifactInterface:
estellazx marked this conversation as resolved.
Show resolved Hide resolved
if not self._instance:
resp = self._future.get().response.log_artifact_response
future_get = self._future.get(wait_timeout_secs)
if not future_get:
raise WaitTimeoutError(
"Artifact wait timed out, failed to fetch Artifact response"
)
resp = future_get.response.log_artifact_response
if resp.error_message:
raise ValueError(resp.error_message)
self._instance = public.Artifact.from_id(resp.artifact_id, self._api.client)
Expand Down