Skip to content

Commit

Permalink
Launch gitversion error message (#4028)
Browse files Browse the repository at this point in the history
* improved error handling for git-version param in launch

* added fallback from main -> master when no git version found

* temp working on tests

* added tests and slighly changed return-path of git fetch function, possibly breaking (!)

* fixed broken tests LOL

* added one little launchError check
  • Loading branch information
gtarpenning committed Aug 12, 2022
1 parent b1270e0 commit 283fd8d
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 8 deletions.
90 changes: 89 additions & 1 deletion tests/unit_tests_old/tests_launch/test_launch.py
Expand Up @@ -38,13 +38,25 @@ def mocked_fetchable_git_repo():
m = mock.Mock()

def populate_dst_dir(dst_dir):
repo = mock.Mock()
reference = mock.Mock()
reference.name = "master"
repo.references = [reference]

def create_remote(o, r):
origin = mock.Mock()
origin.refs = {"master": mock.Mock()}
return origin

repo.create_remote = create_remote
repo.heads = {"master": mock.Mock()}
with open(os.path.join(dst_dir, "train.py"), "w") as f:
f.write(fixture_open("train.py").read())
with open(os.path.join(dst_dir, "requirements.txt"), "w") as f:
f.write(fixture_open("requirements.txt").read())
with open(os.path.join(dst_dir, "patch.txt"), "w") as f:
f.write("test")
return mock.Mock()
return repo

m.Repo.init = mock.Mock(side_effect=populate_dst_dir)
with mock.patch.dict("sys.modules", git=m):
Expand Down Expand Up @@ -1459,3 +1471,79 @@ def test_launch_no_url_job_or_docker_image(
)
except wandb.errors.LaunchError as e:
assert "Must specify a uri, job or docker image" in str(e)


@pytest.fixture
def mocked_fetchable_git_repo_main():
"""Gross fixture for explicit branch name -- TODO: fix using parameterization (?)"""
m = mock.Mock()

def populate_dst_dir(dst_dir):
repo = mock.Mock()
reference = mock.Mock()
reference.name = "main"
repo.references = [reference]

def create_remote(o, r):
origin = mock.Mock()
origin.refs = {"main": mock.Mock()}
return origin

repo.create_remote = create_remote
repo.heads = {"main": mock.Mock()}

with open(os.path.join(dst_dir, "train.py"), "w") as f:
f.write(fixture_open("train.py").read())
with open(os.path.join(dst_dir, "requirements.txt"), "w") as f:
f.write(fixture_open("requirements.txt").read())
with open(os.path.join(dst_dir, "patch.txt"), "w") as f:
f.write("test")
return repo

m.Repo.init = mock.Mock(side_effect=populate_dst_dir)
with mock.patch.dict("sys.modules", git=m):
yield m


def test_launch_git_version_branch_set(
live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend
):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
mock_with_run_info = launch.run(
api=api, uri="https://foo:bar@github.com/FooTest/Foo.git", version="foobar"
)

assert "foobar" in str(mock_with_run_info.args[0].git_version)


def test_launch_git_version_default_master(
live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend
):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
mock_with_run_info = launch.run(
api=api,
uri="https://foo:bar@github.com/FooTest/Foo.git",
)

assert "master" in str(mock_with_run_info.args[0].git_version)


def test_launch_git_version_default_main(
live_mock_server,
test_settings,
mocked_fetchable_git_repo_main,
mock_load_backend,
):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
mock_with_run_info = launch.run(
api=api,
uri="https://foo:bar@github.com/FooTest/Foo.git",
)

assert "main" in str(mock_with_run_info.args[0].git_version)
10 changes: 8 additions & 2 deletions wandb/sdk/launch/_project_spec.py
Expand Up @@ -295,11 +295,13 @@ def _fetch_project_local(self, internal_api: Api) -> None:
raise LaunchError(
"Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
)
utils._fetch_git_repo(
branch_name = utils._fetch_git_repo(
self.project_dir,
run_info["git"]["remote"],
run_info["git"]["commit"],
)
if self.git_version is None:
self.git_version = branch_name
patch = utils.fetch_project_diff(
source_entity, source_project, source_run_name, internal_api
)
Expand Down Expand Up @@ -365,7 +367,11 @@ def _fetch_project_local(self, internal_api: Api) -> None:
f"{LOG_PREFIX}Entry point for repo not specified, defaulting to python main.py"
)
self.add_entry_point(["python", "main.py"])
utils._fetch_git_repo(self.project_dir, self.uri, self.git_version)
branch_name = utils._fetch_git_repo(
self.project_dir, self.uri, self.git_version
)
if self.git_version is None:
self.git_version = branch_name


class EntryPoint:
Expand Down
31 changes: 26 additions & 5 deletions wandb/sdk/launch/utils.py
Expand Up @@ -11,7 +11,7 @@
import wandb
from wandb import util
from wandb.apis.internal import Api
from wandb.errors import CommError, ExecutionError, LaunchError
from wandb.errors import CommError, LaunchError

if TYPE_CHECKING: # pragma: no cover
from wandb.apis.public import Artifact as PublicArtifact
Expand Down Expand Up @@ -413,7 +413,7 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
raise wandb.Error("Failed to apply diff.patch associated with run.")


def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> None:
def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str:
"""Clones the git repo at ``uri`` into ``dst_dir``.
checks out commit ``version`` (or defaults to the head commit of the repository's
Expand All @@ -428,19 +428,40 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> None:
repo = git.Repo.init(dst_dir)
origin = repo.create_remote("origin", uri)
origin.fetch()

if version is not None:
try:
repo.git.checkout(version)
except git.exc.GitCommandError as e:
raise ExecutionError(
raise LaunchError(
"Unable to checkout version '%s' of git repo %s"
"- please ensure that the version exists in the repo. "
"Error: %s" % (version, uri, e)
)
else:
repo.create_head("master", origin.refs.master)
repo.heads.master.checkout()
if repo.getattr("references", None) is not None:
branches = [ref.name for ref in repo.references]
else:
branches = []
# Check if main is in origin, else set branch to master
if "main" in branches or "origin/main" in branches:
version = "main"
else:
version = "master"

try:
repo.create_head(version, origin.refs[version])
repo.heads[version].checkout()
wandb.termlog(f"No git branch passed. Defaulted to branch: {version}")
except (AttributeError, IndexError) as e:
raise LaunchError(
"Unable to checkout default version '%s' of git repo %s "
"- to specify a git version use: --git-version \n"
"Error: %s" % (version, uri, e)
)

repo.submodule_update(init=True, recursive=True)
return version


def merge_parameters(
Expand Down

0 comments on commit 283fd8d

Please sign in to comment.