๐ŸŒ„ Efficient Test-Time Model Adaptation without Forgetting

May 27, 2022 ยท View on GitHub

This is the official project repository for Efficient Test-Time Model Adaptation without Forgetting ๐Ÿ”— by Shuaicheng Niu, Jiaxiang Wu, Yifan Zhang, Yaofo Chen, Shijian Zheng, Peilin Zhao and Mingkui Tan (ICML 2022).

๐ŸŒ„ EATA conducts model learning at test time to adapt a pre-trained model to test data that has distributional shifts โ˜€๏ธ ๐ŸŒง โ„๏ธ, such as corruptions, simulation-to-real discrepancies, and other differences between training and testing data.

  • 1๏ธโƒฃ: EATA conducts selective/sample-adaptative optimization, i.e., only perform backward propagation for reliable and non-redundant test samples, which improves adaptaion efficiency and performance simtaneously.

  • 2๏ธโƒฃ: EATA regularizes the model parameters during testing to prevent the forgetting on in-distribution/clean test samples ๐Ÿ˜‹.

Installation:

EATA depends on

Data preparation:

This repository contains code for evaluation on ImageNet and ImageNet-C ๐Ÿ”— with ResNet models. But feel free to use your own data and models!

Usage:

import eata

model = TODO_model()

model = eata.configure_model(model)
params, param_names = eata.collect_params(model)
optimizer = TODO_optimizer(params, lr=2.5e-4)
adapt_model = eata.EATA(model, optimizer, fishers, e_margin, d_margin) 

outputs = adapt_model(inputs)  # now it infers and adapts!

Notes:

  • fishers are pre-calculated (see main.py) for preventing forgetting,
  • e_margin and d_margin are two parameters for selective (sample-adaptive) optimization.

Example: Adapting a pre-trained ResNet-50 model on ImageNet-C (Corruption).

Usage:

python3 main.py --data /path/to/imagenet --data_corruption /path/to/imagenet-c --exp_type 'continual' or 'each_shift_reset' --algorithm 'eata' or 'eta' or 'tent' --output /output/dir

'--exp_type' is choosen from:

  • 'continual' means the model parameters will never be reset, also called online adaptation;

  • 'each_shift_reset' means after each type of distribution shift, e.g., ImageNet-C Gaussian Noise Level 5, the model parameters will be reset.

'--algorithm' is chosen from:

  • 'tent' (baseline);
  • 'eta' (ours eata w/o regularization);
  • 'eata' (ours)

Results:

Here, we report the results on ImageNet-C, severity level = 5, with ResNet-50.

  • Corruption accuracy (ETA/EATA achieves higher corruption accuracy but use fewer backward passes (compared to TTT, Tent), thereby more efficient):
MethodGauss.ShotImpul.Defoc.GlassMotionZoomSnowFrostFogBrit.Contr.ElasticPixelJPEGAvg. #ForwardsAvg. #Backwards
R-50 (GN)+JT94.995.194.288.991.786.781.682.581.880.649.287.476.979.268.550,0000
+TTT ๐Ÿ”—69.066.466.671.992.266.863.259.181.049.038.261.150.648.352.050,000โœ–๏ธ2150,000โœ–๏ธ20
R-50 (BN)97.897.198.282.190.285.277.583.176.775.641.194.683.179.468.450,0000
+Tent ๐Ÿ”—71.669.869.971.872.758.650.552.958.742.532.674.945.241.547.750,00050,000
+ETA (ours)64.962.163.466.167.152.247.448.154.239.932.155.042.139.145.150,00026,031
+EATA (ours)65.063.164.366.366.652.947.248.654.340.132.055.742.439.345.050,00025,150
+EATA (lifelong)65.061.963.266.265.852.746.848.954.440.332.055.842.839.645.350,00028,243
  • Clean accuracy (testing the model's source accuracy on clean/original imagenet test set). EATA improves model's corruption acc. and maintains the source acc., while Tent can not.

forgetting_results

Please see our PAPER ๐Ÿ”— for detailed results.

Correspondence

Please contact Shuaicheng Niu by niushuaicheng [at] gmail.com ๐Ÿ“ฌ.

Citation

If the EATA method or fully test-time adaptation without forgetting are helpful in your research, please consider citing our paper:

@InProceedings{niu2022efficient,
  title={Efficient Test-Time Model Adaptation without Forgetting},
  author={Niu, Shuaicheng and Wu, Jiaxiang and Zhang, Yifan and Chen, Yaofo and Zheng, Shijian and Zhao, Peilin and Tan, Mingkui},
  booktitle = {The Internetional Conference on Machine Learning},
  year = {2022}
}

Acknowledgment

The code is greatly inspired by (heavily from) the Tent ๐Ÿ”— and TTT ๐Ÿ”—.