AEGDM
May 28, 2021 ยท View on GitHub
This repository contains code to reproduce the experiments in "AEGDM: Adaptive gradient descent with energy and momentum".
Usage
The aegdm.py file provides a PyTorch implementation of AEGDM,
optimizer = aegdm.AEGDM(model.parameters(), lr=0.02)
Examples on CIFAR-10 and CIFAR-100
We test AEGDM on the standard CIFAR-10 and CIFAR-100 image classification tasks, comparing with SGD with momentum (SGDM), Adam and AEGD. We also provide a notebook to present our results for this example.
Supported models for CIFAR-10 are VGG, ResNet, DenseNet and CifarNet, for CIFAR-100 are SqueezeNet and GoogleNet.
For VGG, the weight decay is set as 5e-4; for other architectures, the weight decay is set as 1e-4.
For DenseNet, the batch size is set as 64; for other architectures, the batch size is set as 128. The initial set of learning rate for each optimizer are:
- SGDM: {0.03, 0.05, 0.1, 0.2, 0.3}
- Adam: {0.0001, 0.0003, 0.0005, 0.001, 0.002}
- AEGD: {0.1, 0.2, 0.3, 0.4}
- AEGDM: {0.005, 0.008, 0.01, 0.02, 0.03}
The best base learning rate for each method in a certain task can be found in curve/pretrained fold to ease your reproduction.
Followings are examples to train ResNet-32 on CIFAR-10 using AEGDM with a learning rate of 0.008
python cifar.py --dataset cifar10 --model resnet32 --optim AEGDM --lr 0.008
and train SqueezeNet on CIFAR-100 using AEGDM with a learning rate of 0.02
python cifar.py --dataset cifar100 --model squeezenet --optim AEGDM --lr 0.02
The checkpoints will be saved in the checkpoint folder and the data points of the learning curve will be saved in the curve folder.