Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Launch gitversion error message #4028

Merged
merged 12 commits into from Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 89 additions & 1 deletion tests/unit_tests_old/tests_launch/test_launch.py
Expand Up @@ -38,13 +38,25 @@ def mocked_fetchable_git_repo():
m = mock.Mock()

def populate_dst_dir(dst_dir):
repo = mock.Mock()
reference = mock.Mock()
reference.name = "master"
repo.references = [reference]

def create_remote(o, r):
origin = mock.Mock()
origin.refs = {"master": mock.Mock()}
return origin

repo.create_remote = create_remote
repo.heads = {"master": mock.Mock()}
with open(os.path.join(dst_dir, "train.py"), "w") as f:
f.write(fixture_open("train.py").read())
with open(os.path.join(dst_dir, "requirements.txt"), "w") as f:
f.write(fixture_open("requirements.txt").read())
with open(os.path.join(dst_dir, "patch.txt"), "w") as f:
f.write("test")
return mock.Mock()
return repo

m.Repo.init = mock.Mock(side_effect=populate_dst_dir)
with mock.patch.dict("sys.modules", git=m):
Expand Down Expand Up @@ -1459,3 +1471,79 @@ def test_launch_no_url_job_or_docker_image(
)
except wandb.errors.LaunchError as e:
assert "Must specify a uri, job or docker image" in str(e)


@pytest.fixture
def mocked_fetchable_git_repo_main():
"""Gross fixture for explicit branch name -- TODO: fix using parameterization (?)"""
m = mock.Mock()

def populate_dst_dir(dst_dir):
repo = mock.Mock()
reference = mock.Mock()
reference.name = "main"
repo.references = [reference]

def create_remote(o, r):
origin = mock.Mock()
origin.refs = {"main": mock.Mock()}
return origin

repo.create_remote = create_remote
repo.heads = {"main": mock.Mock()}

with open(os.path.join(dst_dir, "train.py"), "w") as f:
f.write(fixture_open("train.py").read())
with open(os.path.join(dst_dir, "requirements.txt"), "w") as f:
f.write(fixture_open("requirements.txt").read())
with open(os.path.join(dst_dir, "patch.txt"), "w") as f:
f.write("test")
return repo

m.Repo.init = mock.Mock(side_effect=populate_dst_dir)
with mock.patch.dict("sys.modules", git=m):
yield m


def test_launch_git_version_branch_set(
live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend
):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
mock_with_run_info = launch.run(
api=api, uri="https://foo:bar@github.com/FooTest/Foo.git", version="foobar"
)

assert "foobar" in str(mock_with_run_info.args[0].git_version)


def test_launch_git_version_default_master(
live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend
):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
mock_with_run_info = launch.run(
api=api,
uri="https://foo:bar@github.com/FooTest/Foo.git",
)

assert "master" in str(mock_with_run_info.args[0].git_version)


def test_launch_git_version_default_main(
live_mock_server,
test_settings,
mocked_fetchable_git_repo_main,
mock_load_backend,
):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
mock_with_run_info = launch.run(
api=api,
uri="https://foo:bar@github.com/FooTest/Foo.git",
)

assert "main" in str(mock_with_run_info.args[0].git_version)
10 changes: 8 additions & 2 deletions wandb/sdk/launch/_project_spec.py
Expand Up @@ -287,11 +287,13 @@ def _fetch_project_local(self, internal_api: Api) -> None:
raise LaunchError(
"Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
)
utils._fetch_git_repo(
branch_name = utils._fetch_git_repo(
self.project_dir,
run_info["git"]["remote"],
run_info["git"]["commit"],
)
if self.git_version is None:
self.git_version = branch_name
patch = utils.fetch_project_diff(
source_entity, source_project, source_run_name, internal_api
)
Expand Down Expand Up @@ -357,7 +359,11 @@ def _fetch_project_local(self, internal_api: Api) -> None:
"Entry point for repo not specified, defaulting to python main.py"
)
self.add_entry_point(["python", "main.py"])
utils._fetch_git_repo(self.project_dir, self.uri, self.git_version)
branch_name = utils._fetch_git_repo(
self.project_dir, self.uri, self.git_version
)
if self.git_version is None:
self.git_version = branch_name


class EntryPoint:
Expand Down
29 changes: 25 additions & 4 deletions wandb/sdk/launch/utils.py
Expand Up @@ -411,7 +411,7 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
raise wandb.Error("Failed to apply diff.patch associated with run.")


def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> None:
def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str:
"""Clones the git repo at ``uri`` into ``dst_dir``.

checks out commit ``version`` (or defaults to the head commit of the repository's
Expand All @@ -426,19 +426,40 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> None:
repo = git.Repo.init(dst_dir)
origin = repo.create_remote("origin", uri)
origin.fetch()

if version is not None:
try:
repo.git.checkout(version)
except git.exc.GitCommandError as e:
raise ExecutionError(
raise LaunchError(
"Unable to checkout version '%s' of git repo %s"
"- please ensure that the version exists in the repo. "
"Error: %s" % (version, uri, e)
)
else:
repo.create_head("master", origin.refs.master)
repo.heads.master.checkout()
if repo.getattr("references", None) is not None:
branches = [ref.name for ref in repo.references]
else:
branches = []
# Check if main is in origin, else set branch to master
if "main" in branches or "origin/main" in branches:
version = "main"
else:
version = "master"

try:
repo.create_head(version, origin.refs[version])
repo.heads[version].checkout()
wandb.termlog(f"No git branch passed. Defaulted to branch: {version}")
except (AttributeError, IndexError) as e:
raise LaunchError(
"Unable to checkout default version '%s' of git repo %s "
"- to specify a git version use: --git-version \n"
"Error: %s" % (version, uri, e)
)

repo.submodule_update(init=True, recursive=True)
return version


def merge_parameters(
Expand Down