FedAdamW: A Communication-Efficient Optimizer with Convergence and Generalization Guarantees for Federated Large Models 被AAAI 2026录用!!!
May 12, 2026 · View on GitHub
A Communication-Efficient AdamW Optimizer for Federated Large Models
Mitigating client drift and stabilizing adaptive optimization for federated Transformer training
Federated Learning · Adaptive Optimization · Non-IID Generalization · Large Model Fine-Tuning
-
有代码问题 +vx 15653218567 马上回复!帮忙引用论文一下就行!
-
一张4090或者两张2080ti即可训练!!发顶会!!代码问题或者讨论+vx 15653218567
-
我的其他论文也都是这一套代码配置,均可复现!差分隐私,联邦泛化,联邦大模型,联邦优化,联邦大模型微调lora
Quick Start
Requirements
- Python 3.8
- PyTorch
- torchvision
- numpy
- matplotlib
- tensorboardX
- ray==1.0.0
- filelock
Installation
conda create -n fedmuon python=3.8 -y
conda activate fedmuon
pip install torch torchvision
pip install numpy matplotlib filelock tensorboardX ray==1.0.0
pip install peft transformers
Recommended package versions used by the original implementation:
python >= 3.8
torch >= 2.0
torchvision >= 0.15
ray == 1.0.0
tensorboardX == 2.6.2.2
peft == 0.13.2
transformers == 4.46.3
pip install -r requirements.txt
2. CNN Training (ResNet-18)
python main_FedAdamW.py --alg FedLADA --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 301 --extname FedMuon --lr_decay 2 --gamma 0.85 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50
python main_FedAdamW.py --alg FedAdamW --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 301 --extname FedMuon --lr_decay 2 --gamma 0.85 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50
python main_FedAdamW.py --alg FedAvg_adamw --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 301 --extname FedMuon --lr_decay 2 --gamma 0.85 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50
- 这里解释一下 --num_gpus_per 0.1的意思是如果你用的是4090显卡24g显存,那么你每个客户端将分配0.1张显卡,即2.4g显存。
- --lr_decay 2 解释一下,这个是余弦学习率下降
- --gpu 0 是指使用的是第0块gpu(gpu序号)
- --alpha_value 0.1 是迪利克雷非立同分布常数
- --alpha_value 1 这个时候是iid情况
- --lora 0 是否使用lora微调,从头训练的情况下,不用lora微调 选0就行
- --normalization BN resnet的归一化层,我选的是BN层,这个效果更好,选择GN也行,收敛的慢
- --data_name timy imagenet数据集需要自己下载,网址在下面
3. Vision Transformer Training
python main_FedAdamW.py --alg FedAdamW --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha 0.001 --epoch 301 --extname FedAvg_adamw --lr_decay 2 --gamma 0.5 --CNN deit_tiny --E 5 --batch_size 50 --gpu 2 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50
联邦大模型微调 vit
python main_FedAdamW.py --alg FedAdamW --lr 1e-3 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 101 --extname FedMuon --lr_decay 2 --gamma 0.85 --CNN VIT-B --E 5 --batch_size 16 --gpu 0 --p 1 --num_gpus_per 0.2 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 50 --preprint 10 --rho 0.01 --pix 224 --lora 1 --K 50
python main_FedAdamW.py --alg FedAvg_adamw --lr 1e-3 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 101 --extname FedMuon --lr_decay 2 --gamma 0.85 --CNN VIT-B --E 5 --batch_size 16 --gpu 0 --p 1 --num_gpus_per 0.2 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 50 --preprint 10 --rho 0.01 --pix 224 --lora 1 --K 50
- --lora 1 使用lora微调
- --batch_size 16 显存限制原因,16效果还可以
- --num_gpus_per 0.2 五个客户端,每个客户端使用0.2张卡
- --lr 1e-3 这个学习率微调lora最好
下载模型权重网址: 下载下来的权重直接放主文件夹下面就行,你也可以自己该目类
vit-base: https://huggingface.co/Junkang2/vit/tree/main
swin_transformer https://huggingface.co/Junkang2/swin_transformer/tree/main
Dataset
数据集下载网址
Tiny-ImageNet: https://huggingface.co/datasets/Junkang2/Tiny-ImageNet/upload/main
The code supports multiple datasets:
- CIFAR-10 / CIFAR-100
- Tiny-ImageNet
🤖 大语言模型训练示例(RoBERTa-base + GLUE-SST2)
python new_llm.py --alg FedAdamW --lr 2e-4 --data_name sst2 --alpha_value 0.8 --alpha 0.9 --epoch 101 --extname RoBERTa_SST2 --lr_decay 2 --gamma 0.9 --CNN roberta_base --E 10 --batch_size 16 --gpu 0 --p 1 --num_gpus_per 0.25 --selection 0.2 --pre 1 --num_workers 20 --preprint 5 --K 50 --freeze 1 --r 16 --lora 1 --print 1
数据集和模型权重下载地址:
-
RoBERTa_base模型权重下载地址,下载完之后放入 roberta_base 文件夹即可。 https://huggingface.co/FacebookAI/roberta-base/tree/main
-
数据集下载地址在hugging face上 sst2 https://huggingface.co/datasets/SetFit/sst2/tree/main 全部数据集下载地址: https://huggingface.co/datasets/Junkang2/glue/tree/main
Parameter Reference
Core Federated Learning Parameters
| Parameter | Description |
|---|---|
--alg | Algorithm choice: FedAvg, FedAdamW, FedCM, SCAFFOLD, etc. |
--lr | Client learning rate |
--lr_decay | Learning rate decay strategy (1=exponential, 2=cosine annealing) |
--gamma | Momentum parameter for certain algorithms |
--alpha | Weight decay coefficient for AdamW optimizer |
Data Parameters
| Parameter | Description |
|---|---|
--data_name | Dataset: CIFAR10, CIFAR100, imagenet, QQP, MNLI, etc. |
--alpha_value | Dirichlet distribution parameter for non-IID data splitting (0.1=highly non-IID, 1=IID) |
--num_workers | Total number of clients |
--selection | Fraction of clients selected per round (0.1=10%) |
Model Parameters
| Parameter | Description |
|---|---|
--CNN | Model architecture: resnet18, swin_tiny, deit_tiny, roberta_base |
--pre | Use pretrained weights (1=True, 0=False) |
--normalization | Normalization type: BN (BatchNorm) or GN (GroupNorm) |
--pix | Input image size (32 for CIFAR, 224 for ImageNet) |
Training Parameters
| Parameter | Description |
|---|---|
--epoch | Total communication rounds |
--E | Local epochs per client |
--batch_size | Client batch size |
--K | Maximum local steps per round (overrides E if smaller) |
--p | Parallelism factor for client updates |
LoRA Parameters
| Parameter | Description |
|---|---|
--lora | Enable LoRA fine-tuning (1=True, 0=False) |
--r | LoRA rank |
--lora_alpha | LoRA scaling parameter |
Optimization Parameters
| Parameter | Description |
|---|---|
--beta1 | Adam optimizer β1 parameter |
--beta2 | Adam optimizer β2 parameter |
--rho | SAM optimizer perturbation radius |
--optimizer | Base optimizer: SGD or AdamW |
System Parameters
| Parameter | Description |
|---|---|
--gpu | GPU device IDs (e.g., "0,1,2") |
--num_gpus_per | GPU fraction per client (0.2=20% of a GPU) |
--print | Print detailed logs (1=True, 0=False) |
--preprint | Evaluation frequency (in epochs) |
Output Files
- Logs:
./log/alg-dataset-lr-workers-batch-epochs-lr_decay.txt - Checkpoints:
./checkpoint/ckpt-alg-lr-extname-alpha_value-timestamp/ - Plots:
./plot/alg-dataset-...-timestamp.npy(contains accuracy/loss arrays) - Models:
./model/model-alg-...-timestamp.pth
Notes
- LoRA Usage: When
--lora 1, only LoRA parameters are trainable by default - Pretrained Models: Automatically downloads required pretrained weights
- Data Splitting: Uses Dirichlet distribution for non-IID splits when
--alpha_value < 1 - Memory: Adjust
--num_gpus_perbased on your GPU memory capacity
For transformer training with GLUE tasks, use new_llm.py with appropriate --data_name (QQP, MNLI, SST2, etc.).
🌌 联邦学习实验平台 · 中文文档
(支持 CNN & Transformer 双栈训练)
📂 一键安装依赖
# 基础环境
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 联邦学习 & 日志
pip install ray==1.0.0 tensorboardX==2.6.2.2 tqdm==4.67.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
# Transformer & 数据集
pip install transformers==4.46.3 datasets==3.1.0 peft==0.13.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 科学计算
pip install scikit-learn==1.3.2 scipy==1.9.3 matplotlib==3.7.5 -i https://pypi.tuna.tsinghua.edu.cn/simple
🎯 CNN 训练示例(CIFAR-100)
1. Swin-Tiny 联邦训练
python new_adamw.py \
--alg FedAdamW \
--lr 3e-4 \
--data_name CIFAR100 \
--alpha_value 0.1 \
--alpha 0.01 \
--epoch 101 \
--extname Swin_CIFAR100 \
--lr_decay 2 \
--gamma 0.5 \
--CNN swin_tiny \
--E 5 \
--batch_size 16 \
--gpu 2 \
--p 1 \
--num_gpus_per 0.2 \
--normalization BN \
--selection 0.05 \
--pre 1 \
--num_workers 100 \
--K 50
2. ResNet-18 联邦训练
python new_adamw.py \
--alg FedAdamW \
--lr 3e-4 \
--data_name CIFAR100 \
--alpha_value 0.1 \
--alpha 0.001 \
--epoch 301 \
--extname ResNet18_CIFAR100 \
--lr_decay 2 \
--gamma 0.5 \
--CNN resnet18 \
--E 5 \
--batch_size 50 \
--gpu 1 \
--pix 32 \
--lora 0 \
--K 50
🤖 大语言模型训练示例(RoBERTa-base + GLUE-SST2)
python new_llm.py \
--alg FedAdamW \
--lr 3e-4 \
--data_name sst2 \
--alpha_value 0.8 \
--alpha 0.9 \
--epoch 101 \
--extname RoBERTa_SST2 \
--lr_decay 2 \
--gamma 0.9 \
--CNN roberta_base \
--E 10 \
--batch_size 32 \
--gpu 0 \
--p 1 \
--num_gpus_per 0.25 \
--selection 0.2 \
--pre 1 \
--num_workers 20 \
--preprint 5 \
--K 50 \
--freeze 1 \
--r 16 \
--lora 1 \
--print 1
🎛️ 参数速查表(中文)
| 参数 | 说明 | 推荐值 |
|---|---|---|
--alg | 联邦算法 | FedAdamW(Adam优化) / FedAvg(SGD优化) |
--lr | 客户端学习率 | 3e-4(Adam) / 0.1(SGD) |
--data_name | 数据集 | CIFAR100 / sst2 / qnli |
--alpha_value | 数据异构度 | 1.0(IID) / 0.1(高度非IID) |
--CNN | 模型架构 | swin_tiny / resnet18 / roberta_base |
--lora | LoRA微调 | 1(启用)→显存占用降低90% |
--r | LoRA秩 | 16(平衡性能与效率) |
--gpu | GPU设备 | "0"(单卡) / "0,1,2"(多卡) |
--epoch | 通信轮数 | 100(CNN) / 50(大模型) |
--E | 本地轮数 | 5(CNN) / 10(大模型) |
--K | 本地步数上限 | 覆盖--E的步数限制 |
--selection | 每轮参与比例 | 0.1(10%客户端) |
--batch_size | 本地批次大小 | 16(GPU显存紧张时) |
📊 输出文件结构
实验结果/
├── log/ # 训练日志(txt)
├── plot/ # 训练曲线(npy)
├── model/ # 最终权重(pth)
└── checkpoint/ # 断点续训(ckpt)
💡 实用技巧
-
显存优化:
- CNN训练:
--batch_size 16+--num_gpus_per 0.2 - 大模型:
--lora 1+--r 8(显存占用 < 4GB)
- CNN训练:
-
数据异构可视化:
修改--alpha_value(0.1~1.0)观察精度变化曲线 -
快速验证:
添加--print 1实时查看loss,减少--epoch至20
🌈 支持的完整任务列表
| 任务类型 | 数据集 | 模型 | 示例命令 |
|---|---|---|---|
| 图像分类 | CIFAR-10/100 | ResNet/Swin/DeiT | 见上方CNN示例 |
| 文本分类 | SST-2(情感分析) | RoBERTa-base | 见上方LLM示例 |
| 自然语言推理 | QNLI/MNLI | RoBERTa-base | 替换--data_name qnli |
| 句子对匹配 | MRPC/QQP | RoBERTa-base | 替换--data_name mrpc |
🎉 祝实验顺利!有任何问题欢迎提Issue交流~