Knowledge Distillation from A Stronger Teacher (DIST)

December 28, 2022 · View on GitHub

Official implementation of paper "Knowledge Distillation from A Stronger Teacher" (DIST), NeurIPS 2022.
By Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu.

:fire: DIST: a simple and effective KD method.

Updates

  • December 27, 2022: Update CIFAR-100 distillation code and logs.

  • September 20, 2022: Release code for semantic segmentation task.

  • September 15, 2022: DIST was accepted by NeurIPS 2022!

  • May 30, 2022: Code for object detection is available.

  • May 27, 2022: Code for ImageNet classification is available.

Getting started

Clone training code

git clone https://github.com/hunto/DIST_KD.git --recurse-submodules
cd DIST_KD

The loss function of DIST is in classification/lib/models/losses/dist_kd.py.

  • classification: prepare your environment and datasets following the README.md in classification.
  • object detection: coming soon.
  • semantic segmentation: coming soon.

Reproducing our results

ImageNet

cd classification
sh tools/dist_train.sh 8 ${CONFIG} ${MODEL} --teacher-model ${T_MODEL} --experiment ${EXP_NAME}
  • Baseline settings (R34-R101 and R50-MBV1):

    CONFIG=configs/strategies/distill/resnet_dist.yaml
    
    StudentTeacherDISTMODELT_MODELLogCkpt
    ResNet-18 (69.76)ResNet-34 (73.31)72.07tv_resnet18tv_resnet34logckpt
    MobileNet V1 (70.13)ResNet-50 (76.16)73.24mobilenet_v1tv_resnet50logckpt
  • Stronger teachers (R18 and R34 students with various ResNet teachers):

    StudentTeacherKD (T=4)DIST
    ResNet-18 (69.76)ResNet-34 (73.31)71.2172.07
    ResNet-18 (69.76)ResNet-50 (76.13)71.3572.12
    ResNet-18 (69.76)ResNet-101 (77.37)71.0972.08
    ResNet-18 (69.76)ResNet-152 (78.31)71.1272.24
    ResNet-34 (73.31)ResNet-50 (76.13)74.7375.06
    ResNet-34 (73.31)ResNet-101 (77.37)74.8975.36
    ResNet-34 (73.31)ResNet-152 (78.31)74.8775.42
  • Stronger training strategies:

    CONFIG=configs/strategies/distill/dist_b2.yaml
    

    ResNet-50-SB: stronger ResNet-50 trained by TIMM (ResNet strikes back) .

    StudentTeacherKD (T=4)DISTMODELT_MODELLog
    ResNet-18 (73.4)ResNet-50-SB (80.1)72.674.5tv_resnet18timm_resnet50log
    ResNet-34 (76.8)ResNet-50-SB (80.1)77.277.8tv_resnet34timm_resnet50log
    MobileNet V2 (73.6)ResNet-50-SB (80.1)71.774.4tv_mobilenet_v2timm_resnet50log
    EfficientNet-B0 (78.0)ResNet-50-SB (80.1)77.478.6
    timm_tf_efficientnet_b0
    timm_resnet50log
    ResNet-50 (78.5)Swin-L (86.3)80.080.2tv_resnet50
    timm_swin_large_patch4_window7_224
    log ckpt
    Swin-T (81.3)Swin-L (86.3)81.582.3--log
    • Swin-L student: We implement our DIST on the official code of Swin-Transformer.

CIFAR-100

Download and extract the teacher checkpoints to your disk, then specify the path of the corresponding checkpoint pth file using --teacher-ckpt:

cd classification
sh tools/dist_train.sh 1 configs/strategies/distill/dist_cifar.yaml ${MODEL} --teacher-model ${T_MODEL} --experiment ${EXP_NAME} --teacher-ckpt ${CKPT}

NOTE: For MobileNetV2, ShuffleNetV1, and ShuffleNetV2, lr and warmup-lr should be 0.01:

sh tools/dist_train.sh 1 configs/strategies/distill/dist_cifar.yaml ${MODEL} --teacher-model ${T_MODEL} --experiment ${EXP_NAME} --teacher-ckpt ${CKPT} --lr 0.01 --warmup-lr 0.01
StudentTeacherDISTMODELT_MODELLog
WRN-40-1 (71.98)WRN-40-2 (75.61)74.43±0.24cifar_wrn_40_1cifar_wrn_40_2log
ResNet-20 (69.06)ResNet-56 (72.34)71.75±0.30cifar_resnet20cifar_resnet56log
ResNet-8x4 (72.50)ResNet-32x4 (79.42)76.31±0.19cifar_resnet8x4cifar_resnet32x4log
MobileNetV2 (64.60)ResNet-50 (79.34)68.66±0.23cifar_mobile_halfcifar_ResNet50log
ShuffleNetV1 (70.50)ResNet-32x4 (79.42)76.34±0.18cifar_ShuffleV1cifar_resnet32x4log
ShuffleNetV2 (71.82)ResNet-32x4 (79.42)77.35±0.25cifar_ShuffleV2cifar_resnet32x4log

COCO Detection

The training code is in MasKD/mmrazor. An example to train cascade_mask_rcnn_x101-fpn_r50:

sh tools/mmdet/dist_train_mmdet.sh configs/distill/dist/dist_cascade_mask_rcnn_x101-fpn_x50_coco.py 8 work_dirs/dist_cmr_x101-fpn_x50
StudentTeacherDISTDIST+mimicConfigLog
Faster RCNN-R50 (38.4)Cascade Mask RCNN-X101 (45.6)40.441.8[DIST] [DIST+Mimic][DIST] [DIST+Mimic]
RetinaNet-R50 (37.4)RetinaNet-X101 (41.0)39.840.1[DIST] [DIST+Mimic][DIST] [DIST+Mimic]

Cityscapes Segmentation

Detailed instructions of reproducing our results are in segmentation folder (README).

StudentTeacherDISTLog
DeepLabV3-R18 (74.21)DeepLabV3-R101 (78.07)77.10log
PSPNet-R18 (72.55)DeepLabV3-R101 (78.07)76.31log

License

This project is released under the Apache 2.0 license.

Citation

@article{huang2022knowledge,
  title={Knowledge Distillation from A Stronger Teacher},
  author={Huang, Tao and You, Shan and Wang, Fei and Qian, Chen and Xu, Chang},
  journal={arXiv preprint arXiv:2205.10536},
  year={2022}
}