Skip to content

alshedivat/diffusion-playground

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

49 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Playground for experimenting with diffusion models ๐ŸŒ€

Code style: black

This repository includes the following:

  • diffusion package that provides a clean, modular, and minimalistic implementation of different components and algorithms used in diffusion-based generative modeling (with references to key papers), and
  • playground folder that contains a collection of examples that demonstrate diffusion-based generative modeling on different kinds of data (2D points, MNIST, CIFAR10, 3D point clouds, etc.)

Diffusion

The package consists of three core modules:

  1. denoisers module provides:
    • KarrasDenoiser: A thin wrapper around arbitrary neural nets that enable preconditioning of inputs and outputs, as proposed by Karras et al., (2022). The wrapper is agnostic to model architectures and only expects the shape of input and output tensors to match.
    • KarrasOptimalDenoiser: The optimal denoiser that corresponds to the analytical minimum of the denoising loss for a given training dataset.
  2. training module provides functionality for training diffusion models:
    • Loss functions (code): provides denoising MSE loss functions, including the original simple denoising loss of Ho et al. (2020) and preconditioned MSE loss of Karras et al., (2022).
    • Loss weighting schemes (code): a collection of weighting schemes that assign different weights to losses computed for different noise levels, including the SNR-based weighting proposed by Hang et al. (2022).
    • Noise level samplers (code): determine how noise levels are sampled during training at each step; the denoising loss is computed for the sampled noise levels, averaged, and optimized w.r.t. model parameters.
    • Lightning model (code): a LightningModule class that puts all pieces together and enables training denoising models using Pytorch Lightning.
  3. inference modules provides functionality for sampling from trained diffusion models:

Playground

1. 2D points diffusion

This is a very toy example, where each data instance is a 2D point that lies on a swiss-roll 1D manifold. Given that the data is so simple, it's a perfect playground for experimenting with different approaches to training and inference, visualizing diffusion trajectories, and building intuition. Both training and inference can comfortably run on a laptop (it takes a minute or so to train the model to convergence).

Colab notebook: (TODO: add link to the notebook)

2. MNIST diffusion

Another toy example, where diffusion model is trained on MNIST. Model architectures are scaled down versions of the U-nets used on CIFAR10 and ImageNet benchmarks (all the architecture code is copied from https://github.com/NVlabs/edm/blob/main/training/networks.py verbatim). It takes about 1 hour to train an MNIST denoiser in Google Colab using a T4 GPU for 20 epochs or so. And running inference takes just a few seconds.

Colab notebook: (TODO: add link to the notebook)

3. CIFAR10 diffusion

In this example, we train U-net diffusion model on CIFAR10 benchmark. The model can be trained using playground/cifar10/train.py script (takes a few days to train on multiple GPUs), using the architecture and the best hyperparameters given by Karras et al. (2022). Running inference takes just a few seconds, and can be done using different ODE solvers.

Colab notebook: (TODO: add link to the notebook)

About

A playground for experimenting with diffusion models ๐ŸŒ€

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages