README.md
December 8, 2024 ยท View on GitHub
ImageFolder๐: Autoregressive Image Generation with Folded Tokens
Updates
- (2024.12.02) Code released. Also try our new work XQ-GAN for more extensions of ImageFolder.
- (2024.10.03) We are working on advanced training for the ImageFolder tokenizer.
- (2024.10.01) Repo created. Code and checkpoints will be released soon.
Model Zoo
We provide pre-trained tokenizers for image reconstruction on ImageNet.
| Training | Eval | Codebook Size | rFID โ | Link | Resolution | Utilization |
|---|---|---|---|---|---|---|
| ImageNet | ImageNet | 4096 | 0.80 | Huggingface | 256x256 | 100% |
| ImageNet | ImageNet | 8192 | 0.70 | Huggingface | 256x256 | 100% |
| ImageNet | ImageNet | 16384 | 0.67 | Huggingface | 256x256 | 100% |
We provide a pre-trained generator for class-conditioned image generation on ImageNet 256x256 resolution.
| Type | Dataset | Model Size | gFID โ | Link | Resolution |
|---|---|---|---|---|---|
| VAR | ImageNet | 362M | 2.60 | Huggingface | 256x256 |
Installation
Install all packages as
conda env create -f environment.yml
Dataset
We download the ImageNet2012 from the website and collect it as
ImageNet2012
โโโ train
โโโ val
If you want to train or finetune on other datasets, collect them in the format that ImageFolder (pytorch's ImageFolder) can recognize.
Dataset
โโโ train
โ โโโ Class1
โ โ โโโ 1.png
โ โ โโโ 2.png
โ โโโ Class2
โ โ โโโ 1.png
โ โ โโโ 2.png
โโโ val
Training code for tokenizer
Please login to Wandb first using
wandb login
rFID will be automatically evaluated and reported on Wandb. The best checkpoint on the val set will be saved.
torchrun --nproc_per_node=8 tokenizer/tokenizer_image/msvq_train.py --config configs/tokenizer.yaml
Please modify the configuration file as needed for your specific dataset. We list some important ones here.
vq_ckpt: ckpt_best.pt # resume
cloud_save_path: output/exp-xx # output dir
data_path: ImageNet2012/train # training set dir
val_data_path: ImageNet2012/val # val set dir
enc_tuning_method: 'full' # ['full', 'lora', 'frozen']
dec_tuning_method: 'full' # ['full', 'lora', 'frozen']
codebook_embed_dim: 32 # codebook dim
codebook_size: 4096 # codebook size
product_quant: 2 # branch number
codebook_drop: 0.1 # quantizer dropout rate
semantic_guide: dinov2 # ['none', 'dinov2']
Tokenizer linear probing
torchrun --nproc_per_node=8 tokenizer/tokenizer_image/linear_probing.py --config configs/tokenizer.yaml
Training code for VAR
We follow the VAR training code and our training cmd for reproducibility is
torchrun --nproc_per_node=8 train.py --bs=768 --alng=1e-4 --fp16=1 --alng=1e-4 --wpe=0.01 --tblr=8e-5 --data_path /mnt/localssd/ImageNet2012/ --encoder_model vit_base_patch14_dinov2.lvd142m --decoder_model vit_base_patch14_dinov2.lvd142m --product_quant 2 --semantic_guide dinov2 --num_latent_tokens 121 --v_patch_nums 1 1 2 3 3 4 5 6 8 11 --pn 1_1_2_3_3_4_5_6_8_11 --patch_size 11 --vae_ckpt /path/to/ckpt.pt --sem_half True
Inference code for ImageFolder
torchrun --nproc_per_node=8 inference.py --infer_ckpt /path/to/ckpt --data_path /path/to/ImageNet --encoder_model vit_base_patch14_dinov2.lvd142m --decoder_model vit_base_patch14_dinov2.lvd142m --product_quant 2 --semantic_guide dinov2 --num_latent_tokens 121 --v_patch_nums 1 1 2 3 3 4 5 6 8 11 --pn 1_1_2_3_3_4_5_6_8_11 --patch_size 11 --sem_half True --cfg 3.25 3.25 --top_k 750 --top_p 0.95
Ablation
| ID | Method | Length | rFID โ | gFID โ | ACC โ |
|---|---|---|---|---|---|
| ๐ถ1 | Multi-scale residual quantization (Tian et al., 2024) | 680 | 1.92 | 7.52 | - |
| ๐ถ2 | + Quantizer dropout | 680 | 1.71 | 6.03 | - |
| ๐ถ3 | + Smaller patch size K = 11 | 265 | 3.24 | 6.56 | - |
| ๐ถ4 | + Product quantization & Parallel decoding | 265 | 2.06 | 5.96 | - |
| ๐ถ5 | + Semantic regularization on all branches | 265 | 1.97 | 5.21 | - |
| ๐ถ6 | + Semantic regularization on one branch | 265 | 1.57 | 3.53 | 40.5 |
| ๐ท7 | + Stronger discriminator | 265 | 1.04 | 2.94 | 50.2 |
| ๐ท8 | + Equilibrium enhancement | 265 | 0.80 | 2.60 | 58.0 |
๐ถ1-6 are already in the released paper, and after that ๐ท7+ are advanced training settings used similar to VAR (gFID 3.30).
Generation
License
Acknowledge
We would like to thank the following repositories: LlamaGen, VAR and ControlVAR.
Citation
If our work assists your research, feel free to give us a star โญ or cite us using
@misc{li2024imagefolderautoregressiveimagegeneration,
title={ImageFolder: Autoregressive Image Generation with Folded Tokens},
author={Xiang Li and Hao Chen and Kai Qiu and Jason Kuen and Jiuxiang Gu and Bhiksha Raj and Zhe Lin},
year={2024},
eprint={2410.01756},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2410.01756},
}