Simplified-AdEMAMix

February 4, 2025 · View on GitHub

This is the official implementation of the Sim-AdEMAMix optimizer. To use, copy the simplified_AdEMAMix.py file to your codebase and use the optimizer in the following fashion (here T represents the total steps of the run):

from simplified_AdEMAMix import SimAdEMAMix

optim = SimAdEMAMix(lr = 1e-4, betas=(.99, .95), alpha=0.0, min_beta1=0.9, beta1_warmup=T, weight_decay=0.0)

The optimizer by default has the momentum maintained in theory style (not EMA style) with bias correction turned off, which generally seems to help in practice with cosine decay. Optimal value of α\alpha really depends on the batch size, and from theory, should scale down linearly with increase in batch size. Our optimal alpha at a batch size of 1m tokens was close to 0 (0.05\approx 0.05), while at 32k was close to 100. At higher batch sizes (i.e curvature dominated regime, instead of noise dominated), α\alpha should be set close to a small multiple of $1-\beta_1$ (inspired by Nesterov).

For tuning η,β1,β2\eta, \beta_1, \beta_2 and min_beta, if we have an optimal Adam run with hyperparameters ηadam,β1adam\eta^{adam}, \beta_1^{adam} and β2adam\beta_2^{adam}, we recommend that for AdEMAMix, the optimal hyperparameters should be around min_beta = β1adam\beta_1^{adam}, β1\beta_1 higher than min_beta (for min_beta=0.9, maybe try 0.95, 0.99, 0.999), β2=β2adam\beta_2 = \beta_2^{adam} and η=ηadam(1min beta)(1β1)\eta = \eta^{adam} \sqrt{(1-\text{min beta})*(1-\beta_1)} (thus optimal η\eta is coupled with value of β1\beta_1). One more thing to note is that β1\beta_1 should generally decrease with increasing batch size.