EMC-Click: Efficient Mask Correction for Click-Based Interactive Image Segmentation (CVPR2023)
August 8, 2023 ยท View on GitHub
The Pytorch code for "Efficient Mask Correction for Click-Based Interactive Image Segmentation" in CVPR2023.
Abstract
The goal of click-based interactive image segmentation is to extract target masks with the input of positive/negative clicks. Every time a new click is placed, existing methods run the whole segmentation network to obtain a corrected mask, which is inefficient since several clicks may be needed to reach satisfactory accuracy. To this end, we propose an efficient method to correct the mask with a lightweight mask correction network. The whole network remains a low computational cost from the second click, even if we have a large backbone. However, a simple correction network with limited capacity is not likely to achieve comparable performance with a classic segmentation network. Thus, we propose a click-guided self-attention module and a click-guided correlation module to effectively exploits the click information to boost performance. First, several templates are selected based on the semantic similarity with click features. Then the self-attention module propagates the template information to other pixels, while the correlation module directly uses the templates to obtain target out- lines. With the efficient architecture and two click-guided modules, our method shows preferable performance and efficiency compared to existing methods.
Environment setup
- Install the requirements by executing
pip install -r requirements.txt
- Prepare the dataset and pretrained backbone weights following: Data_Weight_Preparation.md
Evaluation
Download the pretrained checkpoints from Releases and put them into weights directory.
Run
python -m torch.distributed.launch --master_port=4321 --nproc_per_node=8 scripts/evaluate_model.py EMC-Click \
--model_dir='./weights/' \
--checkpoint=hr18s.pth,hr18.pth,hr32.pth,segb0.pth,segb3.pth \
--n-clicks=20 \
--gpus=0,1,2,3,4,5,6,7 \
--target-iou=0.9 \
--thresh=0.5 \
--eval-mode='emc-click' \
--datasets=GrabCut,Berkeley,SBD,DAVIS,PascalVOC
to evaluate all models on the GrabCut, Berkeley, SBD, DAVIS, PascalVOC datasets.
Train
Run
CONFIG=models/emcclick/hrnet18s_att_cclvis.py
EXP_NAME=hrnet18s_att_cclvis
nGPUS=4
nBS=64
nWORKERS=4
PORT=`expr $RANDOM + 5000`
python -m torch.distributed.launch --nproc_per_node=$nGPUS --master_port=$PORT \
train.py $CONFIG \
--ngpus=$nGPUS \
--workers=$nWORKERS \
--batch-size=$nBS \
--exp-name=$EXP_NAME
to train with the hrnet18s backbone.
You could find a templet in ./trainval_scripts/train_xxx.sh.
Acknowledgement
The code is implemented based on RITM and ClickSEG. We would like to express our sincere thanks to the contributors.
License
The code is released under the MIT License. It is a short, permissive software license. Basically, you can do whatever you want as long as you include the original copyright and license notice in any copy of the software/source.
Citation
If you find this work is useful for your research, please cite our papers:
@inproceedings{emcclick,
title={Efficient Mask Correction for Click-Based Interactive Image Segmentation},
author={Du, Fei and Yuan, Jianlong and Wang, Zhibin and Wang, Fan},
booktitle={CVPR},
year={2023}
}