-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
test_multiprocess.py
98 lines (71 loc) · 2.68 KB
/
test_multiprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
from unittest import mock
from unittest.mock import Mock
import pytest
from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.frontend import StaticWebFrontend, StreamlitFrontend
from lightning_app.runners import MultiProcessRuntime
from lightning_app.utilities.component import _get_context
def _streamlit_render_fn():
pass
class StreamlitFlow(LightningFlow):
def run(self):
self._exit()
def configure_layout(self):
frontend = StreamlitFrontend(render_fn=_streamlit_render_fn)
frontend.start_server = Mock()
frontend.stop_server = Mock()
return frontend
class WebFlow(LightningFlow):
def run(self):
self._exit()
def configure_layout(self):
frontend = StaticWebFrontend(serve_dir="a/b/c")
frontend.start_server = Mock()
frontend.stop_server = Mock()
return frontend
class StartFrontendServersTestFlow(LightningFlow):
def __init__(self):
super().__init__()
self.flow0 = StreamlitFlow()
self.flow1 = WebFlow()
def run(self):
self._exit()
@mock.patch("lightning_app.runners.multiprocess.find_free_network_port")
def test_multiprocess_starts_frontend_servers(*_):
"""Test that the MultiProcessRuntime starts the servers for the frontends in each LightningFlow."""
root = StartFrontendServersTestFlow()
app = LightningApp(root)
MultiProcessRuntime(app).dispatch()
app.frontends[root.flow0.name].start_server.assert_called_once()
app.frontends[root.flow1.name].start_server.assert_called_once()
app.frontends[root.flow0.name].stop_server.assert_called_once()
app.frontends[root.flow1.name].stop_server.assert_called_once()
class ContextWork(LightningWork):
def __init__(self):
super().__init__()
def run(self):
assert _get_context().value == "work"
class ContextFlow(LightningFlow):
def __init__(self):
super().__init__()
self.work = ContextWork()
assert _get_context() is None
def run(self):
assert _get_context().value == "flow"
self.work.run()
assert _get_context().value == "flow"
self._exit()
def test_multiprocess_runtime_sets_context():
"""Test that the runtime sets the global variable COMPONENT_CONTEXT in Flow and Work."""
MultiProcessRuntime(LightningApp(ContextFlow())).dispatch()
@pytest.mark.parametrize(
"env,expected_url",
[
({}, "http://127.0.0.1:7501/view"),
({"APP_SERVER_HOST": "http://test"}, "http://test"),
],
)
def test_get_app_url(env, expected_url):
with mock.patch.dict(os.environ, env):
assert MultiProcessRuntime._get_app_url() == expected_url