SimKD

June 16, 2022 ยท View on GitHub

Knowledge Distillation with the Reused Teacher Classifier (CVPR-2022) https://arxiv.org/abs/2203.14001

Toolbox for KD research

This repository aims to provide a compact and easy-to-use implementation of several representative knowledge distillation approaches on standard image classification tasks (e.g., CIFAR100, ImageNet).

  • Generally, these KD approaches include a classification loss, a logit-level distillation loss, and an additional feature distillation loss. For fair comparison and ease of tuning, we fix the hyper-parameters for the first two loss terms as one throughout all experiments. (--cls 1 --div 1)

  • The following approaches are currently supported by this toolbox, covering vanilla KD, feature-map distillation/feature-embedding distillation, instance-level distillation/pairwise-level distillation:

  • This toolbox is built on a open-source benchmark and our previous repository. The implementation of more KD approaches can be found there.

  • Computing Infrastructure:

    • We use one NVIDIA GeForce RTX 2080Ti GPU for CIFAR-100 experiments. The PyTorch version is 1.0. We use four NVIDIA A40 GPUs for ImageNet experiments. The PyTorch version is 1.10.
    • As for ImageNet, we use DALI for data loading and pre-processing.
  • The current codes have been reorganized and we have not tested them thoroughly. If you have any questions, please contact us without hesitation.

  • Please put the CIFAR-100 and ImageNet dataset in the ../data/.

Get the pretrained teacher models

# CIFAR-100
python train_teacher.py --batch_size 64 --epochs 240 --dataset cifar100 --model resnet32x4 --learning_rate 0.05 --lr_decay_epochs 150,180,210 --weight_decay 5e-4 --trial 0 --gpu_id 0

# ImageNet
python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet18 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23333 --multiprocessing-distributed --dali gpu --trial 0 

The pretrained teacher models used in our paper are provided in this link [GoogleDrive].

Train the student models with various KD approaches

# CIFAR-100
python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0

# ImageNet
python train_student.py --path-t './save/teachers/models/ResNet50_vanilla/ResNet50_best.pth' --batch_size 256 --epochs 120 --dataset imagenet --model_s ResNet18 --distill simkd -c 0 -d 0 -b 1 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23444 --multiprocessing-distributed --dali gpu --trial 0 

More scripts are provided in ./scripts

Some results on CIFAR-100

ResNet-8x4VGG-8ShuffleNetV2x1.5
Student73.0970.4674.15
KD74.4272.7376.82
FitNet74.3272.9177.12
AT75.0771.9077.51
SP74.2973.1277.18
VID74.5573.1977.11
CRD75.5973.5477.66
SRRL75.3973.2377.55
SemCKD76.2375.2779.13
SimKD (f=8)76.7374.7478.96
SimKD (f=4)77.8875.6279.48
SimKD (f=2)78.0875.7679.54
Teacher (ResNet-32x4)79.4279.4279.42

result

(Left) The cross-entropy loss between model predictions and test labels.
(Right) The top-1 test accuracy (%) (Student: ResNet-8x4, Teacher: ResNet-32x4).

Citation

If you find this repository useful, please consider citing the following paper:

@inproceedings{chen2022simkd,
  title={Knowledge Distillation with the Reused Teacher Classifier},
  author={Chen, Defang and Mei, Jian-Ping and Zhang, Hailin and Wang, Can and Feng, Yan and Chen, Chun},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={11933--11942},
  year={2022}
}
@inproceedings{chen2021cross,
  author    = {Defang Chen and Jian{-}Ping Mei and Yuan Zhang and Can Wang and Zhe Wang and Yan Feng and Chun Chen},
  title     = {Cross-Layer Distillation with Semantic Calibration},
  booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
  pages     = {7028--7036},
  year      = {2021},
}