1. two-step method

November 27, 2024 ยท View on GitHub

WSAM Optimizer

Weighted Sharpness as a Regularization Term

KDD arXiv

We present PyTorch code for Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term, KDD'23. The code is based on https://github.com/davda54/sam.

Deep Neural Networks (DNNs) generalization is known to be closely related to the flatness of minima, leading to the development of Sharpness-Aware Minimization (SAM) for seeking flatter minima and better generalization. We propose a more general method, called WSAM, by incorporating sharpness as a regularization term. WSAM can achieve improved generalization, or is at least highly competitive, compared to the vanilla optimizer, SAM and its variants.

WSAM can achieve different minima by choosing
different ๐›พ.

WSAM can achieve different (flatter) minima by choosing different ๐›พ.

Usage

Similar to SAM, WSAM can be used in a two-step manner or with a single closure-based function.

from atorch.optimizers.wsam import WeightedSAM
from atorch.optimizers.utils import enable_running_stats, disable_running_stats

...

model = YourModel()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # initialize the base optimizer
optimizer = WeightedSAM(model, base_optimizer, rho=0.05, gamma=0.9, adaptive=False, decouple=True, max_norm=None)
...
# 1. two-step method
for input, output in data:
  enable_running_stats(model)
  with model.no_sync():
    # first forward-backward pass
    loss = loss_function(output, model(input))  # use this loss for any training statistics
    loss.backward()
  optimizer.first_step(zero_grad=True)
  disable_running_stats(model)

  # second forward-backward pass
  loss_function(output, model(input)).backward()  # make sure to do a full forward pass
  optimizer.second_step(zero_grad=True)
...
# 2. closure-based method
for input, output in data:
  def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

  loss = loss_function(output, model(input))
  loss.backward()
  optimizer.step(closure)
  optimizer.zero_grad()
...

Extra Notes

  • Regulatization mode: It is recommended to perform a decoupled update of the sharpness term, as used in our paper.
  • Gradient clipping: To ensure training stability, if max_norm is not None, WSAM will perform gradient clipping.
  • Gradient sync: This implementation synchronizes gradients correctly, corresponding to the m-sharpness used in the SAM paper.
  • Rho selection: If you try to reproduce ViT results from this paper, use a larger rho when having less GPUs. For more information, see this related link.