From de1bc5735040668ea4c01e905c6706d0b0dde9e2 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 22 Nov 2022 16:56:20 +0100 Subject: [PATCH 1/3] Add code_dir argument to tracer run --- src/lightning_app/components/python/tracer.py | 8 ++++-- .../components/python/test_python.py | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/lightning_app/components/python/tracer.py b/src/lightning_app/components/python/tracer.py index 0dc8c70f9a5f4..0fe02cf4b4374 100644 --- a/src/lightning_app/components/python/tracer.py +++ b/src/lightning_app/components/python/tracer.py @@ -117,7 +117,7 @@ def __init__( self.code_name = code.get("name") if code else None self.restart_count = 0 - def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, **kwargs): + def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, code_dir=".", **kwargs): """ Arguments: params: A dictionary of arguments to be be added to script_args. @@ -137,7 +137,10 @@ def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[i if self.code_name in self.drive.list(): self.drive.get(self.code_name) - extract_tarfile(self.code_name, ".", "r:gz") + extract_tarfile(self.code_name, code_dir, "r:gz") + + prev_cwd = os.getcwd() + os.chdir(code_dir) if not os.path.exists(self.script_path): raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.") @@ -152,6 +155,7 @@ def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[i if self.env: os.environ.update(self.env) res = self._run_tracer(init_globals) + os.chdir(prev_cwd) os.environ = env_copy return self.on_after_run(res) diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py index 57a5b422b9919..18a0745df18e4 100644 --- a/tests/tests_app/components/python/test_python.py +++ b/tests/tests_app/components/python/test_python.py @@ -122,3 +122,28 @@ def test_tracer_component_with_code(): assert python_script.script_args == ["--b=1", "--a=1"] os.remove("file.py") os.remove("sample.tar.gz") + + +def test_tracer_component_with_code_in_dir(tmp_path): + """This test ensures the Tracer Component gets the latest code from the code object that is provided and + arguments are cleaned.""" + + drive = Drive("lit://code") + drive.component_name = "something" + code = Code(drive=drive, name="sample.tar.gz") + + with open("file.py", "w") as f: + f.write('raise Exception("An error")') + + with tarfile.open("sample.tar.gz", "w:gz") as tar: + tar.add("file.py") + + drive.put("sample.tar.gz") + os.remove("file.py") + os.remove("sample.tar.gz") + + python_script = TracerPythonScript("file.py", script_args=["--b=1"], raise_exception=False, code=code) + run_work_isolated(python_script, params={"--a": "1"}, restart_count=0, code_dir=tmp_path) + assert "An error" in python_script.status.message + + assert os.path.exists(os.path.join(tmp_path, "file.py")) From 9555ee2817325caa3bcc15b5d3c6757c858030d1 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 22 Nov 2022 16:59:50 +0100 Subject: [PATCH 2/3] Add docstring for added argument --- src/lightning_app/components/python/tracer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning_app/components/python/tracer.py b/src/lightning_app/components/python/tracer.py index 0fe02cf4b4374..d305bbc7c3135 100644 --- a/src/lightning_app/components/python/tracer.py +++ b/src/lightning_app/components/python/tracer.py @@ -122,6 +122,8 @@ def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[i Arguments: params: A dictionary of arguments to be be added to script_args. restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks. + code_dir: A path string or Path object determining where the source is extracted, + default is current directory. """ if restart_count: self.restart_count = restart_count From ca959e6c7589f3cea2eb353993f02437134ecf6b Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 22 Nov 2022 19:36:45 +0100 Subject: [PATCH 3/3] Make code dir a string --- src/lightning_app/components/python/tracer.py | 11 ++++++++--- tests/tests_app/components/python/test_python.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/lightning_app/components/python/tracer.py b/src/lightning_app/components/python/tracer.py index d305bbc7c3135..c476f083258fc 100644 --- a/src/lightning_app/components/python/tracer.py +++ b/src/lightning_app/components/python/tracer.py @@ -117,13 +117,18 @@ def __init__( self.code_name = code.get("name") if code else None self.restart_count = 0 - def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, code_dir=".", **kwargs): + def run( + self, + params: Optional[Dict[str, Any]] = None, + restart_count: Optional[int] = None, + code_dir: Optional[str] = ".", + **kwargs, + ): """ Arguments: params: A dictionary of arguments to be be added to script_args. restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks. - code_dir: A path string or Path object determining where the source is extracted, - default is current directory. + code_dir: A path string determining where the source is extracted, default is current directory. """ if restart_count: self.restart_count = restart_count diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py index 18a0745df18e4..e32e63ccf5985 100644 --- a/tests/tests_app/components/python/test_python.py +++ b/tests/tests_app/components/python/test_python.py @@ -143,7 +143,7 @@ def test_tracer_component_with_code_in_dir(tmp_path): os.remove("sample.tar.gz") python_script = TracerPythonScript("file.py", script_args=["--b=1"], raise_exception=False, code=code) - run_work_isolated(python_script, params={"--a": "1"}, restart_count=0, code_dir=tmp_path) + run_work_isolated(python_script, params={"--a": "1"}, restart_count=0, code_dir=str(tmp_path)) assert "An error" in python_script.status.message - assert os.path.exists(os.path.join(tmp_path, "file.py")) + assert os.path.exists(os.path.join(str(tmp_path), "file.py"))