Skip to content

Commit

Permalink
refactor(artifacts): Re-add programmatic alias addition/removal from …
Browse files Browse the repository at this point in the history
…SDK on artifacts (#4429)

* add update_aliases method and relevant gql mutations

* wip test

* added test for public api

* removed relay call

* code nit

* changed to alias updates in pythonic list form rather than separate method

* modify test

* backwards compatibility

* format

* collection name

* format

* modified test to ensure collection-specific alias changes occurred

* nit
  • Loading branch information
vwrj committed Nov 9, 2022
1 parent a49d918 commit 12fc330
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 8 deletions.
39 changes: 39 additions & 0 deletions tests/unit_tests/test_public_api.py
Expand Up @@ -202,3 +202,42 @@ def test_artifact_download_logger():
assert termlog.call_args == call
else:
termlog.assert_not_called()


def test_update_aliases_on_artifact(user, relay_server, wandb_init):
project = "test"
run = wandb_init(entity=user, project=project)
artifact = wandb.Artifact("test-artifact", "test-type")
with open("boom.txt", "w") as f:
f.write("testing")
artifact.add_file("boom.txt", "test-name")
art = run.log_artifact(artifact, aliases=["sequence"])
run.link_artifact(art, f"{user}/{project}/my-sample-portfolio")
artifact.wait()
run.finish()

# fetch artifact under original parent sequence
artifact = Api().artifact(
name=f"{user}/{project}/test-artifact:v0", type="test-type"
)
aliases = artifact.aliases
assert "sequence" in aliases

# fetch artifact under portfolio
# and change aliases under portfolio only
artifact = Api().artifact(
name=f"{user}/{project}/my-sample-portfolio:v0", type="test-type"
)
aliases = artifact.aliases
assert "sequence" not in aliases
artifact.aliases = ["portfolio"]
artifact.aliases.append("boom")
artifact.save()

artifact = Api().artifact(
name=f"{user}/{project}/my-sample-portfolio:v0", type="test-type"
)
aliases = artifact.aliases
assert "portfolio" in aliases
assert "boom" in aliases
assert "sequence" not in aliases
144 changes: 136 additions & 8 deletions wandb/apis/public.py
Expand Up @@ -4175,19 +4175,23 @@ def __init__(self, client, entity, project, name, attrs=None):
self._entity = entity
self._project = project
self._artifact_name = name
self._artifact_collection_name = name.split(":")[0]
self._attrs = attrs
if self._attrs is None:
self._load()
self._metadata = json.loads(self._attrs.get("metadata") or "{}")
self._description = self._attrs.get("description", None)
self._sequence_name = self._attrs["artifactSequence"]["name"]
self._version_index = self._attrs.get("versionIndex", None)
# We will only show aliases under the Collection this artifact version is fetched from
# _aliases will be a mutable copy on which the user can append or remove aliases
self._aliases = [
a["alias"]
for a in self._attrs["aliases"]
if not re.match(r"^v\d+$", a["alias"])
and a["artifactCollectionName"] == self._sequence_name
and a["artifactCollectionName"] == self._artifact_collection_name
]
self._frozen_aliases = [a for a in self._aliases]
self._manifest = None
self._is_downloaded = False
self._dependent_artifacts = []
Expand Down Expand Up @@ -4693,26 +4697,150 @@ def save(self):
}
"""
)
introspect_query = gql(
"""
query ProbeServerAddAliasesInput {
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
name
inputFields {
name
}
}
}
"""
)
res = self.client.execute(introspect_query)
valid = res.get("AddAliasesInputInfoType")
aliases = None
if not valid:
# If valid, wandb backend version >= 0.13.0.
# This means we can safely remove aliases from this updateArtifact request since we'll be calling
# the alias endpoints below in _save_alias_changes.
# If not valid, wandb backend version < 0.13.0. This requires aliases to be sent in updateArtifact.
aliases = [
{
"artifactCollectionName": self._artifact_collection_name,
"alias": alias,
}
for alias in self._aliases
]

self.client.execute(
mutation,
variable_values={
"artifactID": self.id,
"description": self.description,
"metadata": json.dumps(util.make_safe_for_json(self.metadata)),
"aliases": [
{
"artifactCollectionName": self._sequence_name,
"alias": alias,
}
for alias in self._aliases
],
"aliases": aliases,
},
)
# Save locally modified aliases
self._save_alias_changes()
return True

def wait(self):
return self

@normalize_exceptions
def _save_alias_changes(self):
"""
Convenience function called by artifact.save() to persist alias changes
on this artifact to the wandb backend.
"""

aliases_to_add = set(self._aliases) - set(self._frozen_aliases)
aliases_to_remove = set(self._frozen_aliases) - set(self._aliases)

# Introspect
introspect_query = gql(
"""
query ProbeServerAddAliasesInput {
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
name
inputFields {
name
}
}
}
"""
)
res = self.client.execute(introspect_query)
valid = res.get("AddAliasesInputInfoType")
if not valid:
return

if len(aliases_to_add) > 0:
add_mutation = gql(
"""
mutation addAliases(
$artifactID: ID!,
$aliases: [ArtifactCollectionAliasInput!]!,
) {
addAliases(
input: {
artifactID: $artifactID,
aliases: $aliases,
}
) {
success
}
}
"""
)
self.client.execute(
add_mutation,
variable_values={
"artifactID": self.id,
"aliases": [
{
"artifactCollectionName": self._artifact_collection_name,
"alias": alias,
"entityName": self._entity,
"projectName": self._project,
}
for alias in aliases_to_add
],
},
)

if len(aliases_to_remove) > 0:
delete_mutation = gql(
"""
mutation deleteAliases(
$artifactID: ID!,
$aliases: [ArtifactCollectionAliasInput!]!,
) {
deleteAliases(
input: {
artifactID: $artifactID,
aliases: $aliases,
}
) {
success
}
}
"""
)
self.client.execute(
delete_mutation,
variable_values={
"artifactID": self.id,
"aliases": [
{
"artifactCollectionName": self._artifact_collection_name,
"alias": alias,
"entityName": self._entity,
"projectName": self._project,
}
for alias in aliases_to_remove
],
},
)

# reset local state
self._frozen_aliases = self._aliases
return True

# TODO: not yet public, but we probably want something like this.
def _list(self):
manifest = self._load_manifest()
Expand Down

0 comments on commit 12fc330

Please sign in to comment.