Marten: Visual Question Answering with Mask Generation for Multi-modal Document Understanding(CVPR 2025)

July 27, 2025 ยท View on GitHub

Marten: Visual Question Answering with Mask Generation for Multi-modal Document Understanding(CVPR 2025)

๐Ÿ“– Introduction

Paper: (๐Ÿš€๐Ÿš€๐Ÿš€ Accepted by CVPR2025 ๐Ÿš€๐Ÿš€๐Ÿš€):

๐Ÿ“„ MTMask6M

Datasets

๐Ÿ“š Usage

๐Ÿ“ฆ Installation

Ensure you have Python 3.8 or higher installed in your environment.

git clone https://github.com/Token-family/Marten.git
cd Marten
pip install -r requirements.txt

๐Ÿ› ๏ธ Creating Your Own Dataset

Step 1: Obtain Word-level Bounding Boxes

Use OCR engines(PaddleOCR,CRAFT) to generate word-level bounding boxes in the format:

/path/to/image\t[[x1_1,y1_1,x2_1,y2_1,x3_1,y3_1,x4_1,y4_1], ... ,[x1_n,y1_n,x2_n,y2_n,x3_n,y3_n,x4_n,y4_n]]

Step 2: Generate Masks

python mask_utils/mask_generation.py

Step 3: Data Format

Reference InternVL2 for complete format specifications:

{
    "id": 1,
    "image": "/path/to/image",
    "mask_path": "/path/to/mask",
    "conversations":[
        {
            "from": "human",
            "value": "<image>\nRecognize all text:",
        },
        {
            "from": "gpt",
            "value":"Fill in the visual text content here",
        }
    ]
}

๐Ÿš€ Training

Follow InternVL2 methodology:

Pre-training

bash ./shell/marten_internlm2_intervit_pretrain.sh

Fine-tuning

bash ./shell/marten_internlm2_intervit_finetune.sh

Training with MGM

If you want to integrate the MGM module into your own model structure, you can refer to the code.

import torch
import torch.nn as nn
import torch.nn.functional as F
...
from transformers.modeling_utils import PreTrainedModel
from ..marten_module.MGM import MGM

def dice_loss(pred, target, smooth=1e-6):
    """
    ่ฎก็ฎ—ไบŒๅˆ†็ฑป้—ฎ้ข˜็š„ Dice Loss
    :param pred: ้ข„ๆต‹็ป“ๆžœ, ๅฝข็Šถไธบ [N, 1, H, W]
    :param target: ็œŸๅฎžๆ ‡็ญพ, ๅฝข็Šถไธบ [N, 1, H, W]
    :param smooth: ๅนณๆป‘้กน๏ผŒ้˜ฒๆญข้™ค้›ถ
    :return: Dice Loss ๅ€ผ
    """
    pred = torch.sigmoid(pred)
    
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    
    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum()
    
    dice = (2. * intersection + smooth) / (union + smooth)
    
    dice_loss = 1 - dice
    
    return dice_loss


class CustomModel(PreTrainedModel):
    
    def __init__(self, config, ..., use_mgm=False):
        
        ...
        
        llm_hidden_size = config.llm_config.hidden_size
        self.use_mgm = use_mgm

        if self.use_mgm:
            self.MGM_Decoder = MGM(llm_hidden_size, hidden_size=512, dev_convs_nums=4, out_channels=1, layer_num=4)
            self.MGM_Decoder._initialize_weights()
            self.MGM_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.5]))    
            self.MGM_aug_loss = dice_loss 
            self.select_llm_layer_idx = -4

    def forward(
        self,
        ...
        pixel_values_mask: Optional[bool] = None,
        ...
    ):

        ...

        """
        image_llm_hidden_features: Image-related features in the hidden layer of LLM
        text_llm_hidden_features: Text-related features in the hidden layer of LLM
        output_size: equal to image size
        image_patch_size: The size of patch of image
        image_token_num: The number of image tokens
        loss: LLM original loss

        """

        if self.use_mgm and pixel_values_masks is not None:
            
            # The other parameters are customized for InternVL dynamic slicing. If you use other VFMs, you can delete them.
            MGM_output = self.MGM_Decoder(image_llm_hidden_features, text_llm_hidden_features, output_size, image_patch_size, image_token_num)  
            loss += self.MGM_loss(MGM_output, pixel_values_mask, )
            loss += self.dice_loss(MGM_output, pixel_values_mask.float().long())

๐Ÿ” Evaluate

bash ./shell/eval.sh

๐Ÿ“Œ TODO List

  • Release training / evaluation code for Marten series
  • Release code for mask generation
  • Release dataset of MTMask6M

๐Ÿ™ Acknowledgement

Marten is built with reference to the code of the following projects: InternVL2

๐Ÿ“œ Citation