Dataset Reinforcement

September 22, 2023 ยท View on GitHub

A light-weight implementation of Dataset Reinforcement, pretrained checkpoints, and reinforced datasets.

Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar, M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.

Update 2023/09/22: Table 7-Average column corrected in ArXiv V3. Correct numbers: 30.4, 37.1, 37.9, 43.7, 39.6, 51.1.

Reinforced ImageNet, ImageNet+, improves accuracy at similar iterations/wall-clock

Reinforced ImageNet, ImageNet, improves accuracy at similar iterations/wall-clock.

ImageNet validation accuracy of ResNet-50 is shown as a function of training duration with (1) ImageNet dataset, (2) knowledge distillation (KD), and (3) ImageNet+ dataset (ours). Each point is a full training with epochs varying from 50-1000. An epoch has the same number of iterations for ImageNet/ImageNet+.

Illustration of Dataset Reinforcement

Illustration of Dataset Reinforcement.

Data augmentation and knowledge distillation are common approaches to improving accuracy. Dataset reinforcement combines the benefits of both by bringing the advantages of large models trained on large datasets to other datasets and models. Training of new models with a reinforced dataset is as fast as training on the original dataset for the same total iterations. Creating a reinforced dataset is a one-time process (e.g., ImageNet to ImageNet+) the cost of which is amortized over repeated uses.

Requirements

Install the requirements using:

   pip install -r requirements.txt

We support loading models from Timm library and CVNets library.

To install CVNets library follow their installation instructions.

Reinforced Data

The following is a list of reinforcements for ImageNet/CIFAR-100/Food-101/Flowers-102. We recommend ImageNet+-RA/RE based on the analysis in the paper.

Reinforce DataTask IDSize (GBs)Comments
ImageNet+-RRCrdata33.4[NS=400]
ImageNet+-+M*rdata46.3[NS=400]
ImageNet+-+RA/RErdata37.5[NS=400]
ImageNet+-+M*+R*rdata53.3[NS=400]
ImageNet+-RRC-Smallrdata4.7[NS=100, K=5]
ImageNet+-+M*-Smallrdata7.8[NS=100, K=5]
ImageNet+-+RA/RE-Smallrdata5.6[NS=100, K=5]
ImageNet+-+M*+R*-Smallrdata9.4[NS=100, K=5]
ImageNet+-RRC-Minirdata4.4[NS=50]
ImageNet+-+M*-Minirdata6.1[NS=50]
ImageNet+-+RA/RE-Minirdata4.9[NS=50]
ImageNet+-+M*+R*-Minirdata7.0[NS=50]
CIFAR-100rdata2.5[NS=800]
Food-101rdata4.2[NS=800]
Flowers-102rdata0.5[NS=8000]

Pretrained Checkpoints

CVNets Checkpoints

We provide pretrained checkpoints for various models in CVNets. The accuracies can be verified using the CVNets library.

Selected results trained for 1000 epochs:

NameModeParamsImageNetImageNet+ImageNet (EMA)ImageNet+ (EMA)Links
MobileNetV3large5.5M74.877.9 (+3.1)75.877.9 (+2.1)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
ResNet5025.6M80.082.0 (+2.0)80.182.0 (+1.9)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
ViTbase86.7M76.885.1 (+8.3)80.885.1 (+4.3)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
ViT-384base86.7M79.485.4 (+6.0)83.185.5 (+2.4)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swintiny28.3M81.384.0 (+2.7)80.583.5 (+3.0)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swinsmall49.7M81.385.0 (+3.7)81.984.5 (+2.6)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swinbase87.8M81.585.4 (+3.9)81.885.2 (+3.4)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swin-384base87.8M83.685.8 (+2.2)83.885.5 (+1.7)[best.pt] [ema_best.pt] [config.yaml] [metrics.jb]

Timm Checkpoints

We provide pretrained checkpoints for ResNet50d from Timm library trained for 150 epochs using various reinforced datasets:

  • imagenet-timm.tar: All Timm checkpoints trained on ImageNet and ImageNet+ (2.3GBs).
ModelReinforce DataAccuracyLinks
ResNet50d [ERM]N/A78.9[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-RRC80.0[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+M*80.5[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+RA/RE80.4[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+M*+R*80.2[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-RRC-Small80.0[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+M*-Small80.6[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+RA/RE-Small80.2[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+M*+R*-Small80.1[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-RRC-Mini80.1[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+M*-Mini80.5[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+RA/RE-Mini80.4[best.pt] [config.yaml] [metrics.jb]
ResNet50dImageNet+-+M*+R*-Mini80.2[best.pt] [config.yaml] [metrics.jb]

Training

We provide YAML configurations for training ResNet-50 in CFG_FILE=configs/${DATASET}/${TRAINER}.yaml, with the following options:

  • DATASET: imagenet, cifar100, flowers102, and food101.
  • TRAINER: standard training (erm), knowledge distillation (kd), and with reinforced data (plus).

Follow the steps:

  • Choose the dataset and trainer from the choices above.
  • Download ImageNet data and set data_path in $CFG_FILE.
  • Download reinforcement metadata and set reinforce.data_path in $CFG_FILE.
python train.py --config configs/imagenet/erm.yaml  # ImageNet training without Reinforcements (ERM)
python train.py --config configs/imagenet/kd.yaml  # Knowledge Distillation
python train.py --config configs/imagenet/plus.yaml  # ImageNet+ training with reinforcements

Hyperparameters such as batch size for ImageNet training are optimized for running on a single node with 8xA100 40GB GPUs. For CIFAR-100/Flowers-102/Food-101, the configurations are optimized for training on a single GPU.

Reinforce ImageNet

Follow the steps:

  • Download ImageNet data and set data_path in $CFG_FILE.
  • If needed, change the teacher in $CFG_FILE to a smaller architecture.
python reinforce.py --config configs/imagenet/reinforce/randaug.yaml

Reference

If you found this code useful, please cite the following paper:

@InProceedings{faghri2023reinforce,
    author    = {Faghri, Fartash and Pouransari, Hadi and Mehta, Sachin and Farajtabar, Mehrdad and Farhadi, Ali and Rastegari, Mohammad and Tuzel, Oncel},
    title     = {Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
}

License

This sample code is released under the LICENSE terms.