Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(artifacts): when artifact-commit 409s, retry entire artifact-creation, not just commit #4272

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e7b566f
retry 409s at higher level
speezepearson Sep 13, 2022
dd7e37e
refactor out _retry_conflicts
speezepearson Sep 14, 2022
8cb6ab5
introduce ad-hoc Future class to avoid awkward `nonlocal` var
speezepearson Sep 14, 2022
436c178
add explanatory comment
speezepearson Sep 15, 2022
b5c74ec
make functionally-public function actually-public
speezepearson Sep 15, 2022
25b9693
fix retry logic by introducing class ArtifactCommitFailed
speezepearson Sep 15, 2022
8749102
lint
speezepearson Sep 15, 2022
fcc93ab
Merge branch 'master' into spencerpearson/retry-conflict-higher-level
speezepearson Sep 16, 2022
cdd2401
Merge remote-tracking branch 'origin/master' into spencerpearson/retr…
speezepearson Sep 19, 2022
153932c
Merge branch 'spencerpearson/no-retry-conflict' into spencerpearson/r…
kptkin Sep 19, 2022
2ad2ca3
wip
speezepearson Sep 21, 2022
72e8efa
test(artifacts): add test for artifact retries, and GraphQL-injection…
speezepearson Sep 22, 2022
86258be
remove dead code
speezepearson Sep 22, 2022
95f7cf1
undo failed attempt to make make RelayServer error reporting better
speezepearson Sep 22, 2022
a0d01ba
add test for retry on 409
speezepearson Sep 22, 2022
394c448
Merge branch 'spencerpearson/no-retry-conflict' into spencerpearson/r…
speezepearson Sep 22, 2022
a85d0a6
lint
speezepearson Sep 22, 2022
7021af8
Merge branch 'master' into spencerpearson/relay-graphql
speezepearson Sep 22, 2022
db47ded
Merge branch 'spencerpearson/relay-graphql' into spencerpearson/retry…
speezepearson Sep 22, 2022
4cb2d1b
lint
speezepearson Sep 22, 2022
d338460
undo unnecessary isort changes
speezepearson Sep 22, 2022
f34121d
lint
speezepearson Sep 22, 2022
b980fea
fix Resolver methods, and make them stylistically more-closely-resemb…
speezepearson Sep 22, 2022
0a6b526
fix Resolver methods, and make them stylistically more-closely-resemb…
speezepearson Sep 22, 2022
75f4a1e
improve typing, add assertion
speezepearson Sep 22, 2022
4882721
undo unnecessary changes
speezepearson Sep 22, 2022
8ef08a7
make test more precise
speezepearson Sep 22, 2022
26a41fb
fix test
speezepearson Sep 23, 2022
cc526df
fix more tests by not assuming every context-entry has a `config` key
speezepearson Sep 23, 2022
1c719e0
Merge branch 'spencerpearson/relay-graphql' into spencerpearson/retry…
speezepearson Sep 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
126 changes: 105 additions & 21 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Dict,
Generator,
Iterable,
Expand Down Expand Up @@ -49,9 +50,9 @@
from wandb.sdk.lib.git import GitRepo

try:
from typing import Literal, TypedDict
from typing import Literal, Protocol, TypedDict
except ImportError:
from typing_extensions import Literal, TypedDict
from typing_extensions import Literal, Protocol, TypedDict

if TYPE_CHECKING:

Expand All @@ -74,6 +75,12 @@ class Resolver(TypedDict):
resolver: Callable[[Any], Optional[Dict[str, Any]]]


if TYPE_CHECKING:
import wandb_gql
else:
wandb_gql = wandb.util.vendor_import("wandb_gql")


class ConsoleFormatter:
BOLD = "\033[1m"
CODE = "\033[2m"
Expand Down Expand Up @@ -1054,7 +1061,7 @@ def config(self) -> Dict[str, Any]:
if self._config is not None:
return deepcopy(self._config)

self._config = {k: v["config"] for (k, v) in self._entries.items()}
self._config = {k: v.get("config") for (k, v) in self._entries.items()}
return deepcopy(self._config)

# @property
Expand Down Expand Up @@ -1142,9 +1149,14 @@ def __init__(self):
"name": "upsert_sweep",
"resolver": self.resolve_upsert_sweep,
},
# { "name": "create_artifact",
# "resolver": self.resolve_create_artifact,
# },
{
"name": "create_artifact",
"resolver": self.resolve_create_artifact,
},
{
"name": "commit_artifact",
"resolver": self.resolve_commit_artifact,
},
]

@staticmethod
Expand Down Expand Up @@ -1242,28 +1254,42 @@ def resolve_upsert_sweep(
return data
return None

@staticmethod
def resolve_create_artifact(
self, request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
) -> Optional[Dict[str, Any]]:
if not isinstance(request_data, dict):
if not isinstance(response_data, dict):
return None
query = (
"createArtifact(" in request_data.get("query", "")
and request_data.get("variables") is not None
and response_data is not None
)
query = response_data.get("data", {}).get("createArtifact") is not None
if query:
name = request_data["variables"]["runName"]
post_processed_data = {
"name": name,
return {
"name": response_data["data"]["createArtifact"]["artifact"]["id"],
"create_artifact": [
{
"variables": request_data["variables"],
"response": response_data["data"]["createArtifact"]["artifact"],
}
],
}
return post_processed_data
return None

@staticmethod
def resolve_commit_artifact(
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
) -> Optional[Dict[str, Any]]:
if not isinstance(response_data, dict):
return None
if "query" not in request_data:
return None
query = response_data.get("data", {}).get("commitArtifact") is not None
if query:
return {
"name": request_data["variables"]["artifactID"],
"commit_artifact": [
{
"variables": request_data["variables"],
}
],
}
return None

def resolve(
Expand All @@ -1289,6 +1315,7 @@ class InjectedResponse:
content_type: str = "text/plain"
# todo: add more fields for other types of responses?
counter: int = -1
predicate: Optional[Callable[[requests.PreparedRequest], bool]] = None

def __eq__(
self,
Expand All @@ -1307,14 +1334,21 @@ def __eq__(
return False
if self.counter == 0:
return False
# todo: add more fields for other types of responses?
return self.method == other.method and self.url == other.url
if self.method != other.method or self.url != other.url:
return False
if (
isinstance(other, requests.PreparedRequest)
and self.predicate is not None
and not self.predicate(other)
):
return False
return True

def to_dict(self):
return {
k: self.__getattribute__(k)
for k in self.__dict__
if (not k.startswith("_") and k != "counter")
if (not k.startswith("_") and k not in {"counter", "predicate"})
}


Expand Down Expand Up @@ -1534,6 +1568,13 @@ def storage_file(self, path) -> Mapping[str, str]:
return relayed_response.json()


class RelayServerFixture(Protocol):
def __call__(
self, inject: Optional[List[InjectedResponse]] = None
) -> ContextManager[RelayServer]:
...


@pytest.fixture(scope="function")
def relay_server(base_url):
"""
Expand Down Expand Up @@ -1677,3 +1718,46 @@ def helper(
)

yield helper


class InjectedGraphQLRequestCreator(Protocol):
def __call__(
self,
body: Union[str, Exception] = "{}",
status: int = 200,
counter: int = -1,
) -> InjectedResponse:
...


# Injected responses
@pytest.fixture(scope="function")
def inject_graphql_response(base_url: str) -> InjectedGraphQLRequestCreator:
def helper(
operation_name: str,
body: Union[str, Exception] = "{}",
status: int = 200,
counter: int = -1,
) -> InjectedResponse:

if status > 299:
message = body if isinstance(body, str) else "::".join(body.args)
body = DeliberateHTTPError(status_code=status, message=message)

def predicate(request: requests.PreparedRequest) -> bool:
query = wandb_gql.gql(json.loads(request.body)["query"])
return query.definitions[0].name.value == operation_name

return InjectedResponse(
method="POST",
url=urllib.parse.urljoin(
base_url,
"/graphql",
),
body=body,
status=status,
counter=counter,
predicate=predicate,
)

return helper
30 changes: 30 additions & 0 deletions tests/unit_tests/test_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import TYPE_CHECKING, Callable

import pytest
import wandb

if TYPE_CHECKING:
from .conftest import InjectedGraphQLRequestCreator, RelayServerFixture


@pytest.mark.parametrize("status", [409, 500])
def test_commit_retries_on_right_statuses(
relay_server: "RelayServerFixture",
wandb_init: Callable[[], wandb.wandb_sdk.wandb_run.Run],
inject_graphql_response: "InjectedGraphQLRequestCreator",
status: int,
):
art = wandb.Artifact("test", "dataset")

injected_resp = inject_graphql_response(
operation_name="CommitArtifact",
status=status,
counter=1,
)
with relay_server(inject=[injected_resp]) as relay:
run = wandb_init()
logged = run.log_artifact(art).wait()
run.finish()

# even though we made two requests, empirically only the successful one goes into the Context
assert relay.context.entries[logged.id]["commit_artifact"]
16 changes: 10 additions & 6 deletions wandb/filesync/step_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ArtifactStatus(TypedDict):


PreCommitFn = Callable[[], None]
PostCommitFn = Callable[[], None]
PostCommitFn = Callable[[Optional[Exception]], None]
OnRequestFinishFn = Callable[[], None]
SaveFn = Callable[["progress.ProgressFn"], Any]

Expand Down Expand Up @@ -216,12 +216,16 @@ def _maybe_commit_artifact(self, artifact_id: str) -> None:
artifact_status["pending_count"] == 0
and artifact_status["commit_requested"]
):
for callback in artifact_status["pre_commit_callbacks"]:
callback()
for pre_callback in artifact_status["pre_commit_callbacks"]:
pre_callback()
exc = None
if artifact_status["finalize"]:
self._api.commit_artifact(artifact_id)
for callback in artifact_status["post_commit_callbacks"]:
callback()
try:
self._api.commit_artifact(artifact_id)
except Exception as e:
exc = e
for post_callback in artifact_status["post_commit_callbacks"]:
post_callback(exc)

def start(self) -> None:
self._thread.start()
Expand Down