README.md

December 30, 2024 · View on GitHub

CrossGET: Cross-Guided Ensemble of Tokens for Accelerating Vision-Language Transformers

Paper ArXiv
On LLaVA-1.5 | CLIP | CoOp

On LLaVA-1.5

⚙️ Installation

The code is tested on Pytorch==2.1.1, cuda==12.1, and python==3.10.13. Please follow LLaVA-1.5 to install other dependencies.

📑 Evaluation

  1. Download playground/data from LLaVA-1.5.

  2. Follow instructions here for preparing datasets.

  3. Download following checkpoints and put them under checkpoints/.

    ModelLink
    LLaVA-1.5-7B with CrossGETGoogle Drive
    LLaVA-1.5-13B with CrossGETGoogle Drive
  4. Use scripts under LLaVA/scripts/v1_5/eval and follow instructions here for evaluation. Logs are provided under LLaVA/log.

    DatasetVQAv2GQAVisWizSQA^IVQA^TPOPEMMEMMBMMB^CNSEED^I
    LLaVA-1.5-7B78.562.050.066.858.285.91510.764.358.366.2
    w/ CrossGET (~1.9x Tput)77.361.447.766.754.983.91510.264.755.264.4
    DatasetVQAv2GQAVisWizSQA^IVQA^TPOPEMMEMMBMMB^CNSEED^I
    LLaVA-1.5-13B80.063.353.671.661.385.91531.367.763.668.2
    w/ CrossGET (~2.0x Tput)78.762.651.871.458.084.91548.866.362.067.5

📚 Visual Instruction Tuning

  1. Download playground/data from LLaVA-1.5.

  2. Follow instructions here for preparing datasets.

  3. Run python LLaVA/scripts/construct_dataset.py to create llava_v1_5_mix67k.json.

  4. Follow instructions here for visual instruction tuning. For example, use LLaVA/scripts/v1_5/finetune_task.sh

    #!/bin/bash
    
    deepspeed llava/train/train_mem.py \
        --deepspeed ./scripts/zero3.json \
        --model_name_or_path liuhaotian/llava-v1.5-7b \
        --version v1 \
        --data_path ./playground/data/llava_v1_5_mix67k.json \
        --image_folder ./playground/data \
        --vision_tower openai/clip-vit-large-patch14-336 \
        --mm_projector_type mlp2x_gelu \
        --mm_vision_select_layer -2 \
        --mm_use_im_start_end False \
        --mm_use_im_patch_token False \
        --image_aspect_ratio pad \
        --group_by_modality_length True \
        --bf16 True \
        --output_dir ./checkpoints/llava-v1.5-7b-mix67k-ours \
        --num_train_epochs 1 \
        --per_device_train_batch_size 16 \
        --per_device_eval_batch_size 4 \
        --gradient_accumulation_steps 1 \
        --evaluation_strategy "no" \
        --save_strategy "steps" \
        --save_steps 50000 \
        --save_total_limit 1 \
        --learning_rate 2e-5 \
        --weight_decay 0. \
        --warmup_ratio 0.03 \
        --lr_scheduler_type "cosine" \
        --logging_steps 1 \
        --tf32 True \
        --model_max_length 2048 \
        --gradient_checkpointing True \
        --dataloader_num_workers 4 \
        --lazy_preprocess True \
        --report_to wandb
    

On CLIP

⚙️ Installation

The code is tested on Pytorch==2.0.0, cuda==11.7, and python==3.9.16. The dependencies can be installed by:

conda env create -f environment.yml

📑 Evaluation

  1. Download the Flickr30k dataset, unzip it under the datasets folder, and accordingly modify the image_root in config.

  2. Download annotations from Google Drive or Baidu Drive, unzip it under the annotation folder, and accordingly modify the annotation in config.

  3. Download the following checkpoint and put it under output/.

    ModelLink
    CLIP with CrossGETGoogle Drive
  4. Use the following script to evaluate. Logs are provided under CLIP/log.

    python -m torch.distributed.run --nproc_per_node=8 train_retrieval_clip.py \
    --config ./configs/retrieval_flickr_clip.yaml --evaluate --reduce ours --rv 16 --rl 0 \
    --pretrained output/train_retrieval_flickr_clip_ours/checkpoint_best.pth \
    --output_dir output/evaluate_retrieval_flickr_clip_ours_test
    
    Image to Text (Flickr30k)Recall@1Recall5Recall@10
    CLIP92.199.199.7
    w/ CrossGET (~58% GFLOPs)92.199.799.8
    Text to Image (Flickr30k)Recall@1Recall5Recall@10
    CLIP79.395.798.0
    w/ CrossGET (~58% GFLOPs)79.695.798.0

📚 Fine-tuning

  1. Download the Flickr30k dataset, unzip it under the datasets folder, and accordingly modify the image_root in config.

  2. Download annotations from Google Drive or Baidu Drive, unzip it under the annotation folder, and accordingly modify the annotation in config.

  3. Download the following checkpoint and put it under output/.

    ModelLink
    CLIPGoogle Drive
  4. Use the following scripts to fine-tuning. Logs are provided under CLIP/log.

    # Vision-only
    python -W ignore -m torch.distributed.run --nproc_per_node=8 train_retrieval_clip.py \
    --config ./configs/retrieval_flickr_clip.yaml --lr 1e-5 --epoch 12 --reduce ours --rv 16 --rl 0 \
    --pretrained output/finetune_retrieval_flickr_clip/checkpoint_best.pth \
    --output_dir output/train_retrieval_flickr_clip_ours
    
    # Language-only
    python -W ignore -m torch.distributed.run --nproc_per_node=8 train_retrieval_clip.py \
    --config ./configs/retrieval_flickr_clip.yaml --lr 1e-5 --epoch 12 --reduce ours --rv 0 --rl 6 \
    --pretrained output/finetune_retrieval_flickr_clip/checkpoint_best.pth \
    --output_dir output/train_retrieval_flickr_clip_ours
    
    # Vision and Language
    python -W ignore -m torch.distributed.run --nproc_per_node=8 train_retrieval_clip.py \
    --config ./configs/retrieval_flickr_clip.yaml --lr 1e-5 --epoch 12 --reduce ours --rv 16 --rl 6 \
    --pretrained output/finetune_retrieval_flickr_clip/checkpoint_best.pth \
    --output_dir output/train_retrieval_flickr_clip_ours
    

On CoOp Benchmark

⚙️ Installation

The code is tested on Pytorch==2.0.0, cuda==11.7, and python==3.9.17. Please follow CoOp to install other dependencies.

📑 Evaluation

  1. Prepare datasets following instructions here.

  2. Download checkpoints ending without _original from the following link and put them under output/. Logs are also provided in the same link.

    ModelLink
    CoOp with CrossGETGoogle Drive
  3. Adjust model-dir in CoOp/eval.sh accordingly.

  4. Use the following scripts to evaluate.

    # Second last parameter: 16 for ~58% GFLOPs, 12 for ~69% GFLOPs, and 8 for ~80% GFLOPs
    ./eval.sh stanford_cars vit_b16 ours 16 0 
    ./eval.sh oxford_flowers vit_b16 ours 16 0 
    ./eval.sh food101 vit_b16 ours 16 0 
    ./eval.sh fgvc_aircraft vit_b16 ours 16 0 
    ./eval.sh sun397 vit_b16 ours 16 0 
    ./eval.sh dtd vit_b16 ours 16 0 
    ./eval.sh eurosat vit_b16 ours 16 0 
    ./eval.sh ucf101 vit_b16 ours 16 0 
    ./eval.sh caltech101 vit_b16 ours 16 0 
    ./eval.sh oxford_pets vit_b16 ours 16 0 
    ./eval.sh imagenet vit_b16_ep50 ours 16 0 
    
    DatasetImageNetCaltech101OxfordPetsStanfordCarsFlowers102Food101FGVCAircraftSUN397DTDEuroSATUCF101Average
    CoOp71.195.493.377.595.686.537.375.165.882.683.778.5
    w/ CrossGET (~80% GFLOPs)70.894.690.881.995.882.043.774.165.788.482.279.1
    w/ CrossGET (~69% GFLOPs)70.294.990.181.195.081.543.173.565.986.981.978.6
    w/ CrossGET (~58% GFLOPs)67.693.989.576.693.379.741.372.164.284.580.576.7

📚 Prompt Tuning

  1. Prepare datasets following instructions here.

  2. Download checkpoints ending with _original from the following link and put them under output/. Logs are also provided in the same link.

    ModelLink
    CoOpGoogle Drive
  3. Adjust model-dir in CoOp/train.sh accordingly.

  4. Use the following scripts to prompt tuning.

    # Second last parameter: 16 for ~58% GFLOPs, 12 for ~69% GFLOPs, and 8 for ~80% GFLOPs
    ./train.sh stanford_cars vit_b16 end 16 16 False ours 16 0
    ./train.sh oxford_flowers vit_b16 end 16 16 False ours 16 0
    ./train.sh food101 vit_b16 end 16 16 False ours 16 0
    ./train.sh fgvc_aircraft vit_b16 end 16 16 False ours 16 0
    ./train.sh sun397 vit_b16 end 16 16 False ours 16 0
    ./train.sh dtd vit_b16 end 16 16 False ours 16 0
    ./train.sh eurosat vit_b16 end 16 16 False ours 16 0
    ./train.sh ucf101 vit_b16 end 16 16 False ours 16 0
    ./train.sh caltech101 vit_b16 end 16 16 False ours 16 0
    ./train.sh oxford_pets vit_b16 end 16 16 False ours 16 0
    ./train.sh imagenet vit_b16_ep50 end 16 16 False ours 16 0
    

💬 Acknowledgments

This code is built upon LLaVA, ToMe, UPop, BLIP, and CoOp. Thanks for these awesome open-source projects!

✨ Citation

@article{shi2023crossget,
  title={Crossget: Cross-guided ensemble of tokens for accelerating vision-language transformers},
  author={Shi, Dachuan and Tao, Chaofan and Rao, Anyi and Yang, Zhendong and Yuan, Chun and Wang, Jiaqi},
  journal={arXiv preprint arXiv:2305.17455},
  year={2023}
}