[CVPR 2024] MaxQ: Multi-Axis Query for N:M Sparsity Network (Paper Link)
March 5, 2024 · View on GitHub
Jingyang Xiang, Siqi Li, Junhao Chen, Zhuangzhi Chen, Tianxin Huang, Linpeng Peng, Yong Liu
Pytorch implementation of MaxQ in CVPR 2024.
Main code can be found in ./models/conv_type/mullt_axis_query.py

Introduction
In this paper, we propose an efficient and effective Multi-Axis Query methodology, dubbed as MaxQ, which employs a dynamic approach to generate soft N:M masks, considering the weight importance across multiple axes. Meanwhile, a sparsity strategy that gradually increases the percentage of N:M weight blocks is applied, which allows the network to heal from the pruning-induced damage progressively.
Prepare ImageNet1K
Create a data directory as a base for all datasets.
For example, if your base directory is /datadir/datasetthen imagenet would be located at /datadir/dataset/imagenet.
You should place train data and val data in /datadir/dataset/imagenet/train and /datadir/dataset/imagenet/val respectively.
Training on ImageNet1K
All scripts can be obtained in ./scripts/generate_scripts.py
python ./scripts/generate_scripts.py
python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90
Results on ImageNet1K
All models can be obtained in OpenI community. Many thanks to OpenI for the storage space!
| name | N for N:M Sparsity | M for N:M Sparsity | training epochs | use dali | Top-1 Accuracy | Top-5 Accuracy | model & log |
|---|---|---|---|---|---|---|---|
| resnet34 | 1 | 4 | 120 | ✘ | 74.2 | 91.7 | link |
| resnet34 | 2 | 4 | 120 | ✘ | 74.5 | 92.1 | link |
| resnet50 | 1 | 4 | 120 | ✘ | 77.3 | 93.4 | link |
| resnet50 | 1 | 16 | 200 | ✘ | 75.2 | 92.6 | link |
| resnet50 | 2 | 4 | 120 | ✘ | 77.6 | 93.7 | link |
| resnet50 | 2 | 8 | 120 | ✘ | 77.2 | 93.5 | link |
| mobilenetv1 | 1 | 4 | 120 | ✘ | 70.4 | 89.5 | link |
| mobilenetv1 | 2 | 4 | 120 | ✘ | 72.3 | 90.8 | link |
| mobilenetv2 | 1 | 4 | 120 | ✘ | 67.0 | 87.5 | link |
| mobilenetv2 | 2 | 4 | 120 | ✘ | 69.8 | 89.3 | link |
| mobilenetv3_small | 1 | 4 | 120 | ✘ | 55.3 | 78.9 | link |
| mobilenetv3_small | 2 | 4 | 120 | ✘ | 60.8 | 82.9 | link |
Testing on ImageNet1K
python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90 \
--pretrained [PRETRAINED_PATH] --evaluate
Testing log for 2:4 ResNet50 on ImageNet1K
[2024-03-02 14:05:24] Test: [0/782] Time 6.638 (6.638) Loss 1.3484 (1.3484) Prec@1 93.750 (93.750) Prec@5 98.438 (98.438)
[2024-03-02 14:05:36] Test: [100/782] Time 0.118 (0.187) Loss 1.4476 (1.6217) Prec@1 90.625 (83.261) Prec@5 96.875 (95.699)
[2024-03-02 14:05:49] Test: [200/782] Time 0.116 (0.155) Loss 1.6416 (1.6144) Prec@1 85.938 (83.225) Prec@5 96.875 (96.183)
[2024-03-02 14:06:01] Test: [300/782] Time 0.128 (0.145) Loss 1.5182 (1.6168) Prec@1 84.375 (82.942) Prec@5 95.312 (96.288)
[2024-03-02 14:06:13] Test: [400/782] Time 0.117 (0.140) Loss 1.5314 (1.7098) Prec@1 81.250 (80.556) Prec@5 98.438 (95.242)
[2024-03-02 14:06:26] Test: [500/782] Time 0.118 (0.137) Loss 1.3581 (1.7604) Prec@1 90.625 (79.416) Prec@5 98.438 (94.633)
[2024-03-02 14:06:38] Test: [600/782] Time 0.131 (0.135) Loss 1.6157 (1.8020) Prec@1 87.500 (78.557) Prec@5 96.875 (94.169)
[2024-03-02 14:06:51] Test: [700/782] Time 0.117 (0.133) Loss 1.8406 (1.8357) Prec@1 79.688 (77.739) Prec@5 95.312 (93.741)
[2024-03-02 14:07:01] * Prec@1 77.576 Prec@5 93.708 Error@1 22.424
Optional arguments
optional arguments:
# misc
--save_dir Path to save directory
# for model
--arch Choose model
default: resnet18
choice: ['resnet18', 'resnet34', 'resnet50', 'mobilenet_v1', 'mobilenet_v2', 'mobilenet_v3_small', 'mobilenet_v3_large']
--conv-bn-type convbn type for network
default: SoftMaxQConv2DBN
# for datatset
data Path to dataset
--set Choose dataset
default: ImageNet
choice: ["ImageNet", "ImageNetDali"]
# for pretrain, resume or evaluate
--evaluate Evaluate model on validation set
--pretrained Path to pretrained checkpoint
# N:M sparsity
--N N for N:M sparsity
default: 2
--M M for N:M sparsity
default: 4
--decay decay for SR-STE method
default: 0.0002
--decay-type decay type for conv type
default: v1
# MaxQ method
--increase-start Start epoch to increase ratio of N:M blocks
default: 0
--increase-end End epoch to increase ratio of N:M blocks
default: 90
--tau Tau for MaxQ method
default: 0.01
--prune-schedule Prune scheduler for incremental sparsity in MaxQ method
Dependencies
- Python 3.9.16
- Pytorch 2.0.0
- Torchvision 0.15.1
- nvidia-dali-nightly-cuda110 1.27.0.dev20230531
- nvidia-dali-tf-plugin-nightly-cuda110 1.27.0.dev20230531
THANKS
Special thanks to the authors and contributors of the following projects: