diff --git a/tests/unit_tests_old/tests_launch/test_launch.py b/tests/unit_tests_old/tests_launch/test_launch.py index e47df6216e5..12af3753808 100644 --- a/tests/unit_tests_old/tests_launch/test_launch.py +++ b/tests/unit_tests_old/tests_launch/test_launch.py @@ -38,13 +38,25 @@ def mocked_fetchable_git_repo(): m = mock.Mock() def populate_dst_dir(dst_dir): + repo = mock.Mock() + reference = mock.Mock() + reference.name = "master" + repo.references = [reference] + + def create_remote(o, r): + origin = mock.Mock() + origin.refs = {"master": mock.Mock()} + return origin + + repo.create_remote = create_remote + repo.heads = {"master": mock.Mock()} with open(os.path.join(dst_dir, "train.py"), "w") as f: f.write(fixture_open("train.py").read()) with open(os.path.join(dst_dir, "requirements.txt"), "w") as f: f.write(fixture_open("requirements.txt").read()) with open(os.path.join(dst_dir, "patch.txt"), "w") as f: f.write("test") - return mock.Mock() + return repo m.Repo.init = mock.Mock(side_effect=populate_dst_dir) with mock.patch.dict("sys.modules", git=m): @@ -1459,3 +1471,79 @@ def test_launch_no_url_job_or_docker_image( ) except wandb.errors.LaunchError as e: assert "Must specify a uri, job or docker image" in str(e) + + +@pytest.fixture +def mocked_fetchable_git_repo_main(): + """Gross fixture for explicit branch name -- TODO: fix using parameterization (?)""" + m = mock.Mock() + + def populate_dst_dir(dst_dir): + repo = mock.Mock() + reference = mock.Mock() + reference.name = "main" + repo.references = [reference] + + def create_remote(o, r): + origin = mock.Mock() + origin.refs = {"main": mock.Mock()} + return origin + + repo.create_remote = create_remote + repo.heads = {"main": mock.Mock()} + + with open(os.path.join(dst_dir, "train.py"), "w") as f: + f.write(fixture_open("train.py").read()) + with open(os.path.join(dst_dir, "requirements.txt"), "w") as f: + f.write(fixture_open("requirements.txt").read()) + with open(os.path.join(dst_dir, "patch.txt"), "w") as f: + f.write("test") + return repo + + m.Repo.init = mock.Mock(side_effect=populate_dst_dir) + with mock.patch.dict("sys.modules", git=m): + yield m + + +def test_launch_git_version_branch_set( + live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend +): + api = wandb.sdk.internal.internal_api.Api( + default_settings=test_settings, load_settings=False + ) + mock_with_run_info = launch.run( + api=api, uri="https://foo:bar@github.com/FooTest/Foo.git", version="foobar" + ) + + assert "foobar" in str(mock_with_run_info.args[0].git_version) + + +def test_launch_git_version_default_master( + live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend +): + api = wandb.sdk.internal.internal_api.Api( + default_settings=test_settings, load_settings=False + ) + mock_with_run_info = launch.run( + api=api, + uri="https://foo:bar@github.com/FooTest/Foo.git", + ) + + assert "master" in str(mock_with_run_info.args[0].git_version) + + +def test_launch_git_version_default_main( + live_mock_server, + test_settings, + mocked_fetchable_git_repo_main, + mock_load_backend, +): + api = wandb.sdk.internal.internal_api.Api( + default_settings=test_settings, load_settings=False + ) + mock_with_run_info = launch.run( + api=api, + uri="https://foo:bar@github.com/FooTest/Foo.git", + ) + + assert "main" in str(mock_with_run_info.args[0].git_version) diff --git a/wandb/sdk/launch/_project_spec.py b/wandb/sdk/launch/_project_spec.py index 7af9d03decf..0a17967d25d 100644 --- a/wandb/sdk/launch/_project_spec.py +++ b/wandb/sdk/launch/_project_spec.py @@ -295,11 +295,13 @@ def _fetch_project_local(self, internal_api: Api) -> None: raise LaunchError( "Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`" ) - utils._fetch_git_repo( + branch_name = utils._fetch_git_repo( self.project_dir, run_info["git"]["remote"], run_info["git"]["commit"], ) + if self.git_version is None: + self.git_version = branch_name patch = utils.fetch_project_diff( source_entity, source_project, source_run_name, internal_api ) @@ -365,7 +367,11 @@ def _fetch_project_local(self, internal_api: Api) -> None: f"{LOG_PREFIX}Entry point for repo not specified, defaulting to python main.py" ) self.add_entry_point(["python", "main.py"]) - utils._fetch_git_repo(self.project_dir, self.uri, self.git_version) + branch_name = utils._fetch_git_repo( + self.project_dir, self.uri, self.git_version + ) + if self.git_version is None: + self.git_version = branch_name class EntryPoint: diff --git a/wandb/sdk/launch/utils.py b/wandb/sdk/launch/utils.py index d698cab32ae..e7521934893 100644 --- a/wandb/sdk/launch/utils.py +++ b/wandb/sdk/launch/utils.py @@ -11,7 +11,7 @@ import wandb from wandb import util from wandb.apis.internal import Api -from wandb.errors import CommError, ExecutionError, LaunchError +from wandb.errors import CommError, LaunchError if TYPE_CHECKING: # pragma: no cover from wandb.apis.public import Artifact as PublicArtifact @@ -413,7 +413,7 @@ def apply_patch(patch_string: str, dst_dir: str) -> None: raise wandb.Error("Failed to apply diff.patch associated with run.") -def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> None: +def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str: """Clones the git repo at ``uri`` into ``dst_dir``. checks out commit ``version`` (or defaults to the head commit of the repository's @@ -428,19 +428,40 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> None: repo = git.Repo.init(dst_dir) origin = repo.create_remote("origin", uri) origin.fetch() + if version is not None: try: repo.git.checkout(version) except git.exc.GitCommandError as e: - raise ExecutionError( + raise LaunchError( "Unable to checkout version '%s' of git repo %s" "- please ensure that the version exists in the repo. " "Error: %s" % (version, uri, e) ) else: - repo.create_head("master", origin.refs.master) - repo.heads.master.checkout() + if repo.getattr("references", None) is not None: + branches = [ref.name for ref in repo.references] + else: + branches = [] + # Check if main is in origin, else set branch to master + if "main" in branches or "origin/main" in branches: + version = "main" + else: + version = "master" + + try: + repo.create_head(version, origin.refs[version]) + repo.heads[version].checkout() + wandb.termlog(f"No git branch passed. Defaulted to branch: {version}") + except (AttributeError, IndexError) as e: + raise LaunchError( + "Unable to checkout default version '%s' of git repo %s " + "- to specify a git version use: --git-version \n" + "Error: %s" % (version, uri, e) + ) + repo.submodule_update(init=True, recursive=True) + return version def merge_parameters(