Leveraging Text Localization for Scene Text Removal via Text-aware Masked Image Modeling

September 23, 2024 · View on GitHub

This is a pytorch implementation for paper TMIM

Installation

1.Requirements

  • Python==3.8.12
  • Pytorch==1.11.0
  • CUDA==11.3
conda create -n tmim python==3.8.12
conda activate tmim
pip install --upgrade pip
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 
pip install -r requirements.txt

2.Datasets

  • Create a "data" folder. Download text removal dataset (SCUT-Enstext) and text detection datasets(TextOCRTotal-Text, ICDAR2015, COCO-Text, MLT19, ArT, lsvt(fullly annotated), ReCTS). 

  • Create the coco-style annotations for text detection datasets with the code in utils/prepare_dataset/ (or download them from here(data.zip).

  • The structure of the data folder is shown below.

    data
    ├── text_det
       ├── art
       ├── train_images
       └── annotation.json
       ├── cocotext
       ├── train2014
       └── cocotext.v2.json
       ├── ic15
       ├── train_images
       └── annotation.json 
       ├── lsvt
       ├── train_images
       └── annotation.json 
       ├── mlt19
       ├── train_images
       └── annotation.json 
       ├── rects
       ├── img
       └── annotation.json 
       ├── textocr
       ├── train_images
       ├── TextOCR_0.1_train.json 
       └── TextOCR_0.1_val.json 
       └── totaltext
           ├── train_images
           └── annotation.json
    └── text_rmv
        └── SCUT-EnsText
            ├── train
       ├── all_images
       ├── all_labels
       └── mask
            └── test
                ├── all_images
                ├── all_labels
                └── mask
    
    

Models

ModelMethodPSNRMSSIMMSEAGEDownload
Uformer-BPretrained36.6697.660.06371.70uformer_b_tmim.pth
Uformer-BFintuned37.4297.700.04591.52uformer_b_tmim_str.pth
PERTPretrained34.5196.630.12312.11pert_tmim.pth
PERTFintuned35.6697.180.07291.76pert_tmim_str.pth
EraseNetPretrained34.2597.030.11412.23erasenet_tmim.pth
EraseNetFintuned35.4797.300.07651.95erasenet_tmim_str.pth

Inference

  • Download the pretrained models and run the following command for inference.
python -m torch.distributed.launch --master_port 29501 --nproc_per_node=1 demo.py --cfg configs/uformer_b_str.py --resume path/to/uformer_b_tmim_str.pth --test-dir path/to/image/folder --visualize-dir path/to/result/folder

Training and Testing

  • Set the "snapshot_dir"(The location for saving the checkpoints) and "dataroot"(The location of the datasets) in configs/*.py
  • Erasenet and Pert require 4 1080ti GPUs. Uformer requires 8 1080ti GPUs

1.Pretraining

  • Run the following command to pretrain the model on text detection datasets.
python -m torch.distributed.launch --master_port 29501 --nproc_per_node=8 train.py --cfg configs/uformer_b_tmim.py --ckpt-name uformer_b_tmim --save-log 
  • Run the following command to test the performance of the pretrained model.
python test.py --cfg configs/uformer_b_tmim.py --ckpt-name uformer_b_tmim/latest.pth --save-log --visualize

2.Finetuning

  • Run the following command to finetune the model on text removal datasets.
python -m torch.distributed.launch --master_port 29501} --nproc_per_node=8 train.py --cfg configs/uformer_b_str.py --ckpt-name uformer_b_tmim_str --save-log --resume 'ckpt/uformer_b_tmim/latest.pth'
  • Run the following command to test the performance of the finetuned model.
python test.py --cfg configs/uformer_b_str.py --ckpt-name uformer_b_tmim_str/latest.pth --save-log --visualize