PRIOR: Prototype Representation Joint Learning from Medical Images and Reports
November 9, 2023 ยท View on GitHub
Official repository for the paper Prototype Representation Joint Learning from Medical Images and Reports, ICCV 2023.
Introduction
Contrastive learning based vision-language joint pretraining has emerged as a successful representation learning strategy. In this paper, we present a prototype representation learning framework incorporating both global and local alignment between medical images and reports. In contrast to standard global multi-modality alignment methods, we employ a local alignment module for finegrained representation. Furthermore, a cross-modality conditional reconstruction module is designed to interchange information across modalities in the training phase by reconstructing masked images and reports. For reconstructing long reports, a sentence-wise prototype memory bank is cons tructed, enabling the network to focus on low-level localized visual and high-level clinical linguistic features. Additionally, a non-auto-regressive generation paradigm is proposed for reconstructing non-sequential reports. Experimental results on five downstream tasks, including supervised classification, zero-shot classification, image-to-text retrieval, semantic segmentation, and object detection, show the proposed method outperforms other state-of-the-art methods across multiple datasets and under different dataset size settings.
Setup
Run
pip install -r requirements.txt
Data Preparation
MIMIC-CXR Dataset
-
Download the Version 2 of the MIMIC-CXR-JPG from
https://physionet.org/content/mimic-cxr-jpg/2.0.0/to<image_dataset_path> -
Download the reports from MIMIC-CXR
https://physionet.org/content/mimic-cxr/2.0.0/to<report_dataset_path> -
Run scripts to make json file for pre-training
cd codes/
python prior/data/pretrain/mimiccxr.py build --root_image_path <image_dataset_path> --root_report_path <report_dataset_path> --save_path <dataset_json_path> --meta_csv_path <meta_csv_path>
- Add
<dataset_json_path>toconfigs/data/mimiccxr.yaml
_target_: prior.data.pretrain.mimiccxr.MimicCxrDataset
dataset_path: <dataset_json_path> # update this line
image_transform: ???
text_transform: ???
num_colors: 1
rate: 1.0
Pre-train
Run
cd codes/
python scripts/train.py +experiments/pre_train=train_prior
Pre-trained weights
We released our pre-trained model on pretrained/prior_resnet50.pt, you can download image encoder here. The whole image-text part is released on Google Drive.
Downstream tasks
Supervised finetuning
This part is similar to LOVT. The classification model is based on TorchVision; the segmentation model is based on SMP, and the detection model is based on Lightning-flash. Since this part has not addtional technical contribution, we do not provide the codes currently. We will update this part in the future.
Zero-shot classification
Before running the codes, make sure you have downloaded the CheXpert dataset in <root_path> and downloaded the whole image-text pre-trianed weights of PRIOR from Google Drive to <pretrained_path>.
Add <root_path> and <dataset_path> to configs/data/chexpert_zero_shot.yaml
_target_: prior.data.zero_shot_classification.chexpert_zls.CheXPertZeroClsDataset
dataset_path: ??? # update this line
transform: ???
num_colors: 1
root_path: ??? # update this line
rate: 1.0
Add <pretrained_path> to configs/experiments/zero_shot_classification/test_prior.yaml
zero_shot_classification_model:
......
pre_trained_path: # update this line
......
Then run
cd codes/
python scripts/downstream.py +experiments/zero_shot_classification=test_prior
Acknowledgement
Some of the code is borrowed from LOVT, GLoRIA. Thanks for their great work.
Citation
If you find this work useful in your research, please cite:
@inproceedings{PRIOR,
title={PRIOR: Prototype Representation Joint Learning from Medical Images and Reports},
author={Cheng, Pujin and Lin, Li and Lyu, Junyan and Huang, Yijin and Luo, Wenhan and Tang, Xiaoying},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2023}
}