Momentum-SAM: Sharpness Aware Minimization without Computational Overhead

January 23, 2024 · View on GitHub

Official implementation of “Momentum-SAM: Sharpness Aware Minimization without Computational Overhead”.

How to Use

Import Optimizer to your code

Simply import the optimizer to your code

from optimizer.msam import MSAM
from optimizer.adamW_msam import AdamW_MSAM

and use it as a drop-in replacement for SGD or AdamW. If you are not decaying ρ\rho during your training, you should call optimizer.move_back_from_momentumAscent() at the end of your training to recover unperturbed parameters (see main.py).

Run Examples

Baselines:

python -m torch.distributed.run main.py --logSubDir CIFAR_WRN_baseline --ifile configs/CIFAR100_WRN16_4.ini 
python -m torch.distributed.run main.py --logSubDir CIFAR_ResNet_baseline --ifile configs/CIFAR100_ResNet50.ini 
python -m torch.distributed.run main.py --logSubDir ImageNet_ResNet_baseline --ifile configs/ImageNet_ResNet50.ini 
python -m torch.distributed.run main.py --logSubDir ImageNet_ViT_baseline --ifile configs/ImageNet_ViT.ini 

SAM[1]:

python -m torch.distributed.run main.py --logSubDir CIFAR_WRN_SAM --ifile configs/CIFAR100_WRN16_4.ini --optimizer SAM --rho 0.2
python -m torch.distributed.run main.py --logSubDir CIFAR_ResNet_SAM --ifile configs/CIFAR100_ResNet50.ini --optimizer SAM --rho 0.2
python -m torch.distributed.run main.py --logSubDir ImageNet_ResNet_SAM --ifile configs/ImageNet_ResNet50.ini --optimizer SAM --rho 0.2
python -m torch.distributed.run main.py --logSubDir ImageNet_ViT_SAM --ifile configs/ImageNet_ViT.ini --optimizer AdamW_SAM --rho 0.2

MSAM:

python -m torch.distributed.run main.py --logSubDir CIFAR_WRN_MSAM --ifile configs/CIFAR100_WRN16_4.ini --optimizer MSAM --rho 3
python -m torch.distributed.run main.py --logSubDir CIFAR_ResNet_MSAM --ifile configs/CIFAR100_ResNet50.ini --optimizer MSAM --rho 3
python -m torch.distributed.run main.py --logSubDir ImageNet_ResNet_MSAM --ifile configs/ImageNet_ResNet50.ini --optimizer MSAM --rho 3
python -m torch.distributed.run main.py --logSubDir ImageNet_ViT_MSAM --ifile configs/ImageNet_ViT.ini --optimizer AdamW_MSAM --rho 3

Additional supported optimizers: ESAM[2],lookSAM[3]

References

[1] Foret et al. 2021 “Sharpness-Aware Minimization for Efficiently Improving Generalization”
[2] Du et al. 2022 “Efficient Sharpness-Aware Minimization for Improved Training of Neural Networks”
[3] Liu et al. 2020 “Towards Efficient and Scalable Sharpness-Aware Minimization”