discrete-continuous action space policy gradient-based attention for image-text matching
February 1, 2022 · View on GitHub
PyTorch code for VSRN described in the paper "discrete-continuous action space policy gradient-based attention for image-text matching". The paper will appear in CVPR 2021. It is built on top of the VSRN.
discrete-continuous action space policy gradient-based attention for image-text matching https://arxiv.org/abs/2104.10406
Introduction
Image-text matching is an important multi-modal task with massive applications. It tries to match the image and the text with similar semantic information. Existing approaches do not explicitly transform the different modalities into a common space. Meanwhile, the attention mechanism which is widely used in image-text matching models does not have supervision. We propose a novel attention scheme which projects the image and text embedding into a common space and optimises the attention weights directly towards the evaluation metrics. The proposed attention scheme can be considered as a kind of supervised attention and requiring no additional annotations. It is trained via a novel Discrete-continuous action space policy gradient algorithm, which is more effective in modelling complex action space than previous continuous action space policy gradient. We evaluate the proposed methods on two widely-used benchmark datasets: Flickr30k and MS-COCO, outperforming the previous approaches by a large margin.
Requirements
We recommended the following dependencies.
-
Python 3.8
-
Punkt Sentence Tokenizer:
import nltk
nltk.download()
> d punkt
Download data
Download the dataset files and pre-trained models. We use splits produced by Andrej Karpathy.
We follow bottom-up attention model and SCAN to obtain image features for fair comparison. More details about data pre-processing (optional) can be found here. Data needed for reproducing the experiments in the paper, including image features and vocabularies, can be downloaded from SCAN by using:
wget https://scanproject.blob.core.windows.net/scan-data/data.zip
For Glove word embedding data, please check: 链接(Link):https://pan.baidu.com/s/1x87flEG_hq0FM3Z9CJCqeQ 提取码 (Extraction Code):icrx
Training new models
Run train.py:
For MSCOCO:
python train.py --data_path $DATA_PATH --data_name coco_precomp --logger_name runs/coco_VSRN --max_violation
For Flickr30K:
python train.py