B-cos Networks v2

March 3, 2024 · View on GitHub

DOI | arXiv | code

B-cos Alignment for Inherently Interpretable CNNs and Vision Transformers

Moritz Böhle, Navdeeppal Singh, Mario Fritz, Bernt Schiele. TPAMI, 2024.

Table of Contents

Introduction

This repository contains the code for the B-cos v2 models.

These models are more efficient and easier to train than the original v1 B-cos models. Furthermore, we make a large number of pretrained B-cos models available for use.

If you want to take a quick look at the explanations the models generate, you can try out the Gradio web demo on Hugging Face Spaces.

If you prefer a more hands-on approach, you can take a look at the demo notebook on Colab or load the models directly via torch.hub as explained below.

If you simply want to copy the model definitions, we provide a minimal, single-file reference implementation including explanation mode in extra/minimal_bcos_resnet.py!

UPDATE: We have also released our ViT models! See Model Zoo.

Quick Start

You only need to make sure you have torch and torchvision installed.

Then, loading the models via torch.hub is as easy as:

import torch

# list all available models
torch.hub.list('B-cos/B-cos-v2')

# load a pretrained model
model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)

Inference and explanation visualization is as simple as:

from PIL import Image
import matplotlib.pyplot as plt

# load image
img = model.transform(Image.open('cat.jpg'))
img = img[None].requires_grad_()

# predict and explain
model.eval()
expl_out = model.explain(img)
print("Prediction:", expl_out["prediction"])  # predicted class idx
plt.imshow(expl_out["explanation"])
plt.show()

Each of the models has its inference transform attached to it, accessible via model.transform. Furthermore, each model has a .explain() method that takes an image tensor and returns a dictionary containing the prediction and the explanation, and some extras.

See the demo notebook for more details on the .explain() method.

Furthermore, each model has a get_classifier and get_feature_extractor method that return the classifier and feature extractor modules respectively. These can useful for fine-tuning the models!

Installation

Depending on your use case, you can either install the bcos package or set up the development environment for training the models (for your custom models or for reproducing the results).

bcos Package

If you are simply interested in using the models (pretrained or otherwise), then we provide a bcos package that can be installed via pip:

pip install bcos

This contains the models, their modules, transforms, and other utilities making it easy to use and build B-cos models. Take a look at the public API here. (I'll add a proper docs site if I have time or there's enough interest. Nonetheless, I have tried to keep the code well-documented, so it should be easy to follow.)

Training Environment Setup

If you want to train your own B-cos models using this repository or are interested in reproducing the results, you can set up the development environment as follows:

Using conda (recommended, especially if you want to reproduce the results):

conda env create -f environment.yml
conda activate bcos

Using pip

pip install -r requirements-train.txt

Setting Data Paths

You can either set the paths in bcos/settings.py or set the environment variables

  1. DATA_ROOT
  2. IMAGENET_PATH

to the paths of the data directories.

The DATA_ROOT environment variable should point to the data root directory for CIFAR-10 (will be automatically downloaded). For ImageNet, the IMAGENET_PATH environment variable should point to the directory containing the train and val directories.

Usage

For the bcos package, as mentioned earlier, take a look at the public API here.

For evaluating or training the models, you can use the evaluate.py and train.py scripts, as follows:

Evaluation

You can use evaluate the accuracy of the models on the ImageNet validation set using:

python evaluate.py --dataset ImageNet --hubconf resnet18

This will download the model from torch.hub and evaluate it on the ImageNet validation set. The default batch size is 1, but you can change it using the --batch-size argument. Replace resnet18 with any of the other models listed in Model Zoo that you wish to evaluate.

Training

Short version:

python train.py \
  --dataset ImageNet \
  --base_network bcos_final \
  --experiment_name resnet18

Long version: See TRAINING.md for more details on how the setup works and how to train your own models.

Model Zoo

Here are the ImageNet pre-trained models available in the model zoo. You can find the links to the model weights below (uploaded to the Weights GitHub release).

Model/EntrypointTop-1 AccuracyTop-5 Accuracy#ParamsDownload
resnet1868.736%87.430%11.69Mlink
resnet3472.284%90.052%21.80Mlink
resnet5075.882%92.528%25.52Mlink
resnet10176.532%92.538%44.50Mlink
resnet15276.484%92.398%60.13Mlink
resnext50_32x4d75.820%91.810%25.00Mlink
densenet12173.612%91.106%7.95Mlink
densenet16176.622%92.554%28.58Mlink
densenet16975.186%91.786%14.08Mlink
densenet20175.480%91.992%19.91Mlink
vgg11_bnu69.310%88.388%132.86Mlink
convnext_tiny77.488%93.192%28.54Mlink
convnext_base79.650%94.614%88.47Mlink
convnext_tiny_bnu76.826%93.090%28.54Mlink
convnext_base_bnu80.142%94.834%88.47Mlink
densenet121_long77.302%93.234%7.95Mlink
resnet50_long79.468%94.452%25.52Mlink
resnet152_long80.144%94.116%60.13Mlink
simple_vit_ti_patch16_22459.960%81.838%5.80Mlink
simple_vit_s_patch16_22469.246%88.096%22.28Mlink
simple_vit_b_patch16_22474.408%91.156%86.90Mlink
simple_vit_l_patch16_22475.060%91.378%178.79Mlink
vitc_ti_patch1_1467.260%86.774%5.32Mlink
vitc_s_patch1_1474.504%91.288%20.88Mlink
vitc_b_patch1_1477.152%92.926%81.37Mlink
vitc_l_patch1_1477.782%92.966%167.44Mlink
standard_simple_vit_ti_patch16_22470.230%89.380%5.67Mlink
standard_simple_vit_s_patch16_22474.470%91.226%21.96Mlink
standard_simple_vit_b_patch16_22475.300%91.026%86.38Mlink
standard_simple_vit_l_patch16_22475.710%90.050%178.10Mlink
standard_vitc_ti_patch1_1472.590%90.788%5.33Mlink
standard_vitc_s_patch1_1475.756%91.994%20.91Mlink
standard_vitc_b_patch1_1476.790%92.024%81.39Mlink
standard_vitc_l_patch1_1477.866%92.298%167.54Mlink

You can find these entrypoints in bcos/models/pretrained.py.

License

This repository's code is licensed under the Apache License 2.0 which you can find in the LICENSE file.

The pre-trained models are trained on ImageNet (and are hence derived from it), which is licensed under the ImageNet Terms of access, which among others things, only allows non-commercial use of the dataset. It is therefore your responsibility to check whether you have permission to use the pre-trained models for your use case.

Citation

@article{Boehle2024TPAMI,
  author={Böhle, Moritz and Singh, Navdeeppal and Fritz, Mario and Schiele, Bernt},
  title = {B-cos Alignment for Inherently Interpretable CNNs and Vision Transformers},
  journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
  year = {2024},
  pages = {1-15},
  doi = {10.1109/TPAMI.2024.3355155},
}