Skip to content

Pytorch/JAX implementation of examples in Generative Deep Learning 2nd Edition by David Foster

License

Notifications You must be signed in to change notification settings

terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generative Deep Learning 2nd Edition in JAX and PyTorch 2.0

This repository includes the Pytorch 2.0 and JAX implementations of examples in Generative Deep Learning 2nd Edition by David Foster. You can find the physical/kindle book on Amazon or read it online through O'REILLY library (paid subscription needed).

Motivation of this project

I started my journey of deep learning with the help of the first edition of this book. The author introduces topics of generative deep learning in clear and concise way that hels me quickly grasp the key points of type of algorithms without being freak out by heavy mathematics. The codes in the book, written in Tensorflow and Keras, helps me quickly making theories into practice.
We now have other popular deep learning frameworks, like PyTorch and JAX, used by various ML communities. Therefore, I want to translate the Tensorflow and Keras code provided in the book to PyTorch and JAX to help more people study this valuable book more easily.

File structure

The files are organized by frameworks:

├── JAX_FLAX
│   ├── chapter_**_**
│   │   ├── **.ipynb
│   ├── requirements.txt
├── PyTorch
│   ├── chapter_**_**
│   │   ├── **.ipynb
│   ├── requirements.txt
├── .gitignore

Environment setup

I recommend using the separated environments for PyTorch and JAX to avoid potential conflicts on CUDA versions or other related packages. I use miniconda to help managing packages for both environments.

Configure PyTorch environment:

cd PyTorch
conda create -n GDL_PyTorch python==3.9
conda activate GDL_PyTorch
pip install -r requirements.txt

NOTE: If you're using PyTorch on WSL, please add export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH to ~/.bashrc to avoid kernel restarting problems.

Configure JAX environment:

cd JAX_FLAX
conda create -n GDL_JAX python==3.9
conda activate GDL_JAX
pip install -r requirements.txt

.ipynb is the extension of the python notebook; I use Jupyter Lab to run the notebooks in this repository.

Model list

Chapter 2 Deep Learning

Chapter 3 Variational AutoEncoder

Chapter 4 Generative Adversarial Networks

Chapter 5 Autoregressive Models

Chapter 6 Normalizing Flow Models

Chapter 7 Energy-Based Models

About

Pytorch/JAX implementation of examples in Generative Deep Learning 2nd Edition by David Foster

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published