Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(artifacts): add programmatic alias addition/removal from SDK on artifacts #4429

Merged
merged 22 commits into from Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/unit_tests/test_public_api.py
Expand Up @@ -212,6 +212,31 @@ def test_artifact_download_logger():
termlog.assert_not_called()


def test_update_aliases_on_artifact(user, relay_server, wandb_init):
project = "test"
run = wandb_init(entity=user, project=project)
artifact = wandb.Artifact("test-artifact", "test-type")
with open("boom.txt", "w") as f:
f.write("testing")
artifact.add_file("boom.txt", "test-name")
run.log_artifact(artifact, aliases=["best"])
artifact.wait()
run.finish()

artifact = Api().artifact(
name=f"{user}/{project}/test-artifact:v0", type="test-type"
)
artifact.update_aliases(add=["staging"], remove=["best"])
tssweeney marked this conversation as resolved.
Show resolved Hide resolved
artifact.save()

artifact = Api().artifact(
vwrj marked this conversation as resolved.
Show resolved Hide resolved
name=f"{user}/{project}/test-artifact:v0", type="test-type"
)
aliases = artifact.aliases
assert "staging" in aliases
assert "best" not in aliases


@pytest.mark.parametrize("sweep_config", VALID_SWEEP_CONFIGS_MINIMAL)
def test_sweep_api(user, relay_server, sweep_config):
_project = "test"
Expand Down
119 changes: 118 additions & 1 deletion wandb/apis/public.py
Expand Up @@ -4182,19 +4182,23 @@ def __init__(self, client, entity, project, name, attrs=None):
self._entity = entity
self._project = project
self._artifact_name = name
self._artifact_collection_name = name.split(":")[0] if ":" in name else name
vwrj marked this conversation as resolved.
Show resolved Hide resolved
self._attrs = attrs
if self._attrs is None:
self._load()
self._metadata = json.loads(self._attrs.get("metadata") or "{}")
self._description = self._attrs.get("description", None)
self._sequence_name = self._attrs["artifactSequence"]["name"]
self._version_index = self._attrs.get("versionIndex", None)
# We will only show aliases under the Collection this artifact version is fetched from
self._aliases = [
a["alias"]
for a in self._attrs["aliases"]
if not re.match(r"^v\d+$", a["alias"])
and a["artifactCollectionName"] == self._sequence_name
and a["artifactCollectionName"] == self._artifact_collection_name
]
self._aliases_to_add = []
self._aliases_to_remove = []
self._manifest = None
self._is_downloaded = False
self._dependent_artifacts = []
Expand Down Expand Up @@ -4354,6 +4358,20 @@ def _use_as(self, use_as):
self._attrs["_use_as"] = use_as
return use_as

@normalize_exceptions
def update_aliases(self, add: List[str], remove: List[str]):
vwrj marked this conversation as resolved.
Show resolved Hide resolved
# edit aliases locally
# .save() will persist these changes to backend
# TODO: Print out a message in the SDK on finish if the user has
# unfinished changes in their artifact (such as updating the aliases).
for alias in add:
if alias not in self._aliases_to_add:
self._aliases_to_add.append(alias)

for alias in remove:
if alias not in self._aliases_to_remove:
self._aliases_to_remove.append(alias)

@normalize_exceptions
def link(self, target_path: str, aliases=None):
if ":" in target_path:
Expand Down Expand Up @@ -4715,11 +4733,110 @@ def save(self):
],
},
)
# Save locally modified aliases
vwrj marked this conversation as resolved.
Show resolved Hide resolved
self._save_alias_changes()
return True

def wait(self):
return self

@normalize_exceptions
def _save_alias_changes(self):
vwrj marked this conversation as resolved.
Show resolved Hide resolved
"""
Convenience function called by artifact.save() to persist alias changes
on this artifact to the wandb backend.
"""

# Introspect
introspect_query = gql(
vwrj marked this conversation as resolved.
Show resolved Hide resolved
"""
query ProbeServerAddAliasesInput {
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
name
inputFields {
name
}
}
}
"""
)
res = self.client.execute(introspect_query)
valid = res.get("AddAliasesInputInfoType") or None
if not valid:
return

if len(self._aliases_to_add) > 0:
add_mutation = gql(
"""
mutation addAliases(
$artifactID: ID!,
$aliases: [ArtifactCollectionAliasInput!]!,
) {
addAliases(
input: {
artifactID: $artifactID,
aliases: $aliases,
}
) {
success
}
}
"""
)
self.client.execute(
add_mutation,
variable_values={
"artifactID": self.id,
"aliases": [
{
"artifactCollectionName": self._artifact_collection_name,
"alias": alias,
"entityName": self._entity,
"projectName": self._project,
}
for alias in self._aliases_to_add
],
},
)

if len(self._aliases_to_remove) > 0:
delete_mutation = gql(
"""
mutation deleteAliases(
$artifactID: ID!,
$aliases: [ArtifactCollectionAliasInput!]!,
) {
deleteAliases(
input: {
artifactID: $artifactID,
aliases: $aliases,
}
) {
success
}
}
"""
)
self.client.execute(
delete_mutation,
variable_values={
"artifactID": self.id,
"aliases": [
{
"artifactCollectionName": self._artifact_collection_name,
"alias": alias,
"entityName": self._entity,
"projectName": self._project,
}
for alias in self._aliases_to_remove
],
},
)
# reset local state
self._aliases_to_add = []
self._aliases_to_remove = []
return True

# TODO: not yet public, but we probably want something like this.
def _list(self):
manifest = self._load_manifest()
Expand Down