Marten: Visual Question Answering with Mask Generation for Multi-modal Document Understanding(CVPR 2025)
July 27, 2025 ยท View on GitHub
๐ Introduction
Paper: (๐๐๐ Accepted by CVPR2025 ๐๐๐):
๐ MTMask6M

-
MTMask6M:
- Coming soon
-
Original Document Data Sources:
๐ Usage
๐ฆ Installation
Ensure you have Python 3.8 or higher installed in your environment.
git clone https://github.com/Token-family/Marten.git
cd Marten
pip install -r requirements.txt
๐ ๏ธ Creating Your Own Dataset
Step 1: Obtain Word-level Bounding Boxes
Use OCR engines(PaddleOCR,CRAFT) to generate word-level bounding boxes in the format:
/path/to/image\t[[x1_1,y1_1,x2_1,y2_1,x3_1,y3_1,x4_1,y4_1], ... ,[x1_n,y1_n,x2_n,y2_n,x3_n,y3_n,x4_n,y4_n]]
Step 2: Generate Masks
python mask_utils/mask_generation.py
Step 3: Data Format
Reference InternVL2 for complete format specifications:
{
"id": 1,
"image": "/path/to/image",
"mask_path": "/path/to/mask",
"conversations":[
{
"from": "human",
"value": "<image>\nRecognize all text:",
},
{
"from": "gpt",
"value":"Fill in the visual text content here",
}
]
}
๐ Training
Follow InternVL2 methodology:
Pre-training
bash ./shell/marten_internlm2_intervit_pretrain.sh
Fine-tuning
bash ./shell/marten_internlm2_intervit_finetune.sh
Training with MGM
If you want to integrate the MGM module into your own model structure, you can refer to the code.
import torch
import torch.nn as nn
import torch.nn.functional as F
...
from transformers.modeling_utils import PreTrainedModel
from ..marten_module.MGM import MGM
def dice_loss(pred, target, smooth=1e-6):
"""
่ฎก็ฎไบๅ็ฑป้ฎ้ข็ Dice Loss
:param pred: ้ขๆต็ปๆ, ๅฝข็ถไธบ [N, 1, H, W]
:param target: ็ๅฎๆ ็ญพ, ๅฝข็ถไธบ [N, 1, H, W]
:param smooth: ๅนณๆป้กน๏ผ้ฒๆญข้ค้ถ
:return: Dice Loss ๅผ
"""
pred = torch.sigmoid(pred)
pred_flat = pred.view(-1)
target_flat = target.view(-1)
intersection = (pred_flat * target_flat).sum()
union = pred_flat.sum() + target_flat.sum()
dice = (2. * intersection + smooth) / (union + smooth)
dice_loss = 1 - dice
return dice_loss
class CustomModel(PreTrainedModel):
def __init__(self, config, ..., use_mgm=False):
...
llm_hidden_size = config.llm_config.hidden_size
self.use_mgm = use_mgm
if self.use_mgm:
self.MGM_Decoder = MGM(llm_hidden_size, hidden_size=512, dev_convs_nums=4, out_channels=1, layer_num=4)
self.MGM_Decoder._initialize_weights()
self.MGM_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.5]))
self.MGM_aug_loss = dice_loss
self.select_llm_layer_idx = -4
def forward(
self,
...
pixel_values_mask: Optional[bool] = None,
...
):
...
"""
image_llm_hidden_features: Image-related features in the hidden layer of LLM
text_llm_hidden_features: Text-related features in the hidden layer of LLM
output_size: equal to image size
image_patch_size: The size of patch of image
image_token_num: The number of image tokens
loss: LLM original loss
"""
if self.use_mgm and pixel_values_masks is not None:
# The other parameters are customized for InternVL dynamic slicing. If you use other VFMs, you can delete them.
MGM_output = self.MGM_Decoder(image_llm_hidden_features, text_llm_hidden_features, output_size, image_patch_size, image_token_num)
loss += self.MGM_loss(MGM_output, pixel_values_mask, )
loss += self.dice_loss(MGM_output, pixel_values_mask.float().long())
๐ Evaluate
bash ./shell/eval.sh
๐ TODO List
- Release training / evaluation code for Marten series
- Release code for mask generation
- Release dataset of MTMask6M
๐ Acknowledgement
Marten is built with reference to the code of the following projects: InternVL2
๐ Citation