diff --git a/wandb/sdk/launch/utils.py b/wandb/sdk/launch/utils.py index 76c2a577ad9..54bdabb21aa 100644 --- a/wandb/sdk/launch/utils.py +++ b/wandb/sdk/launch/utils.py @@ -414,6 +414,20 @@ def apply_patch(patch_string: str, dst_dir: str) -> None: raise wandb.Error("Failed to apply diff.patch associated with run.") +def _make_refspec_from_version(version: Optional[str]) -> List[str]: + """ + Helper to create a refspec that checks for the existence of origin/main + and the version, if provided. + """ + if version: + return [f"+{version}"] + + return [ + "+refs/heads/main*:refs/remotes/origin/main*", + "+refs/heads/master*:refs/remotes/origin/master*", + ] + + def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str: """Clones the git repo at ``uri`` into ``dst_dir``. @@ -428,7 +442,8 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str: _logger.info("Fetching git repo") repo = git.Repo.init(dst_dir) origin = repo.create_remote("origin", uri) - origin.fetch() + refspec = _make_refspec_from_version(version) + origin.fetch(refspec=refspec, depth=1) if version is not None: try: