Pretraining Deformable Image Registration Networks with Random Images
May 6, 2025 ยท View on GitHub
keywords: Image Registration, Self-supervised Learning, Pretraining
This is a PyTorch implementation of my paper:
The core idea of this work is to leverage randomly generated images to initialize (or pretrain) the encoder of an image registration network. To achieve this, we designed temporary lightweight decoders that are attached to the encoder of the registration DNN, and pretrained the resulting network using a standard image registration loss function on the task of aligning pairs of random images.
This approach is implemented in the MIR package. The source code for generating random images can be found here, and the lightweight decoder is implemented here. The repository also includes training and inference scripts to reproduce the results reported in the paper.
Pretraining and Fine-tuning Pipeline
Step 1: Pretraining the encoder on a proxy task of registering random images
Run python -u train_SSL.py to initiate the pretraining. We first extract the encoder from the registration DNN and connect it to a lightweight decoder for pretraining.
https://github.com/junyuchen245/Pretraining_Image_Registration_DNNs/blob/88a330b9b26313a25d346725df0464cfcdc32968/scripts/train_SSL.py#L36-L45
In each iteration, a pair of random images is generated using data=rs.gen_shapes(.), in which data[0] and data[1] contains moving and fixed random images along with their binary label maps stored in data[2] and data[3].
https://github.com/junyuchen245/Pretraining_Image_Registration_DNNs/blob/88a330b9b26313a25d346725df0464cfcdc32968/scripts/train_SSL.py#L84-L86
We then simply compute the registration loss for pretraining.
https://github.com/junyuchen245/Pretraining_Image_Registration_DNNs/blob/88a330b9b26313a25d346725df0464cfcdc32968/scripts/train_SSL.py#L105-L106
Step 2: Fine-tuning the DNN on a downstream registration task
Run python -u train_TransMorph_SSL.py to initiate the fine-tuning. We first load the pretrained encoder from the pretraining.
https://github.com/junyuchen245/Pretraining_Image_Registration_DNNs/blob/adae0b487e6968a47192d1e1088ba5c2ca4cd192/scripts/train_TransMorph_SSL.py#L53-L62
We then train the model as usual.
Pretraining Strategy Overview
Generating Paired Random Images
Pretraining Reduces Amount of Data Needed to Achieve Competitive Performance
