ProtoS-ViT
June 18, 2024 ยท View on GitHub
ProtoS-ViT
This is the implementation of ProtoS-Vit along the ressources required to evaluate the model. ProtoS-Vit is a novel architecture to turn any frozen ViT backbone into a prototypical model as shown in the figure below:

The model is evaluated both in terms of classification performance and explanability, outperforming current state-of-the art models across a range of metrics:
Installation
A dockerfile is provided if you want to easily run the project with pdm and the required packages.
Get Started
If you are running the container, you can start training your model with:
pdm run train_classification # train a model on the CUB dataset
The CUB dataset will be automatically saved in /data/.
Configurations are managed with hydra. Pre-made configurations for a range of experiments can be found in /config/experiments but you can also design your own experiment based on these examples. If you want to run your own experiment simply create a new my_experiment.yaml file in /config/experiments and then run:
python src/main_train experiment=my_experiment # train a model with your own experiment configuration
Configuration files for all experiments presented in the paper can be found under /config/experiments.
Datasets
General Datasets
The CUB and PETS dataset can be downloaded by setting download=True in the dataloader argument, i.e hydra config.
Stanford Cars
The download URL provided as part of the StanfordCars in torchvision is currently broken. The dataset can be downloaded using the following instructions.
Funny Birds
The Funny birds dataset can be downloaded from the initial repository with the following commands:
cd /path/to/dataset/
wget download.visinf.tu-darmstadt.de/data/funnybirds/FunnyBirds.zip
unzip FunnyBirds.zip
rm FunnyBirds.zip
Biomedical Datasets
The three biomedical datasets: ISIC 2019, RSNA, LC25000 (Lungs), use a random split between the training and test set. The split for each dataset is provided in /data/dataset_name.
ISIC 2019
Classification of skin lesions across nine different diagnostic categories. The dataset can be dowloaded from here and copied into data/isic_2019.
RSNA
Binary classification of chest x-rays for the presence of pneumonia cases. Data can be downloaded from Kaggle.
LC25000
Classification of lung and colon histopathological images. In this work, we focus on the lung dataset which aim to classifiy the images across three different classes. The dataset can be downloaded from here.
Explainability evaluation
The quantitative evaluation of the model's explainability relies on the FunnyBirds dataset as well as some of the metrics presented in the paper introducing this dataset by Hesse et al. The metrics presented by Hesse et al. can be computed for a model trained on the FunnyBirds dataset as follows:
python src/evaluation/main_evaluate_funny_birds --path_sim=path_sim # with the path to the folder where the trained model is saved.
Citation
If you find this code or idea useful, please consider citing our work:
@article{turbe2024protosvit,
title={ProtoS-ViT: Visual foundation models for sparse self-explainable classifications},
author={Hugues Turb\'{e} and Mina Bjelogrlic and Gianmarco Mengaldo and Christian Lovis},
journal={arXiv:2406.10025},
year={2024}
}
Acknowledgements
The repository architecture was build on the initial template found here.