ProtoS-ViT

June 18, 2024 ยท View on GitHub

pdm Framework lightning hydra

ProtoS-ViT

This is the implementation of ProtoS-Vit along the ressources required to evaluate the model. ProtoS-Vit is a novel architecture to turn any frozen ViT backbone into a prototypical model as shown in the figure below:

Model architecture

The model is evaluated both in terms of classification performance and explanability, outperforming current state-of-the art models across a range of metrics:

Installation

A dockerfile is provided if you want to easily run the project with pdm and the required packages.

Get Started

If you are running the container, you can start training your model with:

pdm run train_classification # train a model on the CUB dataset

The CUB dataset will be automatically saved in /data/.

Configurations are managed with hydra. Pre-made configurations for a range of experiments can be found in /config/experiments but you can also design your own experiment based on these examples. If you want to run your own experiment simply create a new my_experiment.yaml file in /config/experiments and then run:

python src/main_train experiment=my_experiment # train a model with your own experiment configuration

Configuration files for all experiments presented in the paper can be found under /config/experiments.

Datasets

General Datasets

The CUB and PETS dataset can be downloaded by setting download=True in the dataloader argument, i.e hydra config.

Stanford Cars

The download URL provided as part of the StanfordCars in torchvision is currently broken. The dataset can be downloaded using the following instructions.

Funny Birds

The Funny birds dataset can be downloaded from the initial repository with the following commands:

cd /path/to/dataset/
wget download.visinf.tu-darmstadt.de/data/funnybirds/FunnyBirds.zip
unzip FunnyBirds.zip
rm FunnyBirds.zip

Biomedical Datasets

The three biomedical datasets: ISIC 2019, RSNA, LC25000 (Lungs), use a random split between the training and test set. The split for each dataset is provided in /data/dataset_name.

ISIC 2019

Classification of skin lesions across nine different diagnostic categories. The dataset can be dowloaded from here and copied into data/isic_2019.

RSNA

Binary classification of chest x-rays for the presence of pneumonia cases. Data can be downloaded from Kaggle.

LC25000

Classification of lung and colon histopathological images. In this work, we focus on the lung dataset which aim to classifiy the images across three different classes. The dataset can be downloaded from here.

Explainability evaluation

The quantitative evaluation of the model's explainability relies on the FunnyBirds dataset as well as some of the metrics presented in the paper introducing this dataset by Hesse et al. The metrics presented by Hesse et al. can be computed for a model trained on the FunnyBirds dataset as follows:

python src/evaluation/main_evaluate_funny_birds  --path_sim=path_sim # with the path to the folder where the trained model is saved.

Citation

If you find this code or idea useful, please consider citing our work:

@article{turbe2024protosvit,
  title={ProtoS-ViT: Visual foundation models for sparse self-explainable classifications},
  author={Hugues Turb\'{e} and Mina Bjelogrlic and Gianmarco Mengaldo and Christian Lovis},
  journal={arXiv:2406.10025},
  year={2024}
}

Acknowledgements

The repository architecture was build on the initial template found here.