diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index eb38b4c263fa8..b23b5fafd2bb6 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800)) +- Utility for pickling work object safely even from a child process ([#15836](https://github.com/Lightning-AI/lightning/pull/15836)) + - Added `AutoScaler` component ([#15769](https://github.com/Lightning-AI/lightning/pull/15769)) - Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921)) diff --git a/src/lightning_app/utilities/safe_pickle.py b/src/lightning_app/utilities/safe_pickle.py new file mode 100644 index 0000000000000..8788ff22a3cb6 --- /dev/null +++ b/src/lightning_app/utilities/safe_pickle.py @@ -0,0 +1,95 @@ +import contextlib +import pickle +import sys +import types +import typing +from copy import deepcopy +from pathlib import Path + +from lightning_app.core.work import LightningWork +from lightning_app.utilities.app_helpers import _LightningAppRef + +NON_PICKLABLE_WORK_ATTRIBUTES = ["_request_queue", "_response_queue", "_backend", "_setattr_replacement"] + + +@contextlib.contextmanager +def _trimmed_work(work: LightningWork, to_trim: typing.List[str]) -> typing.Iterator[None]: + """Context manager to trim the work object to remove attributes that are not picklable.""" + holder = {} + for arg in to_trim: + holder[arg] = getattr(work, arg) + setattr(work, arg, None) + yield + for arg in to_trim: + setattr(work, arg, holder[arg]) + + +def get_picklable_work(work: LightningWork) -> LightningWork: + """Pickling a LightningWork instance fails if done from the work process + itself. This function is safe to call from the work process within both MultiprocessRuntime + and Cloud. + Note: This function modifies the module information of the work object. Specifically, it injects + the relative module path into the __module__ attribute of the work object. If the object is not + importable from the CWD, then the pickle load will fail. + + Example: + for a directory structure like below and the work class is defined in the app.py where + the app.py is the entrypoint for the app, it will inject `foo.bar.app` into the + __module__ attribute + + └── foo + ├── __init__.py + └── bar + └── app.py + """ + + # If the work object not taken from the app ref, there is a thread lock reference + # somewhere thats preventing it from being pickled. Investigate it later. We + # shouldn't be fetching the work object from the app ref. TODO @sherin + app_ref = _LightningAppRef.get_current() + if app_ref is None: + raise RuntimeError("Cannot pickle LightningWork outside of a LightningApp") + for w in app_ref.works: + if work.name == w.name: + # deep-copying the work object to avoid modifying the original work object + with _trimmed_work(w, to_trim=NON_PICKLABLE_WORK_ATTRIBUTES): + copied_work = deepcopy(w) + break + else: + raise ValueError(f"Work with name {work.name} not found in the app references") + + # if work is defined in the __main__ or __mp__main__ (the entrypoint file for `lightning run app` command), + # pickling/unpickling will fail, hence we need patch the module information + if "_main__" in copied_work.__class__.__module__: + work_class_module = sys.modules[copied_work.__class__.__module__] + work_class_file = work_class_module.__file__ + if not work_class_file: + raise ValueError( + f"Cannot pickle work class {copied_work.__class__.__name__} because we " + f"couldn't identify the module file" + ) + relative_path = Path(work_class_module.__file__).relative_to(Path.cwd()) # type: ignore + expected_module_name = relative_path.as_posix().replace(".py", "").replace("/", ".") + # TODO @sherin: also check if the module is importable from the CWD + fake_module = types.ModuleType(expected_module_name) + fake_module.__dict__.update(work_class_module.__dict__) + fake_module.__dict__["__name__"] = expected_module_name + sys.modules[expected_module_name] = fake_module + for k, v in fake_module.__dict__.items(): + if not k.startswith("__") and hasattr(v, "__module__"): + if "_main__" in v.__module__: + v.__module__ = expected_module_name + return copied_work + + +def dump(work: LightningWork, f: typing.BinaryIO) -> None: + picklable_work = get_picklable_work(work) + pickle.dump(picklable_work, f) + + +def load(f: typing.BinaryIO) -> typing.Any: + # inject current working directory to sys.path + sys.path.insert(1, str(Path.cwd())) + work = pickle.load(f) + sys.path.pop(1) + return work diff --git a/tests/tests_app/utilities/test_safe_pickle.py b/tests/tests_app/utilities/test_safe_pickle.py new file mode 100644 index 0000000000000..473fe28ed22f7 --- /dev/null +++ b/tests/tests_app/utilities/test_safe_pickle.py @@ -0,0 +1,11 @@ +import subprocess +from pathlib import Path + + +def test_safe_pickle_app(): + test_dir = Path(__file__).parent / "testdata" + proc = subprocess.Popen( + ["lightning", "run", "app", "safe_pickle_app.py", "--open-ui", "false"], stdout=subprocess.PIPE, cwd=test_dir + ) + stdout, _ = proc.communicate() + assert "Exiting the pickling app successfully" in stdout.decode("UTF-8") diff --git a/tests/tests_app/utilities/testdata/safe_pickle_app.py b/tests/tests_app/utilities/testdata/safe_pickle_app.py new file mode 100644 index 0000000000000..f15344360d85f --- /dev/null +++ b/tests/tests_app/utilities/testdata/safe_pickle_app.py @@ -0,0 +1,63 @@ +""" +This app tests three things +1. Can a work pickle `self` +2. Can the pickled work be unpickled in another work +3. Can the pickled work be unpickled from a script +""" + +import subprocess +from pathlib import Path + +from lightning_app import LightningApp, LightningFlow, LightningWork +from lightning_app.utilities import safe_pickle + + +class SelfPicklingWork(LightningWork): + def run(self): + with open("work.pkl", "wb") as f: + safe_pickle.dump(self, f) + + def get_test_string(self): + return f"Hello from {self.__class__.__name__}!" + + +class WorkThatLoadsPickledWork(LightningWork): + def run(self): + with open("work.pkl", "rb") as f: + work = safe_pickle.load(f) + assert work.get_test_string() == "Hello from SelfPicklingWork!" + + +script_load_pickled_work = """ +import pickle +work = pickle.load(open("work.pkl", "rb")) +print(work.get_test_string()) +""" + + +class RootFlow(LightningFlow): + def __init__(self): + super().__init__() + self.self_pickling_work = SelfPicklingWork() + self.work_that_loads_pickled_work = WorkThatLoadsPickledWork() + + def run(self): + self.self_pickling_work.run() + self.work_that_loads_pickled_work.run() + + with open("script_that_loads_pickled_work.py", "w") as f: + f.write(script_load_pickled_work) + + # read the output from subprocess + proc = subprocess.Popen(["python", "script_that_loads_pickled_work.py"], stdout=subprocess.PIPE) + assert "Hello from SelfPicklingWork" in proc.stdout.read().decode("UTF-8") + + # deleting the script + Path("script_that_loads_pickled_work.py").unlink() + # deleting the pkl file + Path("work.pkl").unlink() + + self._exit("Exiting the pickling app successfully!!") + + +app = LightningApp(RootFlow())