-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
streamlit.py
157 lines (120 loc) · 4.9 KB
/
streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import abc
import inspect
import os
import pydoc
import subprocess
import sys
from typing import Any, Callable, Type
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import StreamLitStatePlugin
from lightning_app.utilities.state import AppState
class ServeStreamlit(LightningWork, abc.ABC):
"""The ``ServeStreamlit`` work allows you to use streamlit from a work.
You can optionally build a model in the ``build_model`` hook, which will only be called once per session.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ready = False
self._process = None
@property
def model(self) -> Any:
return getattr(self, "_model", None)
@abc.abstractmethod
def render(self) -> None:
"""Override with your streamlit render function."""
def build_model(self) -> Any:
"""Optionally override to instantiate and return your model.
The model will be accessible under ``self.model``.
"""
return None
def run(self) -> None:
env = os.environ.copy()
env["LIGHTNING_COMPONENT_NAME"] = self.name
env["LIGHTNING_WORK"] = self.__class__.__name__
env["LIGHTNING_WORK_MODULE_FILE"] = inspect.getmodule(self).__file__
self._process = subprocess.Popen(
[
sys.executable,
"-m",
"streamlit",
"run",
__file__,
"--server.address",
str(self.host),
"--server.port",
str(self.port),
"--server.headless",
"true", # do not open the browser window when running locally
],
env=env,
)
self.ready = True
self._process.wait()
def on_exit(self) -> None:
if self._process is not None:
self._process.kill()
def configure_layout(self) -> str:
return self.url
class _PatchedWork:
"""The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the
self reference in methods an properties to point to the AppState.
Args:
state: The work state to patch
work_class: The work class to emulate
"""
def __init__(self, state: AppState, work_class: Type):
super().__init__()
self._state = state
self._work_class = work_class
def __getattr__(self, name: str) -> Any:
try:
return getattr(self._state, name)
except AttributeError:
# The name isn't in the state, so check if it's a callable or a property
attribute = inspect.getattr_static(self._work_class, name)
if callable(attribute):
attribute = attribute.__get__(self, self._work_class)
return attribute
elif isinstance(attribute, (staticmethod, property)):
return attribute.__get__(self, self._work_class)
# Look for the name in the instance (e.g. for private variables)
return object.__getattribute__(self, name)
def __setattr__(self, name: str, value: Any) -> None:
if name in ["_state", "_work_class"]:
return object.__setattr__(self, name, value)
if hasattr(self._state, name):
return setattr(self._state, name, value)
return object.__setattr__(self, name, value)
def _reduce_to_component_scope(state: AppState, component_name: str) -> AppState:
"""Given the app state, this utility traverses down to the level of the given component name."""
component_name_parts = component_name.split(".")[1:] # exclude root
component_state = state
for part in component_name_parts:
component_state = getattr(component_state, part)
return component_state
def _get_work_class() -> Callable:
"""Import the work class specified in the environment."""
work_name = os.environ["LIGHTNING_WORK"]
work_module_file = os.environ["LIGHTNING_WORK_MODULE_FILE"]
module = pydoc.importfile(work_module_file)
return getattr(module, work_name)
def _build_model(work: ServeStreamlit) -> None:
import streamlit as st
# Build the model (once per session, equivalent to gradio when enable_queue is Flase)
if "_model" not in st.session_state:
with st.spinner("Building model..."):
st.session_state["_model"] = work.build_model()
work._model = st.session_state["_model"]
def _main() -> None:
# Get the AppState
app_state = AppState(plugin=StreamLitStatePlugin())
work_state = _reduce_to_component_scope(app_state, os.environ["LIGHTNING_COMPONENT_NAME"])
# Create the patched work
work_class = _get_work_class()
patched_work = _PatchedWork(work_state, work_class)
# Build and attach the model
_build_model(patched_work)
# Render
patched_work.render()
if __name__ == "__main__":
_main()