SSLD 知识蒸馏实战

September 21, 2022 · View on GitHub

目录

1. 算法介绍

1.1 简介

PaddleClas 融合已有的知识蒸馏方法 [2,3],提供了一种简单的半监督标签知识蒸馏方案(SSLD,Simple Semi-supervised Label Distillation),基于 ImageNet1k 分类数据集,在 ResNet_vd 以及 MobileNet 系列上的精度均有超过 3% 的绝对精度提升,具体指标如下图所示。

1.2 SSLD蒸馏策略

SSLD 的流程图如下图所示。

首先,我们从 ImageNet22k 中挖掘出了近 400 万张图片,同时与 ImageNet-1k 训练集整合在一起,得到了一个新的包含 500 万张图片的数据集。然后,我们将学生模型与教师模型组合成一个新的网络,该网络分别输出学生模型和教师模型的预测分布,与此同时,固定教师模型整个网络的梯度,而学生模型可以做正常的反向传播。最后,我们将两个模型的 logits 经过 softmax 激活函数转换为 soft label,并将二者的 soft label 做 JS 散度作为损失函数,用于蒸馏模型训练。

以 MobileNetV3(该模型直接训练,精度为 75.3%)的知识蒸馏为例,该方案的核心策略优化点如下所示。

实验ID策略Top-1 acc
1baseline75.60%
2更换教师模型精度为82.4%的权重76.00%
3使用改进的JS散度损失函数76.20%
4迭代轮数增加至360epoch77.10%
5添加400W挖掘得到的无标注数据78.50%
6基于ImageNet1k数据微调78.90%
  • 注:其中baseline的训练条件为
    • 训练数据:ImageNet1k数据集
    • 损失函数:Cross Entropy Loss
    • 迭代轮数:120epoch

SSLD 蒸馏方案的一大特色就是无需使用图像的真值标签,因此可以任意扩展数据集的大小,考虑到计算资源的限制,我们在这里仅基于 ImageNet22k 数据集对蒸馏任务的训练集进行扩充。在 SSLD 蒸馏任务中,我们使用了 Top-k per class 的数据采样方案 [3] 。具体步骤如下。

(1)训练集去重。我们首先基于 SIFT 特征相似度匹配的方式对 ImageNet22k 数据集与 ImageNet1k 验证集进行去重,防止添加的 ImageNet22k 训练集中包含 ImageNet1k 验证集图像,最终去除了 4511 张相似图片。部分过滤的相似图片如下所示。

(2)大数据集 soft label 获取,对于去重后的 ImageNet22k 数据集,我们使用 ResNeXt101_32x16d_wsl 模型进行预测,得到每张图片的 soft label 。

(3)Top-k 数据选择,ImageNet1k 数据共有 1000 类,对于每一类,找出属于该类并且得分最高的 k 张图片,最终得到一个数据量不超过 1000*k 的数据集(某些类上得到的图片数量可能少于 k 张)。

(4)将该数据集与 ImageNet1k 的训练集融合组成最终蒸馏模型所使用的数据集,数据量为 500 万。

1.3 SKL-UGI蒸馏策略

此外,在无标注数据选择的过程中,我们发现使用更加通用的数据,即使不需要严格的数据筛选过程,也可以帮助知识蒸馏任务获得稳定的精度提升,因而提出了SKL-UGI (Symmetrical-KL Unlabeled General Images distillation)知识蒸馏方案。

通用数据可以使用ImageNet数据或者与场景相似的数据集。更多关于SKL-UGI的应用,请参考:超轻量图像分类方案PULC使用教程

2. 预训练模型库

移动端预训练模型库列表如下所示。

模型FLOPs(M)Params(M)top-1 accSSLD top-1 acc精度收益下载链接
PPLCNetV2_base604.166.5477.04%80.10%+3.06%链接
PPLCNet_x2_5906.499.0476.60%80.82%+4.22%链接
PPLCNet_x1_0160.812.9671.32%74.39%+3.07%链接
PPLCNet_x0_547.281.8963.14%66.10%+2.96%链接
PPLCNet_x0_2518.431.5251.86%53.43%+1.57%链接
MobileNetV1578.884.1971.00%77.90%+6.90%链接
MobileNetV2327.843.4472.20%76.74%+4.54%链接
MobileNetV3_large_x1_0229.665.4775.30%79.00%+3.70%链接
MobileNetV3_small_x1_063.672.9468.20%71.30%+3.10%链接
MobileNetV3_small_x0_3514.561.6653.00%55.60%+2.60%链接
GhostNet_x1_3_ssld236.897.3075.70%79.40%+3.70%链接
  • 注:其中的top-1 acc表示使用普通训练方式得到的模型精度,SSLD top-1 acc表示使用SSLD知识蒸馏训练策略得到的模型精度。

服务端预训练模型库列表如下所示。

模型FLOPs(G)Params(M)top-1 accSSLD top-1 acc精度收益下载链接
PPHGNet_base25.1471.62-85.00%-链接
PPHGNet_small8.5324.3881.50%83.80%+2.30%链接
PPHGNet_tiny4.5414.7579.83%81.95%+2.12%链接
ResNet50_vd8.6725.5879.10%83.00%+3.90%链接
ResNet101_vd16.144.5780.20%83.70%+3.50%链接
ResNet34_vd7.3921.8276.00%79.70%+3.70%链接
Res2Net50_vd_26w_4s8.3725.0679.80%83.10%+3.30%链接
Res2Net101_vd_26w_4s16.6745.2280.60%83.90%+3.30%链接
Res2Net200_vd_26w_4s31.4976.2181.20%85.10%+3.90%链接
HRNet_W18_C4.1421.2976.90%81.60%+4.70%链接
HRNet_W48_C34.5877.4779.00%83.60%+4.60%链接
SE_HRNet_W64_C57.83128.97-84.70%-链接

3. SSLD使用方法

3.1 加载SSLD模型进行微调

如果希望直接使用预训练模型,可以在训练的时候,加入参数-o Arch.pretrained=True -o Arch.use_ssld=True,表示使用基于SSLD的预训练模型,示例如下所示。

# 单机单卡训练
python3 tools/train.py -c ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml -o Arch.pretrained=True -o Arch.use_ssld=True
# 单机多卡训练
python3 -m paddle.distributed.launch --gpus="0,1,2,3" tools/train.py -c ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml -o Arch.pretrained=True -o Arch.use_ssld=True

3.2 使用SSLD方案进行知识蒸馏

相比于其他大多数知识蒸馏算法,SSLD摆脱对数据标注的依赖,通过引入无标注数据,可以进一步提升模型精度。

对于无标注数据,需要按照与有标注数据完全相同的整理方式,将文件与当前有标注的数据集放在相同目录下,将其标签值记为0,假设整理的标签文件名为train_list_unlabel.txt,则可以通过下面的命令生成用于SSLD训练的标签文件。

cat train_list.txt train_list_unlabel.txt > train_list_all.txt

更多关于图像分类任务的数据标签说明,请参考:PaddleClas图像分类数据集格式说明

PaddleClas中集成了PULC超轻量图像分类实用方案,里面包含SSLD ImageNet预训练模型的使用以及更加通用的无标签数据的知识蒸馏方案,更多详细信息,请参考PULC超轻量图像分类实用方案使用教程

4. 参考文献

[1] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.

[2] Bagherinezhad H, Horton M, Rastegari M, et al. Label refinery: Improving imagenet classification through label progression[J]. arXiv preprint arXiv:1805.02641, 2018.

[3] Yalniz I Z, Jégou H, Chen K, et al. Billion-scale semi-supervised learning for image classification[J]. arXiv preprint arXiv:1905.00546, 2019.

[4] Touvron H, Vedaldi A, Douze M, et al. Fixing the train-test resolution discrepancy[C]//Advances in Neural Information Processing Systems. 2019: 8250-8260.