diff --git a/tests/unit_tests/test_wandb_artifacts_full.py b/tests/unit_tests/test_wandb_artifacts_full.py index d160dcecce7..75ea605c679 100644 --- a/tests/unit_tests/test_wandb_artifacts_full.py +++ b/tests/unit_tests/test_wandb_artifacts_full.py @@ -4,6 +4,7 @@ import numpy as np import pytest import wandb +from wandb.errors import WaitTimeoutError sm = wandb.wandb_sdk.internal.sender.SendManager @@ -193,3 +194,29 @@ def make_table(): artifact2.add(t1, "t2") assert artifact2.manifest.entries["t2.table.json"].ref is not None run.finish() + + +def test_artifact_wait_success(wandb_init): + # Test artifact wait() timeout parameter + timeout = 60 + leeway = 0.50 + run = wandb_init() + artifact = wandb.Artifact("art", type="dataset") + start_timestamp = time.time() + run.log_artifact(artifact).wait(timeout=timeout) + elapsed_time = time.time() - start_timestamp + assert elapsed_time < timeout * (1 + leeway) + run.finish() + + +@pytest.mark.parametrize("timeout", [0, 1e-6]) +def test_artifact_wait_failure(wandb_init, timeout): + # Test to expect WaitTimeoutError when wait timeout is reached and large image + # wasn't uploaded yet + image = wandb.Image(np.random.randint(0, 255, (10, 10))) + run = wandb_init() + with pytest.raises(WaitTimeoutError): + artifact = wandb.Artifact("art", type="image") + artifact.add(image, "image") + run.log_artifact(artifact).wait(timeout=timeout) + run.finish() diff --git a/wandb/errors/__init__.py b/wandb/errors/__init__.py index 6659f6a2af3..b50dba7cbe9 100644 --- a/wandb/errors/__init__.py +++ b/wandb/errors/__init__.py @@ -102,6 +102,12 @@ class SweepError(Error): pass +class WaitTimeoutError(Error): + """Raised when wait() timeout occurs before process is finished""" + + pass + + __all__ = [ "Error", "UsageError", @@ -114,4 +120,5 @@ class SweepError(Error): "ExecutionError", "LaunchError", "SweepError", + "WaitTimeoutError", ] diff --git a/wandb/sdk/wandb_artifacts.py b/wandb/sdk/wandb_artifacts.py index 6d8c3f80ee3..096ad17c150 100644 --- a/wandb/sdk/wandb_artifacts.py +++ b/wandb/sdk/wandb_artifacts.py @@ -647,9 +647,13 @@ def delete(self) -> None: "Cannot call delete on an artifact before it has been logged or in offline mode" ) - def wait(self) -> ArtifactInterface: + def wait(self, timeout: Optional[int] = None) -> ArtifactInterface: + """ + Arguments: + timeout: (int, optional) Waits in seconds for artifact to finish logging if needed. + """ if self._logged_artifact: - return self._logged_artifact.wait() + return self._logged_artifact.wait(timeout) # type: ignore [call-arg] raise ValueError( "Cannot call wait on an artifact before it has been logged or in offline mode" diff --git a/wandb/sdk/wandb_run.py b/wandb/sdk/wandb_run.py index e06d85f07ec..a63e176791e 100644 --- a/wandb/sdk/wandb_run.py +++ b/wandb/sdk/wandb_run.py @@ -3584,9 +3584,14 @@ def __getattr__(self, item: str) -> Any: self._assert_instance() return getattr(self._instance, item) - def wait(self) -> ArtifactInterface: + def wait(self, timeout: Optional[int] = None) -> ArtifactInterface: if not self._instance: - resp = self._future.get().response.log_artifact_response + future_get = self._future.get(timeout) + if not future_get: + raise errors.WaitTimeoutError( + "Artifact upload wait timed out, failed to fetch Artifact response" + ) + resp = future_get.response.log_artifact_response if resp.error_message: raise ValueError(resp.error_message) self._instance = public.Artifact.from_id(resp.artifact_id, self._api.client)