NTK-SAP: Improving neural network pruning by aligning training dynamics
May 1, 2023 ยท View on GitHub
Yite Wang, Dawei Li, Ruoyu Sun
In ICLR 2023.
Overview
This is the PyTorch implementation of NTK-SAP: Improving neural network pruning by aligning training dynamics.
Installation
To run our code, then install all dependencies
pip install -r requirements.txt
Running
Below is a description of the major sections of the code base. Run python main.py --help for a complete description of flags and hyperparameters.
1. Prepare the datasets
MNIST, CIFAR-10, CIFAR-100, Tiny ImageNet will be downloaded automatically. For ImageNet experiment, please download it to Data/imagenet_raw/, or change corresponding path in Utils/load.py.
2. Run foresight pruning experiments
Note experiments of ImageNet requires running code to prune and train separately, see the argument experiment. For other experiments, models will be trained right after pruning. We include a few important arguments:
--experiment: For CIFAR-10, CIFAR-100, and Tiny-ImageNet experiments, you can either usesingleshotormultishot. For ImageNet experiment, please usemultishot_ddp_pruneto get mask then train withmultishot_ddp_train.--dataset: Which dataset to use, to reproduce our results, usecifar10,cifar100,tiny-imagenet, andimagenet.--model-class: For CIFAR-10 and CIFAR-100 experiments, please uselottery. For Tiny-imagenet and ImageNet experiments, please useimagenet.--model: Which model architecture to use. In our experiments, we useresnet20,vgg16-bn,resnet18, andresnet50.--pruner: Which pruning algorithms to use, choose from:rand,mag,snip,grasp,synflow,itersnip,NTKSAP.--prune-batch-size: Batch size of pruning datasets.--compression: You can use this argument to change sparsity forsingleshotexperiments. Specifically, the target density will be $0.8^{\text{compression}}$. Formultishotexperiments, please refer to--compression-list.--prune-train-mode: Set this toTrueif you use pruning algorithms except Synflow.--prune-epochs: Number of pruning iterations .--ntksap_R: Number of resampling procedures, only change this for CIFAR-10 experiment.--ntk_epsilon: Perturbation hyper-parameter used in NTK-SAP.
A sample script can be found in scripts/run.sh.
Acknowledgement
Our code is developed based on the Synflow code: https://github.com/ganguli-lab/Synaptic-Flow.