Fine-tuning Tutorial
May 27, 2024 ยท View on GitHub
This tutorial explains how to fine-tune VisionFM for multi-class disease recognition.
Download the datasets
We use the datasets prepared by RETFound as examples. Please download them from its official github page. The data structure of each dataset is suggested to keep as it is after unzipping the file.
Fine-tune VisionFM for the PAPILA dataset
Fine-tuning using the train set and selecting the best checkpoint using the val set
Executing the following command will start fine-tuning VisionFM on the PAPILA dataset. Please change --pretrained_weights to be the correct local path of the VisionFM pre-trained fundus weights, which can be downloaded from VisionFM fundus encoder. Please also change --data_path to be the correct local path of the downloaded PAPILA dataset.
CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/PAPILA_FT_VisionFM_Val --data_path ./data/PAPILA --task PAPILA_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 3 --extra 10 > PAPILA_FT_VisionFM_Val.log 2>&1 &
The code will fine-tune VisionFM for 100 epochs. After fine-tuning is completed, the results (AUROC and AUPR) on the validation set can be found in the PAPILA_FT_VisionFM_Val.log file.
cat PAPILA_FT_VisionFM_Val.log
The fine-tuned weights (checkpoint_best_finetune.pth) can be found in the output directory whose path is specified using --output_dir.
Evaluating on the test set
The following command will load the fine-tuned weights and evaluate VisionFM on the test set. Please change --pretrained_weights to be the correct local path that saves previously fine-tuned weights.
CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/PAPILA_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/PAPILA_FT_VisionFM_test --data_path ./data/PAPILA --task PAPILA_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 3 --extra 10 > PAPILA_FT_VisionFM_test.log 2>&1 &
The results (AUROC and AUPR) on the test set is stored in PAPILA_FT_VisionFM_test.log
cat PAPILA_FT_VisionFM_test.log
Testing our fine-tuned weights
We also provide fine-tuned weights for reproducibility. Please download it from VisionFM Fine-tuned PAPILA, and change --pretrained_weights to be the correct local path that stores our fine-tuned weights path/to/checkpoint_papila.pth
CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights path/to/checkpoint_papila.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/PAPILA_FT_VisionFM_test --data_path ./data/PAPILA --task PAPILA_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 3 --extra 10 > PAPILA_FT_VisionFM_test.log 2>&1 &
Fine-tuning VisionFM on the rest multi-class datasets
Fine-tuning VisionFM on the rest seven datasets are the same. Please ensure the dataset and output directory paths are correct, and specify a distinct log file name for each dataset. Note that the number of labels --num_labels should be adjusted according to the dataset. For the OCTID dataset, please change --modality to be OCT, and download the pre-trained OCT encoder weights from VisionFM OCT encoder. For IDRiD, MESSIDOR2, and Kaggle APTOS-2019 datasets, fine-tuning 5 epochs are sufficient to produce reasonable results. Longer fine-tuning will lead to overfitting. For other datasets, fine-tuning 100 epochs is suggested.
Alternatively, users can use the following commands to fine-tune and test VisionFM on PAPILA and the rest seven datasets. Please scroll left to display the download links of our fine-tuned weights if your screen does not show
| Dataset | Fine-tuning | Testing | Our fine-tuned weights |
|---|---|---|---|
| IDRiD | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/IDRiD_data_FT_VisionFM_Val --data_path ./data/IDRiD_data --task IDRiD_data_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --epoch 5 --extra 10 > IDRiD_data_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/IDRiD_data_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/IDRiD_data_FT_VisionFM_test --data_path ./data/IDRiD_data --task IDRiD_data_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --extra 10 > IDRiD_data_FT_VisionFM_test.log 2>&1 & | Weights |
| MESSIDOR2 | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/MESSIDOR2_FT_VisionFM_Val --data_path ./data/MESSIDOR2 --task MESSIDOR2_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --epoch 5 --extra 10 > MESSIDOR2_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/MESSIDOR2_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/MESSIDOR2_FT_VisionFM_test --data_path ./data/MESSIDOR2 --task MESSIDOR2_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --extra 10 > MESSIDOR2_FT_VisionFM_test.log 2>&1 & | Weights |
| APTOS-2019 | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/APTOS2019_FT_VisionFM_Val --data_path ./data/APTOS2019 --task APTOS2019_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --epoch 5 --extra 10 > APTOS2019_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/APTOS2019_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/APTOS2019_FT_VisionFM_test --data_path ./data/APTOS2019 --task APTOS2019_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --extra 10 > APTOS2019_FT_VisionFM_test.log 2>&1 & | Weights |
| PAPILA | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/PAPILA_FT_VisionFM_Val --data_path ./data/PAPILA --task PAPILA_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 3 --extra 10 > PAPILA_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/PAPILA_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/PAPILA_FT_VisionFM_test --data_path ./data/PAPILA --task PAPILA_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 3 --extra 10 > PAPILA_FT_VisionFM_test.log 2>&1 & | Weights |
| Glaucoma Fundus | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/Glaucoma_fundus_FT_VisionFM_Val --data_path ./data/Glaucoma_fundus --task Glaucoma_fundus_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 3 --extra 10 > Glaucoma_fundus_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/Glaucoma_fundus_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/Glaucoma_fundus_FT_VisionFM_test --data_path ./data/Glaucoma_fundus --task Glaucoma_fundus_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 3 --extra 10 > Glaucoma_fundus_FT_VisionFM_test.log 2>&1 & | Weights |
| JSIEC | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/JSIEC_FT_VisionFM_Val --data_path ./data/JSIEC --task JSIEC_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 39 --extra 10 > JSIEC_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/JSIEC_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/JSIEC_FT_VisionFM_test --data_path ./data/JSIEC --task JSIEC_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 39 --extra 10 > JSIEC_FT_VisionFM_test.log 2>&1 & | Weights |
| Retina | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_Fundus_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/Retina_FT_VisionFM_Val --data_path ./data/Retina --task Retina_FT_VisionFM_Val --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 4 --extra 10 > Retina_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/Retina_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/Retina_FT_VisionFM_test --data_path ./data/Retina --task Retina_FT_VisionFM_test --modality Fundus --num_workers 4 --batch_size_per_gpu 128 --num_labels 4 --extra 10 > Retina_FT_VisionFM_test.log 2>&1 & | Weights |
| OCTID | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 finetune_visionfm_for_multiclass_classification.py --pretrained_weights ./pretrain_weights/VFM_OCT_weights.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/OCTID_FT_VisionFM_Val --data_path ./data/OCTID --task OCTID_FT_VisionFM_Val --modality OCT --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --extra 10 > OCTID_FT_VisionFM_Val.log 2>&1 & | CUDA_VISIBLE_DEVICES=0 nohup python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=20030 inference_visionfm_for_multiclass_classification.py --pretrained_weights ./results/OCTID_FT_VisionFM_Val/checkpoint_best_finetune.pth --arch vit_base --avgpool_patchtokens 0 --input_size 224 --output_dir ./results/OCTID_FT_VisionFM_test --data_path ./data/OCTID --task OCTID_FT_VisionFM_test --modality OCT --num_workers 4 --batch_size_per_gpu 128 --num_labels 5 --extra 10 > OCTID_FT_VisionFM_test.log 2>&1 & | Weights |