README.md
June 5, 2026 · View on GitHub
Partner-whitened steepest descent for the QK and OV attention circuits. [blog post]
Compositional Muon (CM) extends Muon from single-matrix steepest descent to the composed operators a transformer actually applies — the QK product and the OV product . Each factor's gradient is whitened by its partner's inverse Gram root before the spectral sign and scaled by it again afterward, so the step size for each matrix adapts to the geometry of its partner.
How it works
The loss sees and only through . Constraining the operator norm of the composed update and splitting the budget equally between the two factors gives the partner-whitened half-split rule
\qquad \Delta W_K = -\tfrac{\eta}{2}\,\mathrm{msign}\left(G_K C_Q^{-1}\right) C_Q^{-1},$$ with $C_K = (W_K^\top W_K + \lambda I)^{1/2}$ and symmetrically $C_Q$. The same construction applies to the OV product $W_O W_V$ (with $W_V$ per-head and $W_O$ per-matrix). When each partner Gram is near-isotropic the inverse root collapses to a scalar $C^{-1} \approx c^{-1} I$, recovering a cheap per-head dynamic learning rate. ## Usage Requires `torch`. `cm_ov` / `cm_qk` take the attention weights in `nn.Linear` convention plus caller-managed momentum buffers, and apply one CM update in place: ```python from compositional_muon import cm_ov, cm_qk # attn has q_proj/k_proj/v_proj/o_proj as bias-free nn.Linear layers; the momentum # buffers (zeros at init) are caller-managed, one per weight. cm_qk(attn.q_proj.weight, attn.k_proj.weight, attn.q_proj.weight.grad, attn.k_proj.weight.grad, m_q, m_k, head_dim=attn.head_dim, eta=lr) cm_ov(attn.v_proj.weight, attn.o_proj.weight, attn.v_proj.weight.grad, attn.o_proj.weight.grad, m_v, m_o, head_dim=attn.head_dim, eta=lr) ``` CM governs only the attention QK and OV pairs; update the other parameters with your optimizer of choice. `src/main.py` is a runnable demo (a small transformer trained with CM on attention and Muon on the rest). ## Variants | argument | default | values | description | | --- | --- | --- | --- | | `method` | `"half_split"` | `"half_split"`, `"joint"` | split the budget per factor, or one shared spectral sign over the stacked factors | | `isotropic` | `False` | `False`, `True` | full matrix partner whitening, or its per-head scalar approximation | | `hybrid` (OV) | `True` | `True`, `False` | $W_O$ per-matrix spectral sign with $W_V$ per-head, or both per-head | | `whitening` | `"both"` | `"both"`, `"pre"`, `"post"`, `"none"` | which side(s) of the partner whitening to apply | | `connection` | `"none"` | `"none"`, `"frobenius"`, `"scale_aware"`, `"frobenius_scalar"`, `"scale_aware_scalar"` | gauge fix removing the vertical (gauge) component of the update | | `momentum_reproject` | `False` | `False`, `True` | project the momentum onto the horizontal (gauge-fixed) bundle | | `per_mat_renorm` | `False` | `False`, `True` | restore each leg to its pre-whiten Frobenius norm | ## License Apache 2.0