SIMLA: Single-Stream Multi-Level Alignment for Vision-Language Pretraining, ECCV 2022 (NEC Labs)
August 19, 2022 ยท View on GitHub
SIMLA: Single-Stream Multi-Level Alignment for Vision-Language Pretraining, ECCV 2022 (NEC Labs)
This is the official PyTorch implementation of SIMLA. The repository is heavily based on salesforce/ALBEF, and supports vision-language pretraining and downstream task finetuning for several tasks.
Setup
Dependencies
conda env create --name simla --file environment.yaml
Data
See individual sections below for instructions.
Checkpoints
- pretrained on 4m images
- Use this one if you want to finetune the model for another downstream VL task, like VQA.
- finetuned on COCO
- Use this one for retrieval tasks.
The checkpoints are around 3GB, and contain the optimizer state and everything else needed to resume training.
Pretraining
Downloading these exact datasets is unnecessary - the pretraining only requires image text pairs, so any image-text pair dataset will do.
- Download COCO from the official website (use COCO2014, download it all).
- Download SBU captions using Huggingface.
- Download Conceptual Captions using Huggingface.
- Download the weights of DALL-E's D-VAE (encoder, decoder), and place them in a folder.
- Edit
configs/Pretrain.yamland changeimage_tokenizer_path: /net/acadia10a/data/zkhan/dall-e-tokenizer-weightsto the folder where you downloaded the dall-e tokenizer weights. - Generate the pretraining JSON. You can download an example from ALBEF.
- The JSON is a list of dictionaries, one for each image:
{'image': '/absolute/path/to/image', 'caption': 'the caption of image'}. - We made a JSON file for each dataset we used (COCO, SBU, CC3M), but you can just have one file for all the image text-pairs.
- The JSON is a list of dictionaries, one for each image:
- Edit
configs/Pretrain.yamland point it to your JSON, sotrain_file: /path/to/your/pretraining.json. - Run the command below.
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config configs/Pretrain.yaml --output_dir <where to save>
Pretraining on 8x A100s with a batch size of 64 for 30 epochs takes roughly 7 days, and uses about 73GB of GPU memory per GPU.
If you're using A100s or A6000s, you may need to run export NCCL_P2P_DISABLE=1 in the shell before training.
Image Text Retrieval
- Download the JSON files for finetuning here.
- Next, download the COCO2017 train images and val images from the official website, and move all the images into one directory.
Finetuned (COCO)
Edit train_file, val_file and test_file in configs/Retrieval_coco.yaml to point to their respective JSON files you downloaded in Step 1.
Note that the test annotations are not public, so we report results on the validation split following previous work.
python -m torch.distributed.launch --master_port=49770 --nproc_per_node=2 --use_env Retrieval.py \
--config ./configs/Retrieval_coco.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
Zero-Shot (Flickr)
Download Flickr30k from Kaggle.
Download the annotations for Flickr30k here.
Edit train_file, val_file and test_file in configs/Retrieval_flickr.yaml to point to their respective JSON files you downloaded in Step 1.
We do not use the validation split, so that key can be set to the name of the train or test file.
python -m torch.distributed.launch --master_port=47770 --nproc_per_node=2 --use_env zero_shot_retrieval.py --config ./configs/Retrieval_flickr.yaml --output_dir <where to save> --checkpoint <path of .pth checkpoint file> --evaluate
Finetuned (Flickr)
Same as the above.
python -m torch.distributed.launch --master_port=49770 --nproc_per_node=2 --use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
RefCOCO+ (Visual Grounding)
python -m torch.distributed.launch --master_port=49121 --nproc_per_node=2 --use_env Grounding.py \
--config ./configs/Grounding.yaml \
--output_dir <path/to/output> \
--gradcam_mode itm \
--block_num 8 \
--checkpoint <path/to/checkpoint.pth> \
VQA (Visual Question Answering)
python -m torch.distributed.launch --nproc_per_node=2 --use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
NLVR (Natural Language Visual Reasoning)
Pretraining
python -m torch.distributed.launch --nproc_per_node=2 --use_env Pretrain_NLVR.py \
--config ./configs/NLVR_Pretrain.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
Finetuning
python -m torch.distributed.launch --nproc_per_node=2 --use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
SNLI-VE (Visual Entailment)
python -m torch.distributed.launch --master_port=47770 --nproc_per_node=2 \
--use_env VE.py \
--config ./configs/VE.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
Citation
@inproceedings{SIMLA,
title={Single-Stream Multi-Level Alignment for Vision-Language Pretraining,
author={Zaid Khan and Vijay Kumar BG and Xiang Yu and Samuel Schulter and Manmohan Chandraker and Yun Fu},
year={2022},
booktitle={ECCV}
}