MuonClip
July 24, 2025 ยท View on GitHub
PyTorch and JAX implementation of MuonClip optimizer from the Kimi K2 Technical Report.
MuonClip is an optimizer that combines:
- Muon momentum-based updates with Newton-Schulz orthogonalization
- Consistent RMS scaling for stability
- Per-head QK-Clip mechanism to prevent attention logit explosion
- Weight decay for regularization
Install
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
Usage
PyTorch
from src.muon_clip_pytorch import MuonClip
# Create optimizer
optimizer = MuonClip(
model.parameters(),
lr=2e-4,
momentum=0.95,
weight_decay=0.1,
tau=100.0
)
optimizer.set_model(model) # Required for QK-Clip
# Training step
loss.backward()
optimizer.step()
JAX
from src.muon_clip_jax import muonclip
# Create optimizer
optimizer = muonclip(learning_rate=2e-4, momentum=0.95, weight_decay=0.1)
# In training loop, QK-Clip is applied separately after gradient update
state = state.apply_gradients(grads=grads)
# Then apply QK-Clip to attention weights based on max_logits
Examples
See examples/ for complete training examples comparing MuonClip with AdamW on a GPT model:
# PyTorch example
python examples/example_pytorch.py
# JAX example
python examples/example_jax.py
Both examples train two identical transformer models (one with MuonClip, one with AdamW) on the tiny_shakespeare dataset.
Tests
pytest