CaiT: Going deeper with Image Transformers

May 8, 2022 · View on GitHub

This repository contains PyTorch evaluation code, training code and pretrained models for:

  • DeiT (Data-Efficient Image Transformers), ICML 2021
  • CaiT (Going deeper with Image Transformers), ICCV 2021 (Oral)
  • ResMLP (ResMLP: Feedforward networks for image classification with data-efficient training)
  • PatchConvnet (Augmenting Convolutional networks with attention-based aggregation)
  • 3Things (Three things everyone should know about Vision Transformers)
  • DeiT III (DeiT III: Revenge of the ViT)

CaiT obtain competitive tradeoffs in terms of flops / precision:

For details see Going deeper with Image Transformers by Hugo Touvron, Matthieu Cord, Alexandre Sablayrolles, Gabriel Synnaeve and Hervé Jégou

If you use this code for a paper please cite:

@InProceedings{Touvron_2021_ICCV,
    author    = {Touvron, Hugo and Cord, Matthieu and Sablayrolles, Alexandre and Synnaeve, Gabriel and J\'egou, Herv\'e},
    title     = {Going Deeper With Image Transformers},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {32-42}
}

Model Zoo

We provide baseline CaiT models pretrained on ImageNet1k 2012 only, using the distilled version of our method.

nameacc@1resFLOPs#paramsurl
S2483.52249.4B47Mmodel
XS2484.138419.3B27Mmodel
S2485.138432.2B47Mmodel
S3685.438448.0B68Mmodel
M3686.1384173.3B271Mmodel
M4886.5448329.6B356Mmodel

The models are also available via torch hub. Before using it, make sure you have the pytorch-image-models package timm==0.3.2 by Ross Wightman installed.

Evaluation transforms

CaiT employs a slightly different pre-processing, in particular a crop-ratio of 1.0 at test time. To reproduce the results of our paper please use the following pre-processing:

def get_test_transforms(input_size):
    mean, std = [0.485, 0.456, 0.406],[0.229, 0.224, 0.225]    
    transformations = {}
    transformations= transforms.Compose(
        [transforms.Resize(input_size, interpolation=3),
         transforms.CenterCrop(input_size),
         transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    return transformations

Remark: for CaiT M48 it is best to evaluate with FP32 precision

Other: Unofficial Implementations

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Contributing

We actively welcome your pull requests! Please see CONTRIBUTING.md and CODE_OF_CONDUCT.md for more info.