-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
gradio.py
83 lines (66 loc) · 2.48 KB
/
gradio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import abc
import os
from functools import partial
from types import ModuleType
from typing import Any, List, Optional
from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.imports import _is_gradio_available, requires
if _is_gradio_available():
import gradio
else:
gradio = ModuleType("gradio")
class ServeGradio(LightningWork, abc.ABC):
"""The ServeGradio Class enables to quickly create a ``gradio`` based UI for your LightningApp.
In the example below, the ``ServeGradio`` is subclassed to deploy ``AnimeGANv2``.
.. literalinclude:: ../../../examples/app_components/serve/gradio/app.py
:language: python
The result would be the following:
.. image:: https://pl-flash-data.s3.amazonaws.com/assets_lightning/anime_gan.gif
:alt: Animation showing how to AnimeGANv2 UI would looks like.
"""
inputs: Any
outputs: Any
examples: Optional[List] = None
enable_queue: bool = False
title: Optional[str] = None
description: Optional[str] = None
def __init__(self, *args, **kwargs):
requires("gradio")(super().__init__(*args, **kwargs))
assert self.inputs
assert self.outputs
self._model = None
# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)
@property
def model(self):
return self._model
@abc.abstractmethod
def predict(self, *args, **kwargs):
"""Override with your logic to make a prediction."""
@abc.abstractmethod
def build_model(self) -> Any:
"""Override to instantiate and return your model.
The model would be accessible under self.model
"""
def run(self, *args, **kwargs):
if self._model is None:
self._model = self.build_model()
fn = partial(self.predict, *args, **kwargs)
fn.__name__ = self.predict.__name__
gradio.Interface(
fn=fn,
inputs=self.inputs,
outputs=self.outputs,
examples=self.examples,
title=self.title,
description=self.description,
).launch(
server_name=self.host,
server_port=self.port,
enable_queue=self.enable_queue,
)
def configure_layout(self) -> str:
return self.url