From 1d0f7b7240213e0c8ed5f74d5e7d8356ed020ea0 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 23 Nov 2022 14:18:15 +0000 Subject: [PATCH] Improve Components API (#12) --- README.md | 13 +++++++------ dreambooth_component.py | 2 +- lightning_diffusion/base_diffusion.py | 4 ++++ requirements.txt | 3 +-- serve_diffusion_component.py | 4 ++-- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index ac65dfd..dc5154d 100644 --- a/README.md +++ b/README.md @@ -6,22 +6,22 @@ Lightning Diffusion provides components to finetune and serve diffusion model on ### Serve ANY Diffusion Models ```python -# !pip install torch diffusers lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git +# !pip install lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git import lightning as L -import torch, diffusers +import diffusers from lightning_diffusion import BaseDiffusion, models class ServeDiffusion(BaseDiffusion): def setup(self, *args, **kwargs): - self._model = diffusers.StableDiffusionPipeline.from_pretrained( + self.model = diffusers.StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", **models.extras ).to(self.device) def predict(self, data): - out = self.model(prompt=data.prompt) + out = self.model(prompt=data.prompt, num_inference_steps=23) return {"image": self.serialize(out[0][0])} @@ -36,12 +36,13 @@ Use the DreamBooth fine-tuning methodology from the paper `Fine Tuning Text-to-I ```python import lightning as L from lightning_diffusion import BaseDiffusion, DreamBoothTuner, models -import torch, diffusers +from diffusers import StableDiffusionPipeline + class ServeDreamBoothDiffusion(BaseDiffusion): def setup(self): - self._model = diffusers.StableDiffusionPipeline.from_pretrained( + self.model = StableDiffusionPipeline.from_pretrained( **models.get_kwargs("CompVis/stable-diffusion-v1-4", self.weights_drive), ).to(self.device) diff --git a/dreambooth_component.py b/dreambooth_component.py index c5eb312..b41800c 100644 --- a/dreambooth_component.py +++ b/dreambooth_component.py @@ -6,7 +6,7 @@ class ServeDreamBoothDiffusion(BaseDiffusion): def setup(self): - self._model = StableDiffusionPipeline.from_pretrained( + self.model = StableDiffusionPipeline.from_pretrained( **models.get_kwargs("CompVis/stable-diffusion-v1-4", self.weights_drive), ).to(self.device) diff --git a/lightning_diffusion/base_diffusion.py b/lightning_diffusion/base_diffusion.py index 3955477..b72553c 100644 --- a/lightning_diffusion/base_diffusion.py +++ b/lightning_diffusion/base_diffusion.py @@ -87,6 +87,10 @@ def model(self) -> StableDiffusionPipeline: assert self._model return self._model + @model.setter + def model(self, model) -> None: + self._model = model + @property def device(self): import torch diff --git a/requirements.txt b/requirements.txt index 7d75728..dfd78e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ jsonargparse[signatures] pyyaml==5.4.0 protobuf<4.21.0 lightning_hpo -lightning Pillow>=2.9.0 numpy>=1.9.2 diffusers==0.7.2 @@ -12,6 +11,6 @@ deepspeed ftfy lightning-api-access clip@git+https://github.com/openai/CLIP.git -lightning==1.8.1 +lightning==1.8.3 redis diff --git a/serve_diffusion_component.py b/serve_diffusion_component.py index 7a1efa1..f81bbad 100644 --- a/serve_diffusion_component.py +++ b/serve_diffusion_component.py @@ -1,13 +1,13 @@ # !pip install lightning_diffusion@git+https://github.com/Lightning-AI/lightning-diffusion.git import lightning as L -import torch, diffusers +import diffusers from lightning_diffusion import BaseDiffusion, models class ServeDiffusion(BaseDiffusion): def setup(self, *args, **kwargs): - self._model = diffusers.StableDiffusionPipeline.from_pretrained( + self.model = diffusers.StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", **models.extras ).to(self.device)