PyTorch Implementation of No Token Left Behind: Explainability-Aided Image Classification and Generation
June 4, 2022 ยท View on GitHub
Usage
1. Notebook for spatially conditioned image generation
2. Notebook for image editing
3. Notebook for image generation
4. Prompt engineering running instructions
First, follow DATASETS.md to install the datasets. Create the required enviromnet with
conda env create -f external/CoOp/dassl_env.yml
conda activate dassl
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
Then clone and install dassl under 'external' direrctory:
cd external/Dassl.pytorch/
python setup.py develop
cd ../../
To run the experiment please run:
python external/CoOp/train.py --root <dataset_root> --trainer CoOp --dataset-config-file <dataset config file> --config-file external/CoOp/configs/trainers/CoOp/<base model>_ep50.yaml --output-dir <output_dir> --model-dir <model_dir> --seed 1 DATASET.NUM_SHOTS 1 TRAINER.COOP.EXPL_WEIGHT <expl_lambda> TRAINER.COOP.CSC False TRAINER.COOP.RETURN_EXPL_SCORE True TRAINER.COOP.CLASS_TOKEN_POSITION middle TRAINER.COOP.N_CTX 16
Citation
@misc{Paiss2022NoTL,
url = {https://arxiv.org/abs/2204.04908},
author = {Paiss, Roni and Chefer, Hila and Wolf, Lior},
title = {No Token Left Behind: Explainability-Aided Image Classification and Generation},
publisher = {arXiv},
year = {2022}
}
Acknowledements
- Image manipulation code is based on StyleCLIP
- Image generation code is based on FuseDream
- Image generation with spatial conditioning code is based on VQGAN+CLIP and VQGAN
- Prompt engineering code is based on CoOp and Dassl
- Explainability method code is based on Transformer-MM-Explainability
License
This sample code is released under the LICENSE terms.