Skip to content
This repository has been archived by the owner on Feb 15, 2023. It is now read-only.

Commit

Permalink
Improve Components API (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Nov 23, 2022
1 parent 263d583 commit 1d0f7b7
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 11 deletions.
13 changes: 7 additions & 6 deletions README.md
Expand Up @@ -6,22 +6,22 @@ Lightning Diffusion provides components to finetune and serve diffusion model on
### Serve ANY Diffusion Models

```python
# !pip install torch diffusers lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git
# !pip install lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git
import lightning as L
import torch, diffusers
import diffusers
from lightning_diffusion import BaseDiffusion, models


class ServeDiffusion(BaseDiffusion):

def setup(self, *args, **kwargs):
self._model = diffusers.StableDiffusionPipeline.from_pretrained(
self.model = diffusers.StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
**models.extras
).to(self.device)

def predict(self, data):
out = self.model(prompt=data.prompt)
out = self.model(prompt=data.prompt, num_inference_steps=23)
return {"image": self.serialize(out[0][0])}


Expand All @@ -36,12 +36,13 @@ Use the DreamBooth fine-tuning methodology from the paper `Fine Tuning Text-to-I
```python
import lightning as L
from lightning_diffusion import BaseDiffusion, DreamBoothTuner, models
import torch, diffusers
from diffusers import StableDiffusionPipeline


class ServeDreamBoothDiffusion(BaseDiffusion):

def setup(self):
self._model = diffusers.StableDiffusionPipeline.from_pretrained(
self.model = StableDiffusionPipeline.from_pretrained(
**models.get_kwargs("CompVis/stable-diffusion-v1-4", self.weights_drive),
).to(self.device)

Expand Down
2 changes: 1 addition & 1 deletion dreambooth_component.py
Expand Up @@ -6,7 +6,7 @@
class ServeDreamBoothDiffusion(BaseDiffusion):

def setup(self):
self._model = StableDiffusionPipeline.from_pretrained(
self.model = StableDiffusionPipeline.from_pretrained(
**models.get_kwargs("CompVis/stable-diffusion-v1-4", self.weights_drive),
).to(self.device)

Expand Down
4 changes: 4 additions & 0 deletions lightning_diffusion/base_diffusion.py
Expand Up @@ -87,6 +87,10 @@ def model(self) -> StableDiffusionPipeline:
assert self._model
return self._model

@model.setter
def model(self, model) -> None:
self._model = model

@property
def device(self):
import torch
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Expand Up @@ -3,7 +3,6 @@ jsonargparse[signatures]
pyyaml==5.4.0
protobuf<4.21.0
lightning_hpo
lightning
Pillow>=2.9.0
numpy>=1.9.2
diffusers==0.7.2
Expand All @@ -12,6 +11,6 @@ deepspeed
ftfy
lightning-api-access
clip@git+https://github.com/openai/CLIP.git
lightning==1.8.1
lightning==1.8.3
redis

4 changes: 2 additions & 2 deletions serve_diffusion_component.py
@@ -1,13 +1,13 @@
# !pip install lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git
import lightning as L
import torch, diffusers
import diffusers
from lightning_diffusion import BaseDiffusion, models


class ServeDiffusion(BaseDiffusion):

def setup(self, *args, **kwargs):
self._model = diffusers.StableDiffusionPipeline.from_pretrained(
self.model = diffusers.StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
**models.extras
).to(self.device)
Expand Down

0 comments on commit 1d0f7b7

Please sign in to comment.