This repository has been archived by the owner on Feb 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
base_diffusion.py
120 lines (95 loc) · 3.3 KB
/
base_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import abc
import base64
import io
import os
import lightning as L
from typing import Optional
from copy import deepcopy
from lightning.app.utilities.app_helpers import is_overridden
from lightning_diffusion.diffusion_serve import DiffusionServe
from lightning_diffusion.lite_finetuner import Finetuner
from lightning.app.storage import Drive
from diffusers import StableDiffusionPipeline
def trimmed_flow(flow: 'L.LightningFlow') -> 'L.LightningFlow':
"""Trims a flow to not have any of the internal attributes.
"""
backend = flow._backend
flow._backend = None
for f in flow.flows:
f._backend = None
for w in flow.works():
w._backend = None
# also taking a deep copy
flow_copy = deepcopy(flow)
if backend:
L.LightningFlow._attach_backend(flow, backend)
return flow_copy
class LoadBalancer(L.LightningFlow):
def __init__(self, server: L.LightningWork, num_replicas: int = 1):
super().__init__()
self.server = server
self.num_replicas = num_replicas
self.url = ""
def run(self):
self.server.run()
self.url = self.server.url
def configure_layout(self):
return {'name': 'API', 'content': self.server}
class BaseDiffusion(L.LightningFlow, abc.ABC):
def __init__(
self,
finetune_cloud_compute: Optional[L.CloudCompute] = None,
serve_cloud_compute: Optional[L.CloudCompute] = None,
num_replicas=1
):
super().__init__()
if not is_overridden("predict", instance=self, parent=BaseDiffusion):
raise Exception("The predict method needs to be overriden.")
self.weights_drive = Drive("lit://weights")
self._model = None
self._device = None
_trimmed_flow = trimmed_flow(self)
self.finetuner = None
if is_overridden("finetune", instance=self, parent=BaseDiffusion):
self.finetuner = Finetuner(
flow=_trimmed_flow,
cloud_compute=finetune_cloud_compute,
)
self.load_balancer = LoadBalancer(
DiffusionServe(_trimmed_flow, cloud_compute=serve_cloud_compute, start_with_flow=False),
num_replicas=num_replicas,
)
@staticmethod
def serialize(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
@property
def model(self) -> StableDiffusionPipeline:
assert self._model
return self._model
@model.setter
def model(self, model) -> None:
self._model = model
@property
def device(self):
import torch
local_rank = os.getenv("LOCAL_RANK", "0")
return f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
@abc.abstractmethod
def setup(self, *args, **kwargs):
pass
@abc.abstractmethod
def predict(self, request):
pass
def finetune(self):
raise NotImplementedError("Fine tuning is not implemented.")
def run(self):
if self.finetuner:
self.finetuner.run()
if self.finetuner.has_succeeded:
self.load_balancer.run()
else:
self.load_balancer.run()
def configure_layout(self):
return {'name': 'API', 'content': self.load_balancer.url}