FedMuon: Accelerating Federated Learning with Matrix Orthogonalization

May 12, 2026 · View on GitHub

Python PyTorch Ray Tasks License

A structure-aware federated optimizer for large vision and language models.
FedMuon stabilizes non-IID federated training by coupling matrix-orthogonalized local updates, local-global alignment, and cross-round momentum aggregation.


  • 一张4090或者两张2080ti即可训练!!发顶会!!代码问题或者讨论+vx 15653218567

  • 我的其他论文也都是这一套代码配置,均可复现!

  • 个人主页:https://junkangliu0.github.io/

Overview

Federated learning usually relies on element-wise local optimizers such as SGD or AdamW. These optimizers treat matrix-shaped parameters as flattened vectors and may amplify ill-conditioned directions during multi-step local training, especially when client data are heterogeneous.

FedMuon introduces matrix orthogonalization into federated optimization. It first studies Local Muon, where each client applies Muon-style orthogonalized updates locally, and then addresses the instability of Local Muon under non-IID data with two mechanisms:

  • Local-Global Alignment: aligns client-side orthogonalized updates with the global update direction to reduce client drift.
  • Momentum Aggregation: aggregates client momentum states across communication rounds to avoid momentum reinitialization.
  • SVD Momentum Compression: optionally communicates a low-rank approximation of momentum states to reduce communication overhead.

The repository provides Ray-based federated simulations for vision models and LoRA fine-tuning support for large language models.


Repository Structure

.
├── main_FedMuon.py          # Vision federated training entry point
├── new_llm.py               # Language / GLUE LoRA training entry point, if included
├── dirichlet_data.py        # Dirichlet non-IID partitioning
├── dataset.py               # Tiny-ImageNet dataset wrapper
├── model.py                 # Swin Transformer backbones
├── vit_model.py             # ViT backbones
├── models/
│   ├── resnet.py            # ResNet with GN variants
│   ├── resnet_bn.py         # ResNet with BN variants
│   └── DeiTTiny.py          # DeiT-Tiny backbone
├── data/                    # Dataset root
├── log/                     # Training logs
├── checkpoint/              # Checkpoints
└── plot/                    # Saved curves / numpy results

Installation

conda create -n fedmuon python=3.8 -y
conda activate fedmuon

pip install torch torchvision
pip install numpy matplotlib filelock tensorboardX ray==1.0.0
pip install peft transformers

Recommended package versions used by the original implementation:

python >= 3.8
torch >= 2.0
torchvision >= 0.15
ray == 1.0.0
tensorboardX == 2.6.2.2
peft == 0.13.2
transformers == 4.46.3

Datasets

Vision

The vision entry point supports:

Dataset argumentDatasetNotes
CIFAR10CIFAR-10Automatically downloaded by torchvision
CIFAR100CIFAR-100Automatically downloaded by torchvision
imagenetTiny-ImageNet-200Place under ./data/tiny-imagenet-200

For non-IID experiments, client partitions are generated with a Dirichlet distribution:

alpha_value = 0.6  # mild heterogeneity
alpha_value = 0.1  # strong heterogeneity

Generated partition files are cached with names such as:

num_workers_100-alpha_value_0.1-data_CIFAR100

Language

For LoRA-based language experiments, the paper evaluates GLUE tasks and OpenWebText with RoBERTa / GPT-style models. Use the language training script if it is included in your repository.

下载模型权重网址: 下载下来的权重直接放主文件夹下面就行,你也可以自己该目类

vit-base: https://huggingface.co/Junkang2/vit/tree/main

swin_transformer https://huggingface.co/Junkang2/swin_transformer/tree/main

Dataset

数据集下载网址

Tiny-ImageNet: https://huggingface.co/datasets/Junkang2/Tiny-ImageNet/upload/main

The code supports multiple datasets:

  • CIFAR-10 / CIFAR-100
  • Tiny-ImageNet

数据集和模型权重下载地址:


Quick Start

FedMuon on CIFAR-100, Dirichlet-0.1

可以,改成下面这种单行命令即可:

python main_FedMuon.py --alg FedMuon --data_name CIFAR100 --CNN deit_tiny --lr 3e-4 --epoch 301 --num_workers 100 --selection 0.1 --alpha_value 0.1 --batch_size 50 --E 5 --K 50 --lr_decay 2 --gamma 0.5 --alpha 10 --beta1 0.9 --beta2 0.999 --rho 0.01 --pix 32 --lora 0 --pre 1 --gpu 0 --num_gpus_per 0.1 --p 1 --preprint 10 --normalization BN --extname fedmuon_cifar100_dir01_deit

Local Muon baseline

python main_FedMuon.py --alg Local_Muon --data_name CIFAR100 --CNN deit_tiny --lr 3e-4 --epoch 301 --num_workers 100 --selection 0.1 --alpha_value 0.1 --batch_size 50 --E 5 --K 50 --lr_decay 2 --gamma 0.5 --alpha 10 --beta1 0.9 --beta2 0.999 --rho 0.01 --pix 32 --lora 0 --pre 1 --gpu 0 --num_gpus_per 0.1 --p 1 --preprint 10 --normalization BN --extname local_muon_cifar100_dir01_deit

FedAvg and AdamW baselines

python main_FedMuon.py --alg FedAvg --data_name CIFAR100 --CNN deit_tiny --lr 1e-1 --epoch 301 --num_workers 100 --selection 0.1 --alpha_value 0.1 --batch_size 50 --E 5 --K 50 --lr_decay 2 --gpu 0 --num_gpus_per 0.1 --p 1 --preprint 10 --normalization BN --extname fedavg_cifar100_dir01_deit
python main_FedMuon.py --alg FedAvg_adamw --data_name CIFAR100 --CNN deit_tiny --lr 3e-4 --epoch 301 --num_workers 100 --selection 0.1 --alpha_value 0.1 --batch_size 50 --E 5 --K 50 --lr_decay 2 --gamma 0.5 --alpha 10 --beta1 0.9 --beta2 0.999 --rho 0.01 --pix 32 --lora 0 --pre 1 --gpu 0 --num_gpus_per 0.1 --p 1 --preprint 10 --normalization BN --extname fedadamw_cifar100_dir01_deit

ResNet-18

python main_FedMuon.py --alg FedMuon --lr 3e-2 --data_name CIFAR100 --alpha_value 0.1 --alpha 0.5 --epoch 301 --extname FedMuon_resnet18 --lr_decay 2 --gamma 0.5 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --beta1 0.9 --beta2 0.999 --rho 0.01 --pix 32 --lora 0 --K 50
python main_FedMuon.py --alg Local_Muon --lr 3e-2 --data_name CIFAR100 --alpha_value 0.1 --alpha 0.5 --epoch 301 --extname LocalMuon_resnet18 --lr_decay 2 --gamma 0.5 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --beta1 0.9 --beta2 0.999 --rho 0.01 --pix 32 --lora 0 --K 50
python main_FedMuon.py --alg FedAvg --lr 1e-1 --data_name CIFAR100 --alpha_value 0.1 --alpha 0.5 --epoch 301 --extname FedAvg_resnet18 --lr_decay 2 --gamma 0.5 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --beta1 0.9 --beta2 0.999 --rho 0.01 --pix 32 --lora 0 --K 50
python main_FedMuon.py --alg FedAvg_adamw --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha 0.5 --epoch 301 --extname FedAvgAdamW_resnet18 --lr_decay 2 --gamma 0.5 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --beta1 0.9 --beta2 0.999 --rho 0.01 --pix 32 --lora 0 --K 50

Supported Algorithms

The current training script includes the following algorithm choices:

ArgumentDescription
FedMuonProposed matrix-orthogonalized FL optimizer with momentum aggregation and local-global alignment
Local_MuonLocal Muon baseline without FedMuon correction
FedAvgLocal SGD / FedAvg baseline
FedAvg_adamwAdamW-style local baseline
FedAdamServer-side adaptive FedAdam baseline
FedAdamWAdamW-based federated baseline
FedCMFederated client-momentum baseline
SCAFFOLDControl-variate correction baseline
FedLADAAdaptive moment aggregation baseline
Local_SoapSOAP-style local optimizer baseline

Note: FedMuon_SVD is implemented as a communication-efficient momentum-compression variant in the worker dispatch. If your local branch does not expose it in the main algorithm allow-list, add it before running --alg FedMuon_SVD.


Supported Models

Model argumentArchitecture
lenet5LeNet-style CNN
resnet10, resnet18, resnet34, resnet50ResNet variants
resnet18pre, resnet50pre, resnet101preImageNet-pretrained ResNet variants
deit_tinyDeiT-Tiny
VIT-B, VIT-LVision Transformer backbones
swin_tiny, swin_small, swin_base, swin_largeSwin Transformer backbones

LoRA is available for Transformer-style vision backbones and pretrained ResNet classifiers through --lora 1.


Important Arguments

ArgumentDefaultDescription
--algFedLESAMFederated algorithm name. Use FedMuon for the proposed method.
--data_nameCIFAR100Dataset name: CIFAR10, CIFAR100, or imagenet.
--CNNlenet5Model architecture.
--lr0.1Client learning rate.
--epoch1001Number of communication rounds.
--num_workers100Number of simulated clients.
--selection0.1Client participation ratio per round.
--alpha_value0.1Dirichlet concentration parameter for non-IID partitioning.
--batch_size50Client mini-batch size.
--E5Local epochs / local update budget.
--K50Maximum local steps per round.
--lr_decay0.998Learning-rate decay setting.
--gpu0Visible GPU device IDs.
--num_gpus_per1GPU fraction assigned to each Ray worker.
--p10Parallelism factor for client updates.
--preprint10Evaluation interval.
--lora0Enable LoRA fine-tuning.
--r16LoRA rank.
--pix224Input image resolution. Use 32 for CIFAR-style training.
--pre1Use pretrained weights when available.
--normalizationBNNormalization type for ResNet variants.
--datapath./dataDataset root.

Paper Results

CIFAR-100, 100 clients, 10% participation, batch size 50, K = 50

MethodResNet-18 Dir-0.6ResNet-18 Dir-0.1ViT-Tiny Dir-0.6ViT-Tiny Dir-0.1
FedAvg64.08 ± 0.1860.25 ± 0.2032.36 ± 0.0827.14 ± 0.12
Local AdamW62.84 ± 0.0858.97 ± 0.1040.47 ± 0.0937.86 ± 0.11
Local Muon71.66 ± 0.1566.71 ± 0.1546.69 ± 0.1540.53 ± 0.17
FedMuon74.12 ± 0.1873.05 ± 0.1650.22 ± 0.1448.18 ± 0.12

GLUE with RoBERTa-Base + LoRA, 20 clients, 20% participation, K = 50

MethodAverage Accuracy
FedAvg76.73
Local AdamW78.77
Local Muon80.17
FedMuon81.00

Outputs

The training script automatically writes logs, checkpoints, and curve files.

log/         # Training logs, e.g., alg-dataset-lr-workers-batch-E-lr_decay.txt
checkpoint/  # Checkpoints, e.g., ckpt-{alg}-{lr}-{extname}-{alpha_value}-{timestamp}/
plot/        # Saved numpy arrays for accuracy / loss curves
runs/        # TensorBoard summaries

During training, the script reports:

Iter r: accuracy, train loss, test loss, learning rate, algorithm, model, data split

Visualize TensorBoard logs with:

tensorboard --logdir runs

Reproducibility Checklist

  • Use the same client partition by keeping num_workers, alpha_value, and data_name unchanged.
  • Keep selection, batch_size, E, and K fixed when comparing FL algorithms.
  • Use the same backbone and input size across methods.
  • Run multiple seeds and report mean ± standard deviation.
  • For fair non-IID comparison, reuse cached Dirichlet partition files.
  • For Ray simulation, tune --num_gpus_per and --p according to available GPU memory.

Citation

@inproceedings{fedmuon2026,
  title     = {FedMuon: Accelerating Federated Learning with Matrix Orthogonalization for Large Models},
  author    = {Anonymous Authors},
  booktitle = {Advances in Neural Information Processing Systems},
  year      = {2026}
}

Acknowledgements

This implementation builds on PyTorch, Ray, torchvision, PEFT, and Transformers. We thank the open-source community for providing reliable tools for scalable federated learning research.