Skip to content

Commit

Permalink
[App] Multiprocessing-safe work pickling (#15836)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherin Thomas committed Dec 8, 2022
1 parent ca5ca0e commit df67833
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800))

- Utility for pickling work object safely even from a child process ([#15836](https://github.com/Lightning-AI/lightning/pull/15836))

- Added `AutoScaler` component ([#15769](https://github.com/Lightning-AI/lightning/pull/15769))

- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))
Expand Down
95 changes: 95 additions & 0 deletions src/lightning_app/utilities/safe_pickle.py
@@ -0,0 +1,95 @@
import contextlib
import pickle
import sys
import types
import typing
from copy import deepcopy
from pathlib import Path

from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import _LightningAppRef

NON_PICKLABLE_WORK_ATTRIBUTES = ["_request_queue", "_response_queue", "_backend", "_setattr_replacement"]


@contextlib.contextmanager
def _trimmed_work(work: LightningWork, to_trim: typing.List[str]) -> typing.Iterator[None]:
"""Context manager to trim the work object to remove attributes that are not picklable."""
holder = {}
for arg in to_trim:
holder[arg] = getattr(work, arg)
setattr(work, arg, None)
yield
for arg in to_trim:
setattr(work, arg, holder[arg])


def get_picklable_work(work: LightningWork) -> LightningWork:
"""Pickling a LightningWork instance fails if done from the work process
itself. This function is safe to call from the work process within both MultiprocessRuntime
and Cloud.
Note: This function modifies the module information of the work object. Specifically, it injects
the relative module path into the __module__ attribute of the work object. If the object is not
importable from the CWD, then the pickle load will fail.
Example:
for a directory structure like below and the work class is defined in the app.py where
the app.py is the entrypoint for the app, it will inject `foo.bar.app` into the
__module__ attribute
└── foo
├── __init__.py
└── bar
└── app.py
"""

# If the work object not taken from the app ref, there is a thread lock reference
# somewhere thats preventing it from being pickled. Investigate it later. We
# shouldn't be fetching the work object from the app ref. TODO @sherin
app_ref = _LightningAppRef.get_current()
if app_ref is None:
raise RuntimeError("Cannot pickle LightningWork outside of a LightningApp")
for w in app_ref.works:
if work.name == w.name:
# deep-copying the work object to avoid modifying the original work object
with _trimmed_work(w, to_trim=NON_PICKLABLE_WORK_ATTRIBUTES):
copied_work = deepcopy(w)
break
else:
raise ValueError(f"Work with name {work.name} not found in the app references")

# if work is defined in the __main__ or __mp__main__ (the entrypoint file for `lightning run app` command),
# pickling/unpickling will fail, hence we need patch the module information
if "_main__" in copied_work.__class__.__module__:
work_class_module = sys.modules[copied_work.__class__.__module__]
work_class_file = work_class_module.__file__
if not work_class_file:
raise ValueError(
f"Cannot pickle work class {copied_work.__class__.__name__} because we "
f"couldn't identify the module file"
)
relative_path = Path(work_class_module.__file__).relative_to(Path.cwd()) # type: ignore
expected_module_name = relative_path.as_posix().replace(".py", "").replace("/", ".")
# TODO @sherin: also check if the module is importable from the CWD
fake_module = types.ModuleType(expected_module_name)
fake_module.__dict__.update(work_class_module.__dict__)
fake_module.__dict__["__name__"] = expected_module_name
sys.modules[expected_module_name] = fake_module
for k, v in fake_module.__dict__.items():
if not k.startswith("__") and hasattr(v, "__module__"):
if "_main__" in v.__module__:
v.__module__ = expected_module_name
return copied_work


def dump(work: LightningWork, f: typing.BinaryIO) -> None:
picklable_work = get_picklable_work(work)
pickle.dump(picklable_work, f)


def load(f: typing.BinaryIO) -> typing.Any:
# inject current working directory to sys.path
sys.path.insert(1, str(Path.cwd()))
work = pickle.load(f)
sys.path.pop(1)
return work
11 changes: 11 additions & 0 deletions tests/tests_app/utilities/test_safe_pickle.py
@@ -0,0 +1,11 @@
import subprocess
from pathlib import Path


def test_safe_pickle_app():
test_dir = Path(__file__).parent / "testdata"
proc = subprocess.Popen(
["lightning", "run", "app", "safe_pickle_app.py", "--open-ui", "false"], stdout=subprocess.PIPE, cwd=test_dir
)
stdout, _ = proc.communicate()
assert "Exiting the pickling app successfully" in stdout.decode("UTF-8")
63 changes: 63 additions & 0 deletions tests/tests_app/utilities/testdata/safe_pickle_app.py
@@ -0,0 +1,63 @@
"""
This app tests three things
1. Can a work pickle `self`
2. Can the pickled work be unpickled in another work
3. Can the pickled work be unpickled from a script
"""

import subprocess
from pathlib import Path

from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.utilities import safe_pickle


class SelfPicklingWork(LightningWork):
def run(self):
with open("work.pkl", "wb") as f:
safe_pickle.dump(self, f)

def get_test_string(self):
return f"Hello from {self.__class__.__name__}!"


class WorkThatLoadsPickledWork(LightningWork):
def run(self):
with open("work.pkl", "rb") as f:
work = safe_pickle.load(f)
assert work.get_test_string() == "Hello from SelfPicklingWork!"


script_load_pickled_work = """
import pickle
work = pickle.load(open("work.pkl", "rb"))
print(work.get_test_string())
"""


class RootFlow(LightningFlow):
def __init__(self):
super().__init__()
self.self_pickling_work = SelfPicklingWork()
self.work_that_loads_pickled_work = WorkThatLoadsPickledWork()

def run(self):
self.self_pickling_work.run()
self.work_that_loads_pickled_work.run()

with open("script_that_loads_pickled_work.py", "w") as f:
f.write(script_load_pickled_work)

# read the output from subprocess
proc = subprocess.Popen(["python", "script_that_loads_pickled_work.py"], stdout=subprocess.PIPE)
assert "Hello from SelfPicklingWork" in proc.stdout.read().decode("UTF-8")

# deleting the script
Path("script_that_loads_pickled_work.py").unlink()
# deleting the pkl file
Path("work.pkl").unlink()

self._exit("Exiting the pickling app successfully!!")


app = LightningApp(RootFlow())

0 comments on commit df67833

Please sign in to comment.