Self-Supervised Learning for Visual Relationship Detection through Masked Bounding Box Reconstruction

May 23, 2024 ยท View on GitHub

deeplab.ai

Zacharias Anastasakis, Dimitrios Mallis, Markos Diomataris, George Alexandridis, Stefanos Kollias, Vassilis Pitsikalis

We propose Masked Bounding Box Reconstruction, a variation of Masked Image Modeling where a percentage of the entities/objects within a scene are masked and subsequently reconstructed based on the unmasked objects. Through object-level masked modeling, our proposed network learns context-aware representations that capture the interaction of objects within a scene and are highly predictive of visual object relationships.

This repository contains the code for reproducing our IEEE/CVF Winter Conference on Applications of Computer Vision 2024 paper and is based on the grounding-consistent-vrd. You can find our paper here.

Environment Setup

After cloning this repository, you can set up a conda environment using the mbbr.yml config file:

conda env create -f mbbr.yml
conda activate mbbr

Dataset Setup

You can download the VRD and/or VG200 dataset by running the main_prerequisites python file. You can define the dataset as an argument:

python3 main_prerequisites.py VG200

Train

Training involves 2 steps:

  1. Pre-train a transformer network in a self-supervised manner through Masked Bounding Box Reconstruction (MBBR)
python3 main_research.py --model=MBBR --net_name=MBBRNetwork --projection_head --dataset=VG200 --pretrain_arch=encoder
  1. Train an MLP network in a few-shot setting on random samples, using the pre-trained network from the previous step:
python3 main_research.py --model=SSL_finetune --net_name=FinetunedNetwork --dataset=VG200 --pretrain_arch=encoder --random_few_shot=10 --random_seed=4 --pretrained_model=MBBRNetwork --projection_head --normal --pretrain_task=reconstruction

The above command trains a 2-layer MLP network on 10 random samples from the VRD dataset. However, in our work we also manually selected {1,2,5} accurate relationships per Predicate Category and used them to train our classifier. These relationships are given in the prerequisites/{VG200/VRD}_few_shot_dict.json files. You can train a classifier on these manually-selected samples by running the following command:

python3 main_research.py --model=SSL_finetune --net_name=FinetunedNetwork --dataset=VG200 --pretrain_arch=encoder --few_shot=5 --pretrained_model=MBBRNetwork --projection_head --normal --pretrain_task=reconstruction

Test

After training, testing is automatically performed and micro/macro Recal@[20, 50, 100] is printed for both constrained and unconstrained scenarios while also calculating zero-shot results.

Checkpointing is performed so re-running step 2 for an already trained model will simply perform testing.

Citation

If you plan to use this code in your work or experiments, please use the following citation:

@INPROCEEDINGS{Anastasakis_WACV_2024,
   author={Anastasakis, Zacharias and Mallis, Dimitrios and Diomataris, Markos and Alexandridis,
George and Kollias, Stefanos and Pitsikalis, Vassilis},
   booktitle={2024 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
   title={Self-Supervised Learning for Visual Relationship Detection through Masked Bounding Box
Reconstruction},
   year={2024},
   volume={},
   number={},
   pages={1195-1204},
   keywords={Representation learning;Visualization;Computer vision;Codes;Self-supervised
learning;Predictive models;Task analysis;Algorithms;Image recognition and
understanding;Algorithms;Machine learning architectures;formulations;and algorithms},
   doi={10.1109/WACV57701.2024.00124}}

Feel free to contact us for any issues!!