Skip to content

teddykoker/image-gpt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

34 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Image GPT

PyTorch implementation of Image GPT, based on paper Generative Pretraining from Pixels (Chen et al.) and accompanying code.


Model-generated completions of half-images from test set. First column is input; last column is original image


iGPT-S pretrained on CIFAR10. Completions are fairly poor as the model was only trained on CIFAR10, not all of ImageNet.

WIP

  • Batched k-means on GPU for quantization of larger datasets (currently using sklearn.cluster.MiniBatchKMeans.)
  • BERT-style pretraining (currently only generative is supported.)
  • Load pretrained models from OpenAI.
  • Reproduce at least iGPT-S results.

According to their blog post, the largest model, iGPT-L (1.4 M parameters), was trained for 2500 V100-days. By greatly reducing the number of attention head, number of layers, and input size (which effects model size quadratically), we can train our own model (26 K parameters) on Fashion-MNIST on a single NVIDIA 2070 in less than 2 hours.

Usage

Pre-trained Models

Some pre-trained models are located in models directory. Run ./download.sh to download the cifar10 pretrained iGPT-S model.

Compute Centroids

Images are downloaded, and centroids are computed using k-means with num_clusters clusters. These centroids are used to quantize the images before they are fed into the model.

# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8

# creates data/<dataset>_centroids.npy

Note: Use the same num_clusters as num_vocab in your model.

Training

Models can be trained using src/run.py with the train subcommand.

Generative Pre-training

Models can be pretrained by specifying a dataset and model config. configs/s_gen.yml corresponds to iGPT-S from the paper, configs/xxs_gen.yml is an extra small model for trying on toy datasets with limited compute.

python src/run.py --dataset mnist train configs/xxs_gen.yml

Classification Fine-tuning

Pre-trained models can be fine-tuned by passing the path to the pre-trained checkpoint to --pretrained, along with the config file and dataset.

python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt`

Sampling

Figures like those seen above can be created using random images from test set:

# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt

Gifs like the one seen in my tweet can be made like so:

# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt