[CVPR 2024] MaxQ: Multi-Axis Query for N:M Sparsity Network (Paper Link)

March 5, 2024 · View on GitHub

Jingyang Xiang, Siqi Li, Junhao Chen, Zhuangzhi Chen, Tianxin Huang, Linpeng Peng, Yong Liu

Pytorch implementation of MaxQ in CVPR 2024.

Main code can be found in ./models/conv_type/mullt_axis_query.py

Introduction

In this paper, we propose an efficient and effective Multi-Axis Query methodology, dubbed as MaxQ, which employs a dynamic approach to generate soft N:M masks, considering the weight importance across multiple axes. Meanwhile, a sparsity strategy that gradually increases the percentage of N:M weight blocks is applied, which allows the network to heal from the pruning-induced damage progressively.

Prepare ImageNet1K

Create a data directory as a base for all datasets. For example, if your base directory is /datadir/datasetthen imagenet would be located at /datadir/dataset/imagenet. You should place train data and val data in /datadir/dataset/imagenet/train and /datadir/dataset/imagenet/val respectively.

Training on ImageNet1K

All scripts can be obtained in ./scripts/generate_scripts.py

python ./scripts/generate_scripts.py
python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90

Results on ImageNet1K

All models can be obtained in OpenI community. Many thanks to OpenI for the storage space!

nameN for N:M SparsityM for N:M Sparsitytraining epochsuse daliTop-1 AccuracyTop-5 Accuracymodel & log
resnet341412074.291.7link
resnet342412074.592.1link
resnet501412077.393.4link
resnet5011620075.292.6link
resnet502412077.693.7link
resnet502812077.293.5link
mobilenetv11412070.489.5link
mobilenetv12412072.390.8link
mobilenetv21412067.087.5link
mobilenetv22412069.889.3link
mobilenetv3_small1412055.378.9link
mobilenetv3_small2412060.882.9link

Testing on ImageNet1K

python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90 \
--pretrained [PRETRAINED_PATH] --evaluate

Testing log for 2:4 ResNet50 on ImageNet1K

[2024-03-02 14:05:24] Test: [0/782]	Time 6.638 (6.638)	Loss 1.3484 (1.3484)	Prec@1 93.750 (93.750)	Prec@5 98.438 (98.438)
[2024-03-02 14:05:36] Test: [100/782]	Time 0.118 (0.187)	Loss 1.4476 (1.6217)	Prec@1 90.625 (83.261)	Prec@5 96.875 (95.699)
[2024-03-02 14:05:49] Test: [200/782]	Time 0.116 (0.155)	Loss 1.6416 (1.6144)	Prec@1 85.938 (83.225)	Prec@5 96.875 (96.183)
[2024-03-02 14:06:01] Test: [300/782]	Time 0.128 (0.145)	Loss 1.5182 (1.6168)	Prec@1 84.375 (82.942)	Prec@5 95.312 (96.288)
[2024-03-02 14:06:13] Test: [400/782]	Time 0.117 (0.140)	Loss 1.5314 (1.7098)	Prec@1 81.250 (80.556)	Prec@5 98.438 (95.242)
[2024-03-02 14:06:26] Test: [500/782]	Time 0.118 (0.137)	Loss 1.3581 (1.7604)	Prec@1 90.625 (79.416)	Prec@5 98.438 (94.633)
[2024-03-02 14:06:38] Test: [600/782]	Time 0.131 (0.135)	Loss 1.6157 (1.8020)	Prec@1 87.500 (78.557)	Prec@5 96.875 (94.169)
[2024-03-02 14:06:51] Test: [700/782]	Time 0.117 (0.133)	Loss 1.8406 (1.8357)	Prec@1 79.688 (77.739)	Prec@5 95.312 (93.741)
[2024-03-02 14:07:01]  * Prec@1 77.576 Prec@5 93.708 Error@1 22.424

Optional arguments

optional arguments:
    # misc
    --save_dir                  Path to save directory
    
    # for model                         
    --arch                      Choose model
                                default: resnet18
                                choice: ['resnet18', 'resnet34', 'resnet50', 'mobilenet_v1', 'mobilenet_v2', 'mobilenet_v3_small', 'mobilenet_v3_large']
    --conv-bn-type              convbn type for network
                                default: SoftMaxQConv2DBN
      
    # for datatset
    data                        Path to dataset
    --set                       Choose dataset
                                default: ImageNet
                                choice: ["ImageNet", "ImageNetDali"]   
                                                 
    # for pretrain, resume or evaluate
    --evaluate                  Evaluate model on validation set
    --pretrained                Path to pretrained checkpoint

    # N:M sparsity                            
    --N                         N for N:M sparsity 
                                default: 2  
    --M                         M for N:M sparsity 
                                default: 4
    --decay                     decay for SR-STE method
                                default: 0.0002
    --decay-type                decay type for conv type
                                default: v1

    # MaxQ method                            
    --increase-start            Start epoch to increase ratio of N:M blocks
                                default: 0
    --increase-end              End epoch to increase ratio of N:M blocks
                                default: 90
    --tau                       Tau for MaxQ method
                                default: 0.01
    --prune-schedule            Prune scheduler for incremental sparsity in MaxQ method

Dependencies

  • Python 3.9.16
  • Pytorch 2.0.0
  • Torchvision 0.15.1
  • nvidia-dali-nightly-cuda110 1.27.0.dev20230531
  • nvidia-dali-tf-plugin-nightly-cuda110 1.27.0.dev20230531

THANKS

Special thanks to the authors and contributors of the following projects: