Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Enable running with spawn context #15923

Merged
merged 19 commits into from Dec 7, 2022
4 changes: 4 additions & 0 deletions examples/app_installation_commands/app.py
Expand Up @@ -13,6 +13,10 @@ def run(self):
print("lmdb successfully installed")
print("accessing a module in a Work or Flow body works!")

@property
def ready(self) -> bool:
return True


print(f"accessing an object in main code body works!: version={lmdb.version()}")

Expand Down
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))

- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))


### Changed

Expand Down
7 changes: 7 additions & 0 deletions src/lightning_app/core/flow.py
Expand Up @@ -763,6 +763,13 @@ def __init__(self, work):
super().__init__()
self.work = work

@property
def ready(self) -> bool:
ready = getattr(self.work, "ready", None)
if ready:
return ready
return self.work.url != ""

def run(self):
if self.work.has_succeeded:
self.work.stop()
Expand Down
3 changes: 2 additions & 1 deletion src/lightning_app/core/queues.py
Expand Up @@ -198,7 +198,8 @@ class MultiProcessQueue(BaseQueue):
def __init__(self, name: str, default_timeout: float):
self.name = name
self.default_timeout = default_timeout
self.queue = multiprocessing.Queue()
context = multiprocessing.get_context("spawn")
self.queue = context.Queue()

def put(self, item):
self.queue.put(item)
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/core/work.py
@@ -1,3 +1,4 @@
import sys
import time
import warnings
from copy import deepcopy
Expand Down Expand Up @@ -46,6 +47,8 @@ class LightningWork:
)

_run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor
# TODO: Move to spawn for all Operating System.
_start_method = "spawn" if sys.platform == "win32" else "fork"

def __init__(
self,
Expand Down
5 changes: 4 additions & 1 deletion src/lightning_app/runners/backends/mp_process.py
Expand Up @@ -31,7 +31,10 @@ def start(self):
flow_to_work_delta_queue=self.app.flow_to_work_delta_queues[self.work.name],
run_executor_cls=self.work._run_executor_cls,
)
self._process = multiprocessing.Process(target=self._work_runner)

start_method = self.work._start_method
context = multiprocessing.get_context(start_method)
self._process = context.Process(target=self._work_runner)
self._process.start()

def kill(self):
Expand Down
8 changes: 5 additions & 3 deletions tests/tests_app/core/test_queues.py
Expand Up @@ -5,7 +5,6 @@
from unittest import mock

import pytest
import redis
import requests_mock

from lightning_app import LightningFlow
Expand All @@ -23,6 +22,7 @@ def test_queue_api(queue_type, monkeypatch):

This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction
"""
import redis

blpop_out = (b"entry-id", pickle.dumps("test_entry"))

Expand Down Expand Up @@ -104,12 +104,14 @@ def test_redis_queue_read_timeout(redis_mock):

@pytest.mark.parametrize(
"queue_type, queue_process_mock",
[(QueuingSystem.SINGLEPROCESS, queue), (QueuingSystem.MULTIPROCESS, multiprocessing)],
[(QueuingSystem.MULTIPROCESS, multiprocessing)],
)
def test_process_queue_read_timeout(queue_type, queue_process_mock, monkeypatch):

context = mock.MagicMock()
queue_mocked = mock.MagicMock()
monkeypatch.setattr(queue_process_mock, "Queue", queue_mocked)
context.Queue = queue_mocked
monkeypatch.setattr(queue_process_mock, "get_context", mock.MagicMock(return_value=context))
my_queue = queue_type.get_readiness_queue()

# default timeout
Expand Down