NeurIPS 2025 OPT: Sharpness-Aware Minimization with Z-Score Gradient Filtering (Official Repository)

September 29, 2025 ยท View on GitHub

Author: Vincent-Daniel Yun (Juyoung Yun)
Institute: University of Southern California
[Paper] [Workshop] [Author Google Scholar]

Image

The settings of this repository are based on the default configuration of:

https://github.com/omihub777/ViT-CIFAR

Dependencies

# Please check requirements.txt
torchsummary
pytorch-lightning==1.2.1
comet-ml==3.3.5
matplotlib
numpy
pandas
scipy
numpy==1.26.4
scikit-learn
warmup_scheduler

Tiny ImageNet Download

mkdir -p data/tiny-imagenet
cd data/tiny-imagenet
wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
unzip tiny-imagenet-200.zip

Auto Run

bash run.sh

Manual Run

# Ours (Zsharp)
python main.py --dataset c10 --seed 52 --model-name res20s --sam --zsam

# Adaptive SAM
python main.py --dataset c10 --seed 52 --model-name res20s --samsung

# Standard SAM
python main.py --dataset c10 --seed 52 --model-name res20s --sam

# Friendly SAM
python main.py --dataset c10 --seed 52 --model-name res20s --fsam

# AdamW
python main.py --dataset c10 --seed 52 --model-name res20s

Args

--dataset [c10, c100]
--model-name [res20s, res56s, vgg16_bn, vit]

Adaptation to other codes

Zsharp Class:

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, hparams=None, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)
        self.zsam_enabled = getattr(hparams, "zsam", False) if hparams else False
        self.zsam_alpha = getattr(hparams, "alpha", False) if hparams else False

    @torch.no_grad()
    def first_step(self, closure=None, zero_grad=False):
        grad_norm = self._grad_norm()

        for group in self.param_groups:
            rho = group["rho"]
            scale = rho / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None:
                    continue

                self.state[p]["old_p"] = p.data.clone()
                grad = p.grad
                if self.zsam_enabled and grad.ndim > 1:
                    grad_mean = grad.mean()
                    grad_std = grad.std() + 1e-12
                    z_norm_grad = (grad - grad_mean) / grad_std
                    threshold = torch.quantile(z_norm_grad.abs(), self.zsam_alpha) 
                    mask = z_norm_grad.abs() > threshold  
                    filtered_grad = grad.clone()
                    filtered_grad[~mask] = 0.0
                    e_w = rho * filtered_grad / (filtered_grad.norm(p=2) + 1e-12) if filtered_grad.norm(p=2) > 0 else grad * scale
                else:
                    e_w = grad * scale
                p.add_(e_w)

                self.state[p]["prev_ew"] = e_w.clone()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.data = self.state[p]["old_p"]
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    def step(self, closure=None):
        assert closure is not None, "SAM requires closure"
        closure = torch.enable_grad()(closure)
        self.first_step(closure=closure)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(torch.stack([
            (torch.abs(p) if group["adaptive"] else 1.0) * p.grad.norm(p=2).to(shared_device)
            for group in self.param_groups for p in group["params"]
            if p.grad is not None
        ]), p=2)
        return norm

class Zharp(SAM):
    def __init__(self, params, hparams, rho=0.05, adaptive=False, lr=1e-3, betas=(0.9, 0.999), weight_decay=0.0, eps=1e-8):
        if hparams.sgd:
            base_opt = torch.optim.SGD
        else:
            base_opt = torch.optim.Adam
        super().__init__(
            params=params,
            base_optimizer=base_opt,
            rho=rho,
            adaptive=adaptive,
            hparams=hparams, 
            lr=lr,
            betas=betas,
            weight_decay=weight_decay,
            eps=eps
        )

Optimizer Block:

optimizer = self.Zharp(
    self.model.parameters(),
    hparams=self.hparams,
    lr=self.hparams.lr,
    betas=(self.hparams.beta1, self.hparams.beta2),
    weight_decay=self.hparams.weight_decay
        )



[Error Handling] If you have "zero_gradient" error in Pytorch

This error occurs because the zero_gradients function was removed in recent PyTorch updates. In PyTorch 2.0 and above, the function

torch.autograd.gradcheck.zero_gradients()

is no longer supported. It was available in earlier versions but has since been removed. As a result, some parts of the advertorch library are not compatible with the latest versions of PyTorch, leading to this issue.

In summary: the error happens because zero_gradients has been removed from torch.autograd.gradcheck in PyTorch.


Solution

To fix this issue, manually define zero_gradients and replace its usage in the affected file.

1. Remove the Import Statement

Open the file:

/opt/conda/lib/python3.11/site-packages/advertorch/attacks/fast_adaptive_boundary.py

Find and delete the following line:

from torch.autograd.gradcheck import zero_gradients

2. Define zero_gradients Manually

At the top of the same file (fast_adaptive_boundary.py), add the following function:

def zero_gradients(x):
    if x.grad is not None:
        x.grad.detach_()
        x.grad.zero_()




Acknowledgement

This research was supported by Brian Impact Foundation, a non-profit organization dedicated to the advancement of science and technology for all.

Reference

@article{yun2025zsharp,
  title     = {Sharpness-Aware Minimization with Z-Score Gradient Filtering},
  author    = {Yun, Vincent-Daniel},
  journal   = {arXiv preprint arXiv:2505.02369},
  year      = {2025},
  doi       = {10.48550/arXiv.2505.02369},
  url       = {https://arxiv.org/abs/2505.02369},
}