Skip to content

function2-llx/PUMIT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PUMIT

This is the code repository for the paper Building Universal Foundation Models for Medical Image Analysis with Spatially Adaptive Networks (arxiv, former name: Pre-trained Universal Medical Image Transformer).

Pre-trained model weights can be found in the release page.

The code release is for reproducing results of the paper and also used for double-blinded review, which may not look nice. You may switch to the dev branch to see a cleaner version of code (but is also further developed, so it may not be consistent).

Create Environment

Mamba is recommended to manage virtual environments.

./create-env.zsh

Prepare data

Download datasets and put them under the ./datasets folder, e.g.:

datasets
├── AbdomenCT-1K
├── ACDC
├── AMOS22
├── BCV
├── BrainPTM-2021
├── BraTS2023
...

Then run pre-proecssing script: python scripts/data/process.py <dataset name>. More details can be found in the script.

Run Visual Tokenizer Pre-Training

python scripts/tokenizer/main.py -c conf/tokenizer/simple/main.yaml --data.dl_conf.train_batch_size 8 --data.dl_conf.num_workers 10 --training.benchmark true --model.quantize.mode soft --model.quantize.num_embeddings 1024 --loss.entropy_weight 1 --loss.quant_weight 0.03

Note that you may need to adjust the batch size according to the number of GPUs.

Run ViT Pre-Training

scripts/model/mim-b.zsh --data.dl_conf.train_batch_size 14 --data.dl_conf.num_workers 10 --model.tokenizer.path <tokenizer checkpoit path>

Run Downstream Tasks

Assume that the pre-trained checkpoint is placed at ./pre-trained/pumit.ckpt.

Classification

Execute scripts under scripts/downstream/medmnistv2 for training and evaluation for each model.

BTCV Segmentation

Download the BTCV data from the official challenge, and download the train/validation split file from SMIT's repository, organize the files as following:

data
├── BTCV
│   ├── smit.json
│   ├── Testing
│   └── Training

Then run fine-tuning and inference:

scripts/downstream/btcv/pumit-b.zsh --data.num_workers 10 --data.ratio 1 --trainer.logger.name pumit-b --data.train_batch_size 4
scripts/downstream/btcv/test-b.zsh --data.num_workers 10 --ckpt_path <output checkpoint path> --trainer.logger.name pumit-b

CHAOS Segmentation

First, run the pre-processing script to convert the DICOM series into NIFTI format: python scripts/downstream/chaos/preprocess.py

Then run fine-tuning and inference:

scripts/downstream/chaos/pumit-b.zsh --data.num_workers 10 --data.ratio 1 --trainer.logger.name pumit-b --data.train_batch_size 8
scripts/downstream/chaos/predict-b.zsh --data.num_workers 10 --ckpt_path <output checkpoint path> --trainer.logger.name pumit-b