Skip to content
This repository has been archived by the owner on Feb 15, 2023. It is now read-only.

Improve Components API #12

Merged
merged 4 commits into from Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions README.md
Expand Up @@ -6,22 +6,22 @@ Lightning Diffusion provides components to finetune and serve diffusion model on
### Serve ANY Diffusion Models

```python
# !pip install torch diffusers lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git
# !pip install lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git
import lightning as L
import torch, diffusers
import diffusers
from lightning_diffusion import BaseDiffusion, models


class ServeDiffusion(BaseDiffusion):

def setup(self, *args, **kwargs):
self._model = diffusers.StableDiffusionPipeline.from_pretrained(
self.model = diffusers.StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
**models.extras
).to(self.device)

def predict(self, data):
out = self.model(prompt=data.prompt)
out = self.model(prompt=data.prompt, num_inference_steps=23)
return {"image": self.serialize(out[0][0])}


Expand All @@ -36,12 +36,13 @@ Use the DreamBooth fine-tuning methodology from the paper `Fine Tuning Text-to-I
```python
import lightning as L
from lightning_diffusion import BaseDiffusion, DreamBoothTuner, models
import torch, diffusers
from diffusers import StableDiffusionPipeline


class ServeDreamBoothDiffusion(BaseDiffusion):

def setup(self):
self._model = diffusers.StableDiffusionPipeline.from_pretrained(
self.model = StableDiffusionPipeline.from_pretrained(
**models.get_kwargs("CompVis/stable-diffusion-v1-4", self.weights_drive),
).to(self.device)

Expand Down
2 changes: 1 addition & 1 deletion dreambooth_component.py
Expand Up @@ -6,7 +6,7 @@
class ServeDreamBoothDiffusion(BaseDiffusion):

def setup(self):
self._model = StableDiffusionPipeline.from_pretrained(
self.model = StableDiffusionPipeline.from_pretrained(
**models.get_kwargs("CompVis/stable-diffusion-v1-4", self.weights_drive),
).to(self.device)

Expand Down
4 changes: 4 additions & 0 deletions lightning_diffusion/base_diffusion.py
Expand Up @@ -87,6 +87,10 @@ def model(self) -> StableDiffusionPipeline:
assert self._model
return self._model

@model.setter
def model(self, model) -> None:
self._model = model

@property
def device(self):
import torch
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Expand Up @@ -3,7 +3,6 @@ jsonargparse[signatures]
pyyaml==5.4.0
protobuf<4.21.0
lightning_hpo
lightning
Pillow>=2.9.0
numpy>=1.9.2
diffusers==0.7.2
Expand All @@ -12,6 +11,6 @@ deepspeed
ftfy
lightning-api-access
clip@git+https://github.com/openai/CLIP.git
lightning==1.8.1
lightning==1.8.3
redis

4 changes: 2 additions & 2 deletions serve_diffusion_component.py
@@ -1,13 +1,13 @@
# !pip install lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git
import lightning as L
import torch, diffusers
import diffusers
from lightning_diffusion import BaseDiffusion, models


class ServeDiffusion(BaseDiffusion):

def setup(self, *args, **kwargs):
self._model = diffusers.StableDiffusionPipeline.from_pretrained(
self.model = diffusers.StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
**models.extras
).to(self.device)
Expand Down