diff --git a/src/lightning_app/components/python/tracer.py b/src/lightning_app/components/python/tracer.py index 0dc8c70f9a5f4..c476f083258fc 100644 --- a/src/lightning_app/components/python/tracer.py +++ b/src/lightning_app/components/python/tracer.py @@ -117,11 +117,18 @@ def __init__( self.code_name = code.get("name") if code else None self.restart_count = 0 - def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, **kwargs): + def run( + self, + params: Optional[Dict[str, Any]] = None, + restart_count: Optional[int] = None, + code_dir: Optional[str] = ".", + **kwargs, + ): """ Arguments: params: A dictionary of arguments to be be added to script_args. restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks. + code_dir: A path string determining where the source is extracted, default is current directory. """ if restart_count: self.restart_count = restart_count @@ -137,7 +144,10 @@ def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[i if self.code_name in self.drive.list(): self.drive.get(self.code_name) - extract_tarfile(self.code_name, ".", "r:gz") + extract_tarfile(self.code_name, code_dir, "r:gz") + + prev_cwd = os.getcwd() + os.chdir(code_dir) if not os.path.exists(self.script_path): raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.") @@ -152,6 +162,7 @@ def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[i if self.env: os.environ.update(self.env) res = self._run_tracer(init_globals) + os.chdir(prev_cwd) os.environ = env_copy return self.on_after_run(res) diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py index 57a5b422b9919..e32e63ccf5985 100644 --- a/tests/tests_app/components/python/test_python.py +++ b/tests/tests_app/components/python/test_python.py @@ -122,3 +122,28 @@ def test_tracer_component_with_code(): assert python_script.script_args == ["--b=1", "--a=1"] os.remove("file.py") os.remove("sample.tar.gz") + + +def test_tracer_component_with_code_in_dir(tmp_path): + """This test ensures the Tracer Component gets the latest code from the code object that is provided and + arguments are cleaned.""" + + drive = Drive("lit://code") + drive.component_name = "something" + code = Code(drive=drive, name="sample.tar.gz") + + with open("file.py", "w") as f: + f.write('raise Exception("An error")') + + with tarfile.open("sample.tar.gz", "w:gz") as tar: + tar.add("file.py") + + drive.put("sample.tar.gz") + os.remove("file.py") + os.remove("sample.tar.gz") + + python_script = TracerPythonScript("file.py", script_args=["--b=1"], raise_exception=False, code=code) + run_work_isolated(python_script, params={"--a": "1"}, restart_count=0, code_dir=str(tmp_path)) + assert "An error" in python_script.status.message + + assert os.path.exists(os.path.join(str(tmp_path), "file.py"))