ATP

December 25, 2023 · View on GitHub

This is the official implementation of the following paper:

Wenxuan Bao, Tianxin Wei, Haohan Wang, Jingrui He. Adaptive Test-Time Personalization for Federated Learning. NeurIPS 2023.

[Arxiv] [Poster] [Slides]

Introduction

  • We consider a novel setting named Test-Time Personalized Federated Learning, addressing the challenge of personalizing a global model to each unparticipating client during test-time, without requiring any labeled data.
  • We propose ATP, which adaptively learns the adaptation rate for each module, enabling it to handle different types of distribution shifts among FL clients.

Requirements

  • python 3.8.5
  • cudatoolkit 10.2.89
  • cudnn 7.6.5
  • pytorch 1.11.0
  • torchvision 0.12.0
  • numpy 1.18.5
  • tqdm 4.65.0
  • matplotlib 3.7.1

If you prefer generating the CIFAR-10C and CIFAR-100C by yourself, these packages may also be required:

  • wandb 0.16.0
  • scikit-image 0.17.2
  • opencv-python 4.8.0.74

(This codebase should not be very sensitive to the version of packages.)

Run

CIFAR-10C Experiments

We consider three types of distribution shifts in our CIFAR-10C experiments: feature shift, label shift, and hybrid shift.

cd ./exp/cifar10/${shift}

where ${shift} should be replaced by feat (feature shift), label (label shift), or hybrid (hybrid shift).

Generate Dataset

./data_prepare.sh

This shell script will partition the CIFAR-10 dataset to 300 clients (240 source clients and 60 clients), and save the partition indices to ~/data/atp/partition/cifar10/. When there are corruptions (feature shift and hybrid shift), we also cache the corrupted dataset to ~/data/atp/cifar10 to save time.

We also upload these

Train Global Model with FedAvg

Before running ATP, we need to train a global model with source clients' training sets. We use FedAvg algorithm to train the global model.

./pretrain_fedavg_${model}.sh

Here ${model} specifies the model architecture we use. We used resnet18 (ResNet-18) and cnn (shallow CNN) in our paper.

Learn Adaptation Rates with ATP

./atp_train_${model}.sh

It also prints the evaluation result of ATP-batch in each iteration.

Test-Time Personalization with ATP-batch and ATP-online

./atp_test_${model}.sh

The results of

Expected Accuracies

Notice that this is the result with one seed, while we showed the results from five difference random seeds in our paper.

ResNet-18

AlgorithmFeature shiftLabel shiftHybrid shift
No adaptation69.6272.5863.55
ATP-batch73.4880.0972.85
ATP-online73.8381.7275.34

Shallow CNN

AlgorithmFeature shiftLabel shiftHybrid shift
No adaptation64.3669.1561.87
ATP-batch67.0276.1468.48
ATP-online67.2278.3870.86

CIFAR-100C Experiments

Coming soon.

Digits-5

Coming soon.

PACS

Coming soon.

Citation

If you are also interested in test-time personalization, please consider giving a star ⭐️ to our repo and citing our paper:

@inproceedings{
  bao2023adaptive,
  title={Adaptive Test-Time Personalization for Federated Learning},
  author={Wenxuan Bao and Tianxin Wei and Haohan Wang and Jingrui He},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023},
  url={https://openreview.net/forum?id=rbw9xCU6Ci}
}