FedAdamW: A Communication-Efficient Optimizer with Convergence and Generalization Guarantees for Federated Large Models 被AAAI 2026录用!!!

May 12, 2026 · View on GitHub

A Communication-Efficient AdamW Optimizer for Federated Large Models

Mitigating client drift and stabilizing adaptive optimization for federated Transformer training


Python PyTorch FedAdamW Vision Language


Federated Learning · Adaptive Optimization · Non-IID Generalization · Large Model Fine-Tuning

  • 有代码问题 +vx 15653218567 马上回复!帮忙引用论文一下就行!

  • 一张4090或者两张2080ti即可训练!!发顶会!!代码问题或者讨论+vx 15653218567

  • 我的其他论文也都是这一套代码配置,均可复现!差分隐私,联邦泛化,联邦大模型,联邦优化,联邦大模型微调lora

  • 个人主页:https://junkangliu0.github.io/


Quick Start

Requirements

  • Python 3.8
  • PyTorch
  • torchvision
  • numpy
  • matplotlib
  • tensorboardX
  • ray==1.0.0
  • filelock

Installation

conda create -n fedmuon python=3.8 -y
conda activate fedmuon

pip install torch torchvision
pip install numpy matplotlib filelock tensorboardX ray==1.0.0
pip install peft transformers

Recommended package versions used by the original implementation:

python >= 3.8
torch >= 2.0
torchvision >= 0.15
ray == 1.0.0
tensorboardX == 2.6.2.2
peft == 0.13.2
transformers == 4.46.3

pip install -r requirements.txt

2. CNN Training (ResNet-18)

python  main_FedAdamW.py --alg FedLADA --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha  10  --epoch 301  --extname FedMuon --lr_decay 2 --gamma 0.85  --CNN   resnet18 --E 5 --batch_size 50   --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10  --rho 0.01 --pix 32 --lora 0 --K 50
python  main_FedAdamW.py --alg FedAdamW --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha  10  --epoch 301  --extname FedMuon --lr_decay 2 --gamma 0.85  --CNN   resnet18 --E 5 --batch_size 50   --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10  --rho 0.01 --pix 32 --lora 0 --K 50
python  main_FedAdamW.py --alg FedAvg_adamw --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha  10  --epoch 301  --extname FedMuon --lr_decay 2 --gamma 0.85  --CNN   resnet18 --E 5 --batch_size 50   --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10  --rho 0.01 --pix 32 --lora 0 --K 50
  • 这里解释一下 --num_gpus_per 0.1的意思是如果你用的是4090显卡24g显存,那么你每个客户端将分配0.1张显卡,即2.4g显存。
  • --lr_decay 2 解释一下,这个是余弦学习率下降
  • --gpu 0 是指使用的是第0块gpu(gpu序号)
  • --alpha_value 0.1 是迪利克雷非立同分布常数
  • --alpha_value 1 这个时候是iid情况
  • --lora 0 是否使用lora微调,从头训练的情况下,不用lora微调 选0就行
  • --normalization BN resnet的归一化层,我选的是BN层,这个效果更好,选择GN也行,收敛的慢
  • --data_name timy imagenet数据集需要自己下载,网址在下面

3. Vision Transformer Training

python main_FedAdamW.py --alg FedAdamW --lr 3e-4 --data_name CIFAR100 --alpha_value 0.1 --alpha 0.001 --epoch 301 --extname FedAvg_adamw --lr_decay 2 --gamma 0.5 --CNN deit_tiny --E 5 --batch_size 50 --gpu 2 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50

联邦大模型微调 vit

python  main_FedAdamW.py --alg FedAdamW --lr 1e-3 --data_name CIFAR100 --alpha_value 0.1 --alpha  10  --epoch 101  --extname FedMuon --lr_decay 2 --gamma 0.85  --CNN   VIT-B --E 5 --batch_size 16   --gpu 0 --p 1 --num_gpus_per 0.2 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 50 --preprint 10  --rho 0.01 --pix 224 --lora 1 --K 50
python  main_FedAdamW.py --alg FedAvg_adamw --lr 1e-3 --data_name CIFAR100 --alpha_value 0.1 --alpha  10  --epoch 101  --extname FedMuon --lr_decay 2 --gamma 0.85  --CNN   VIT-B --E 5 --batch_size 16   --gpu 0 --p 1 --num_gpus_per 0.2 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 50 --preprint 10  --rho 0.01 --pix 224 --lora 1 --K 50
  • --lora 1 使用lora微调
  • --batch_size 16 显存限制原因,16效果还可以
  • --num_gpus_per 0.2 五个客户端,每个客户端使用0.2张卡
  • --lr 1e-3 这个学习率微调lora最好

下载模型权重网址: 下载下来的权重直接放主文件夹下面就行,你也可以自己该目类

vit-base: https://huggingface.co/Junkang2/vit/tree/main

swin_transformer https://huggingface.co/Junkang2/swin_transformer/tree/main

Dataset

数据集下载网址

Tiny-ImageNet: https://huggingface.co/datasets/Junkang2/Tiny-ImageNet/upload/main

The code supports multiple datasets:

  • CIFAR-10 / CIFAR-100
  • Tiny-ImageNet

🤖 大语言模型训练示例(RoBERTa-base + GLUE-SST2)

python new_llm.py --alg FedAdamW --lr 2e-4 --data_name sst2 --alpha_value 0.8 --alpha 0.9 --epoch 101 --extname RoBERTa_SST2 --lr_decay 2 --gamma 0.9 --CNN roberta_base --E 10 --batch_size 16 --gpu 0 --p 1 --num_gpus_per 0.25 --selection 0.2 --pre 1 --num_workers 20 --preprint 5 --K 50 --freeze 1 --r 16 --lora 1 --print 1

数据集和模型权重下载地址:

Parameter Reference

Core Federated Learning Parameters

ParameterDescription
--algAlgorithm choice: FedAvg, FedAdamW, FedCM, SCAFFOLD, etc.
--lrClient learning rate
--lr_decayLearning rate decay strategy (1=exponential, 2=cosine annealing)
--gammaMomentum parameter for certain algorithms
--alphaWeight decay coefficient for AdamW optimizer

Data Parameters

ParameterDescription
--data_nameDataset: CIFAR10, CIFAR100, imagenet, QQP, MNLI, etc.
--alpha_valueDirichlet distribution parameter for non-IID data splitting (0.1=highly non-IID, 1=IID)
--num_workersTotal number of clients
--selectionFraction of clients selected per round (0.1=10%)

Model Parameters

ParameterDescription
--CNNModel architecture: resnet18, swin_tiny, deit_tiny, roberta_base
--preUse pretrained weights (1=True, 0=False)
--normalizationNormalization type: BN (BatchNorm) or GN (GroupNorm)
--pixInput image size (32 for CIFAR, 224 for ImageNet)

Training Parameters

ParameterDescription
--epochTotal communication rounds
--ELocal epochs per client
--batch_sizeClient batch size
--KMaximum local steps per round (overrides E if smaller)
--pParallelism factor for client updates

LoRA Parameters

ParameterDescription
--loraEnable LoRA fine-tuning (1=True, 0=False)
--rLoRA rank
--lora_alphaLoRA scaling parameter

Optimization Parameters

ParameterDescription
--beta1Adam optimizer β1 parameter
--beta2Adam optimizer β2 parameter
--rhoSAM optimizer perturbation radius
--optimizerBase optimizer: SGD or AdamW

System Parameters

ParameterDescription
--gpuGPU device IDs (e.g., "0,1,2")
--num_gpus_perGPU fraction per client (0.2=20% of a GPU)
--printPrint detailed logs (1=True, 0=False)
--preprintEvaluation frequency (in epochs)

Output Files

  • Logs: ./log/alg-dataset-lr-workers-batch-epochs-lr_decay.txt
  • Checkpoints: ./checkpoint/ckpt-alg-lr-extname-alpha_value-timestamp/
  • Plots: ./plot/alg-dataset-...-timestamp.npy (contains accuracy/loss arrays)
  • Models: ./model/model-alg-...-timestamp.pth

Notes

  1. LoRA Usage: When --lora 1, only LoRA parameters are trainable by default
  2. Pretrained Models: Automatically downloads required pretrained weights
  3. Data Splitting: Uses Dirichlet distribution for non-IID splits when --alpha_value < 1
  4. Memory: Adjust --num_gpus_per based on your GPU memory capacity

For transformer training with GLUE tasks, use new_llm.py with appropriate --data_name (QQP, MNLI, SST2, etc.).

🌌 联邦学习实验平台 · 中文文档

(支持 CNN & Transformer 双栈训练)


📂 一键安装依赖

# 基础环境
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 -i https://pypi.tuna.tsinghua.edu.cn/simple

# 联邦学习 & 日志
pip install ray==1.0.0 tensorboardX==2.6.2.2 tqdm==4.67.1 -i https://pypi.tuna.tsinghua.edu.cn/simple

# Transformer & 数据集
pip install transformers==4.46.3 datasets==3.1.0 peft==0.13.2 -i https://pypi.tuna.tsinghua.edu.cn/simple

# 科学计算
pip install scikit-learn==1.3.2 scipy==1.9.3 matplotlib==3.7.5 -i https://pypi.tuna.tsinghua.edu.cn/simple

🎯 CNN 训练示例(CIFAR-100)

1. Swin-Tiny 联邦训练

python new_adamw.py \
  --alg FedAdamW \
  --lr 3e-4 \
  --data_name CIFAR100 \
  --alpha_value 0.1 \
  --alpha 0.01 \
  --epoch 101 \
  --extname Swin_CIFAR100 \
  --lr_decay 2 \
  --gamma 0.5 \
  --CNN swin_tiny \
  --E 5 \
  --batch_size 16 \
  --gpu 2 \
  --p 1 \
  --num_gpus_per 0.2 \
  --normalization BN \
  --selection 0.05 \
  --pre 1 \
  --num_workers 100 \
  --K 50

2. ResNet-18 联邦训练

python new_adamw.py \
  --alg FedAdamW \
  --lr 3e-4 \
  --data_name CIFAR100 \
  --alpha_value 0.1 \
  --alpha 0.001 \
  --epoch 301 \
  --extname ResNet18_CIFAR100 \
  --lr_decay 2 \
  --gamma 0.5 \
  --CNN resnet18 \
  --E 5 \
  --batch_size 50 \
  --gpu 1 \
  --pix 32 \
  --lora 0 \
  --K 50

🤖 大语言模型训练示例(RoBERTa-base + GLUE-SST2)

python new_llm.py \
  --alg FedAdamW \
  --lr 3e-4 \
  --data_name sst2 \
  --alpha_value 0.8 \
  --alpha 0.9 \
  --epoch 101 \
  --extname RoBERTa_SST2 \
  --lr_decay 2 \
  --gamma 0.9 \
  --CNN roberta_base \
  --E 10 \
  --batch_size 32 \
  --gpu 0 \
  --p 1 \
  --num_gpus_per 0.25 \
  --selection 0.2 \
  --pre 1 \
  --num_workers 20 \
  --preprint 5 \
  --K 50 \
  --freeze 1 \
  --r 16 \
  --lora 1 \
  --print 1

🎛️ 参数速查表(中文)

参数说明推荐值
--alg联邦算法FedAdamW(Adam优化) / FedAvg(SGD优化)
--lr客户端学习率3e-4(Adam) / 0.1(SGD)
--data_name数据集CIFAR100 / sst2 / qnli
--alpha_value数据异构度1.0(IID) / 0.1(高度非IID)
--CNN模型架构swin_tiny / resnet18 / roberta_base
--loraLoRA微调1(启用)→显存占用降低90%
--rLoRA秩16(平衡性能与效率)
--gpuGPU设备"0"(单卡) / "0,1,2"(多卡)
--epoch通信轮数100(CNN) / 50(大模型)
--E本地轮数5(CNN) / 10(大模型)
--K本地步数上限覆盖--E的步数限制
--selection每轮参与比例0.1(10%客户端)
--batch_size本地批次大小16(GPU显存紧张时)

📊 输出文件结构

实验结果/
├── log/              # 训练日志(txt)
├── plot/             # 训练曲线(npy)
├── model/            # 最终权重(pth)
└── checkpoint/       # 断点续训(ckpt)

💡 实用技巧

  1. 显存优化

    • CNN训练:--batch_size 16 + --num_gpus_per 0.2
    • 大模型:--lora 1 + --r 8(显存占用 < 4GB)
  2. 数据异构可视化
    修改--alpha_value(0.1~1.0)观察精度变化曲线

  3. 快速验证
    添加--print 1实时查看loss,减少--epoch20


🌈 支持的完整任务列表

任务类型数据集模型示例命令
图像分类CIFAR-10/100ResNet/Swin/DeiT见上方CNN示例
文本分类SST-2(情感分析)RoBERTa-base见上方LLM示例
自然语言推理QNLI/MNLIRoBERTa-base替换--data_name qnli
句子对匹配MRPC/QQPRoBERTa-base替换--data_name mrpc

🎉 祝实验顺利!有任何问题欢迎提Issue交流~