Skip to content

Commit

Permalink
Add code_dir argument to tracer run (#15771)
Browse files Browse the repository at this point in the history
(cherry picked from commit 0a12731)
  • Loading branch information
lantiga authored and Borda committed Nov 30, 2022
1 parent 0fe095b commit 3fea2d5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/lightning_app/components/python/tracer.py
Expand Up @@ -117,11 +117,18 @@ def __init__(
self.code_name = code.get("name") if code else None
self.restart_count = 0

def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, **kwargs):
def run(
self,
params: Optional[Dict[str, Any]] = None,
restart_count: Optional[int] = None,
code_dir: Optional[str] = ".",
**kwargs,
):
"""
Arguments:
params: A dictionary of arguments to be be added to script_args.
restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks.
code_dir: A path string determining where the source is extracted, default is current directory.
"""
if restart_count:
self.restart_count = restart_count
Expand All @@ -137,7 +144,10 @@ def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[i

if self.code_name in self.drive.list():
self.drive.get(self.code_name)
extract_tarfile(self.code_name, ".", "r:gz")
extract_tarfile(self.code_name, code_dir, "r:gz")

prev_cwd = os.getcwd()
os.chdir(code_dir)

if not os.path.exists(self.script_path):
raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")
Expand All @@ -152,6 +162,7 @@ def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[i
if self.env:
os.environ.update(self.env)
res = self._run_tracer(init_globals)
os.chdir(prev_cwd)
os.environ = env_copy
return self.on_after_run(res)

Expand Down
25 changes: 25 additions & 0 deletions tests/tests_app/components/python/test_python.py
Expand Up @@ -122,3 +122,28 @@ def test_tracer_component_with_code():
assert python_script.script_args == ["--b=1", "--a=1"]
os.remove("file.py")
os.remove("sample.tar.gz")


def test_tracer_component_with_code_in_dir(tmp_path):
"""This test ensures the Tracer Component gets the latest code from the code object that is provided and
arguments are cleaned."""

drive = Drive("lit://code")
drive.component_name = "something"
code = Code(drive=drive, name="sample.tar.gz")

with open("file.py", "w") as f:
f.write('raise Exception("An error")')

with tarfile.open("sample.tar.gz", "w:gz") as tar:
tar.add("file.py")

drive.put("sample.tar.gz")
os.remove("file.py")
os.remove("sample.tar.gz")

python_script = TracerPythonScript("file.py", script_args=["--b=1"], raise_exception=False, code=code)
run_work_isolated(python_script, params={"--a": "1"}, restart_count=0, code_dir=str(tmp_path))
assert "An error" in python_script.status.message

assert os.path.exists(os.path.join(str(tmp_path), "file.py"))

0 comments on commit 3fea2d5

Please sign in to comment.