From 12fc3305879c4a686b05d232e871dd4929c38960 Mon Sep 17 00:00:00 2001 From: Vish Rajiv <8609620+vwrj@users.noreply.github.com> Date: Wed, 9 Nov 2022 15:41:54 -0800 Subject: [PATCH] refactor(artifacts): Re-add programmatic alias addition/removal from SDK on artifacts (#4429) * add update_aliases method and relevant gql mutations * wip test * added test for public api * removed relay call * code nit * changed to alias updates in pythonic list form rather than separate method * modify test * backwards compatibility * format * collection name * format * modified test to ensure collection-specific alias changes occurred * nit --- tests/unit_tests/test_public_api.py | 39 ++++++++ wandb/apis/public.py | 144 ++++++++++++++++++++++++++-- 2 files changed, 175 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/test_public_api.py b/tests/unit_tests/test_public_api.py index fedd66b914d..5e58fd90e5c 100644 --- a/tests/unit_tests/test_public_api.py +++ b/tests/unit_tests/test_public_api.py @@ -202,3 +202,42 @@ def test_artifact_download_logger(): assert termlog.call_args == call else: termlog.assert_not_called() + + +def test_update_aliases_on_artifact(user, relay_server, wandb_init): + project = "test" + run = wandb_init(entity=user, project=project) + artifact = wandb.Artifact("test-artifact", "test-type") + with open("boom.txt", "w") as f: + f.write("testing") + artifact.add_file("boom.txt", "test-name") + art = run.log_artifact(artifact, aliases=["sequence"]) + run.link_artifact(art, f"{user}/{project}/my-sample-portfolio") + artifact.wait() + run.finish() + + # fetch artifact under original parent sequence + artifact = Api().artifact( + name=f"{user}/{project}/test-artifact:v0", type="test-type" + ) + aliases = artifact.aliases + assert "sequence" in aliases + + # fetch artifact under portfolio + # and change aliases under portfolio only + artifact = Api().artifact( + name=f"{user}/{project}/my-sample-portfolio:v0", type="test-type" + ) + aliases = artifact.aliases + assert "sequence" not in aliases + artifact.aliases = ["portfolio"] + artifact.aliases.append("boom") + artifact.save() + + artifact = Api().artifact( + name=f"{user}/{project}/my-sample-portfolio:v0", type="test-type" + ) + aliases = artifact.aliases + assert "portfolio" in aliases + assert "boom" in aliases + assert "sequence" not in aliases diff --git a/wandb/apis/public.py b/wandb/apis/public.py index 90fb8fe1c21..3d86566555f 100644 --- a/wandb/apis/public.py +++ b/wandb/apis/public.py @@ -4175,6 +4175,7 @@ def __init__(self, client, entity, project, name, attrs=None): self._entity = entity self._project = project self._artifact_name = name + self._artifact_collection_name = name.split(":")[0] self._attrs = attrs if self._attrs is None: self._load() @@ -4182,12 +4183,15 @@ def __init__(self, client, entity, project, name, attrs=None): self._description = self._attrs.get("description", None) self._sequence_name = self._attrs["artifactSequence"]["name"] self._version_index = self._attrs.get("versionIndex", None) + # We will only show aliases under the Collection this artifact version is fetched from + # _aliases will be a mutable copy on which the user can append or remove aliases self._aliases = [ a["alias"] for a in self._attrs["aliases"] if not re.match(r"^v\d+$", a["alias"]) - and a["artifactCollectionName"] == self._sequence_name + and a["artifactCollectionName"] == self._artifact_collection_name ] + self._frozen_aliases = [a for a in self._aliases] self._manifest = None self._is_downloaded = False self._dependent_artifacts = [] @@ -4693,26 +4697,150 @@ def save(self): } """ ) + introspect_query = gql( + """ + query ProbeServerAddAliasesInput { + AddAliasesInputInfoType: __type(name: "AddAliasesInput") { + name + inputFields { + name + } + } + } + """ + ) + res = self.client.execute(introspect_query) + valid = res.get("AddAliasesInputInfoType") + aliases = None + if not valid: + # If valid, wandb backend version >= 0.13.0. + # This means we can safely remove aliases from this updateArtifact request since we'll be calling + # the alias endpoints below in _save_alias_changes. + # If not valid, wandb backend version < 0.13.0. This requires aliases to be sent in updateArtifact. + aliases = [ + { + "artifactCollectionName": self._artifact_collection_name, + "alias": alias, + } + for alias in self._aliases + ] + self.client.execute( mutation, variable_values={ "artifactID": self.id, "description": self.description, "metadata": json.dumps(util.make_safe_for_json(self.metadata)), - "aliases": [ - { - "artifactCollectionName": self._sequence_name, - "alias": alias, - } - for alias in self._aliases - ], + "aliases": aliases, }, ) + # Save locally modified aliases + self._save_alias_changes() return True def wait(self): return self + @normalize_exceptions + def _save_alias_changes(self): + """ + Convenience function called by artifact.save() to persist alias changes + on this artifact to the wandb backend. + """ + + aliases_to_add = set(self._aliases) - set(self._frozen_aliases) + aliases_to_remove = set(self._frozen_aliases) - set(self._aliases) + + # Introspect + introspect_query = gql( + """ + query ProbeServerAddAliasesInput { + AddAliasesInputInfoType: __type(name: "AddAliasesInput") { + name + inputFields { + name + } + } + } + """ + ) + res = self.client.execute(introspect_query) + valid = res.get("AddAliasesInputInfoType") + if not valid: + return + + if len(aliases_to_add) > 0: + add_mutation = gql( + """ + mutation addAliases( + $artifactID: ID!, + $aliases: [ArtifactCollectionAliasInput!]!, + ) { + addAliases( + input: { + artifactID: $artifactID, + aliases: $aliases, + } + ) { + success + } + } + """ + ) + self.client.execute( + add_mutation, + variable_values={ + "artifactID": self.id, + "aliases": [ + { + "artifactCollectionName": self._artifact_collection_name, + "alias": alias, + "entityName": self._entity, + "projectName": self._project, + } + for alias in aliases_to_add + ], + }, + ) + + if len(aliases_to_remove) > 0: + delete_mutation = gql( + """ + mutation deleteAliases( + $artifactID: ID!, + $aliases: [ArtifactCollectionAliasInput!]!, + ) { + deleteAliases( + input: { + artifactID: $artifactID, + aliases: $aliases, + } + ) { + success + } + } + """ + ) + self.client.execute( + delete_mutation, + variable_values={ + "artifactID": self.id, + "aliases": [ + { + "artifactCollectionName": self._artifact_collection_name, + "alias": alias, + "entityName": self._entity, + "projectName": self._project, + } + for alias in aliases_to_remove + ], + }, + ) + + # reset local state + self._frozen_aliases = self._aliases + return True + # TODO: not yet public, but we probably want something like this. def _list(self): manifest = self._load_manifest()