Rex-Omni Finetuning Code

January 10, 2026 ยท View on GitHub

We provide code for both SFT and GRPO finetuning of the Rex-Omni model.

Table of Contents


Installation

# Install Rex-Omni, skip if you have already installed it
conda create -n rexomni -m python=3.10
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
git clone https://github.com/IDEA-Research/Rex-Omni.git
cd Rex-Omni
pip install -v -e .

# Install required dependencies
cd finetuning
pip install -v -e .


Stage 1: SFT Finetuning

Stage 1 SFT finetuning uses supervised learning to finetune the model with annotated data.

1.1. Data Format

This project uses TSV (Tab-Separated Values) format to store datasets. Each dataset requires three files:

  1. Image TSV file (*.images.tsv): Stores base64-encoded images
  2. Annotation TSV file (*.annotations.tsv): Stores annotation data
  3. Line index file (*.annotations.tsv.lineidx): Stores byte offsets for each annotation line in the file

Here is an example code to parse the dataset:

import json
import os
from base64 import b64decode
from io import BytesIO

import numpy as np
from torch.utils.data import Dataset

class TSVBase(Dataset):
    """Base class for TSV dataset. This class is used to load image and annotations from TSV file.

    Args:
        img_tsv_file (str): The path to the image TSV file.
        ann_tsv_file (str): The path to the annotation TSV file.
        ann_lineidx_file (str): The path to the annotation lineidx file.
        num_workers (int): The number of workers.
        data_ratio (float, optional): The ratio of data to use. Defaults to 1.0.
        filter_empty (bool): If filter the samples without annotations. When training, set it to True.
        dataset_type (str): The data source.
    """

    def __init__(
        self,
        img_tsv_file: str,
        ann_tsv_file: str,
        ann_lineidx_file: str,
    ):

        self.data = []
        f = open(ann_lineidx_file)
        for line in tqdm(f):
            self.data.append(int(line.strip()))

        self.img_handle = None
        self.ann_handle = None
        self.img_tsv_file = img_tsv_file
        self.ann_tsv_file = ann_tsv_file

        self.preparer = None
        self.captionbuilder = None
        self._transforms = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        ann_line_idx = self.data[idx]
        if self.ann_handle is None:
            self.ann_handle = open(self.ann_tsv_file)
        self.ann_handle.seek(ann_line_idx)
        img_line_idx, ann = self.ann_handle.readline().strip().split("\t")

        img_line_idx = int(img_line_idx)
        if self.img_handle is None:
            self.img_handle = open(self.img_tsv_file)
        self.img_handle.seek(img_line_idx)
        img = self.img_handle.readline().strip().split("\t")[1]
        if img.startswith("b'"):
            img = img[1:-1]
        img = BytesIO(b64decode(img))
        img = Image.open(img).convert("RGB")
        target = json.loads(ann)
        return img, target

1.2. Download Toy Data

We have released toy datasets for both grounding and pointing tasks with 1000 samples for finetuning examples. You can download them at Huggingface

1.2.1 Grounding Data Format Example

The Grounding task is used to train the model for region localization (phrase grounding). Below is the format specification for annotation data:

Annotation TSV file format (*.annotations.tsv):

  • Line format: {image_line_idx}\t{annotation_json}
  • image_line_idx: Line index of the corresponding image in the images.tsv file
  • annotation_json: JSON-formatted annotation data

Annotation JSON data structure for Grounding:

{
  "boxes": [
    {
      "bbox": [x0, y0, x1, y1],  // Bounding box coordinates in xyxy format
      "phrase": "category name",  // the category or description for this box
    },
    ...
  ]
}

Annotation JSON data structure for Pointing:

{
  "points": [
    {
      "point": [x, y],  // Point coordinates in absolute pixel values
      "phrase": "category name",  // the category or description for this point
    },
    ...
  ]
}

Key Differences:

  • Grounding: Uses "boxes" field with "bbox": [x0, y0, x1, y1] (4 coordinates)
  • Pointing: Uses "points" field with "point": [x, y] (2 coordinates)
  • Both use "phrase" to specify the object category

Negative Sample Support:

Rex-Omni supports negative samples (categories not present in the image) directly in the data format:

For Grounding task:

{
  "boxes": [
    {"bbox": [10, 20, 100, 200], "phrase": "person"},        // Positive sample
    {"bbox": null, "phrase": "car"}                          // Negative sample
  ]
}

For Pointing task:

{
  "points": [
    {"point": [65, 110], "phrase": "person"},                // Positive sample
    {"point": null, "phrase": "car"}                         // Negative sample
  ]
}

When a bbox or point is null, the model will learn to respond with "None" for that category, indicating the object is not present in the image. This helps reduce hallucinations.

Alternative: Using extra_categories

You can also dynamically add negative samples during training without including them in the data:

task_fn=dict(
    type=GroundingTaskFn,
    task_prompts=GROUNDING_SINGLE_REGION_STAGE_XYXY,
    image_min_pixels=min_pixels,
    image_max_pixels=max_pixels,
    extra_categories=["car", "dog", "cat", ...],  # Categories for negative sampling
)

This will randomly sample categories from extra_categories that don't appear in the image and add them as negative samples.

1.2.2 Visualize the Toy Dataset

  • Visualize Grounding Data
python tools/vis_tsv_dataset.py \
  --img_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_data.images.tsv \
  --ann_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_data.annotations.tsv \
  --ann_lineidx_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_data.annotations.tsv.lineidx \
  --output_dir Mountchicken/Rex-Omni-Finetune-ToyData/vis \
  --num_samples 20 \
  • Visualize Pointing Data
python tools/vis_tsv_dataset.py \
  --img_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_point_data.images.tsv \
  --ann_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_point_data.annotations.tsv \
  --ann_lineidx_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_point_data.annotations.tsv.lineidx \
  --output_dir Mountchicken/Rex-Omni-Finetune-ToyData/vis \
  --num_samples 20 \

1.2.3 Convert Custom Data to TSV Format

We also provide a script to convert custom data in JSON format to TSV format.

Usage:

python tools/convert_json_data_to_tsv.py \
  --json_file /path/to/annotations.json \
  --image_root_path /path/to/images \
  --save_image_tsv_path output/images.tsv \
  --save_ann_tsv_path output/annotations.tsv \
  --save_ann_lineidx_path output/annotations.tsv.lineidx

JSON Format:

The input JSON file should contain one JSON object per line. Each line must have two keys: image_name and annotation.

  • image_name: The relative path or filename of the image (will be joined with image_root_path)
  • annotation: The annotation data (can be any valid JSON object, will be stored as-is in the TSV file)

Example JSON file (annotations.json):

For Grounding task:

{"image_name": "image1.jpg", "annotation": {"boxes": [{"bbox": [10, 20, 100, 200], "phrase": "person"}]}}
{"image_name": "subdir/image2.jpg", "annotation": {"boxes": [{"bbox": [50, 50, 150, 250], "phrase": "car"}]}}
{"image_name": "image3.png", "annotation": {"boxes": [{"bbox": [0, 0, 200, 300], "phrase": "dog"}]}}

For Pointing task:

{"image_name": "image1.jpg", "annotation": {"points": [{"point": [65, 110], "phrase": "person"}]}}
{"image_name": "subdir/image2.jpg", "annotation": {"points": [{"point": [100, 150], "phrase": "car"}]}}
{"image_name": "image3.png", "annotation": {"points": [{"point": [100, 150], "phrase": "dog"}]}}

Parameters:

  • --json_file: Path to the input JSON file (one JSON object per line)
  • --image_root_path: Root directory where images are stored
  • --save_image_tsv_path: Output path for the image TSV file
  • --save_ann_tsv_path: Output path for the annotation TSV file
  • --save_ann_lineidx_path: Output path for the annotation lineidx file

The script will:

  1. Read JSON file line by line
  2. Load images from image_root_path/image_name
  3. Convert images to base64-encoded JPEG format
  4. Generate TSV files with proper line indices

1.3. Launch Training

Use the provided script to launch training:

bash scripts/sft.sh

Main parameter descriptions:

  • --config: Specify the configuration file path (contains dataset configuration)
  • --deepspeed: DeepSpeed configuration file for memory optimization
  • --output_dir: Model output directory
  • --per_device_train_batch_size: Training batch size per device
  • --gradient_accumulation_steps: Gradient accumulation steps
  • --learning_rate: Main learning rate
  • --mm_projector_lr: Multimodal projector learning rate
  • --vision_tower_lr: Vision encoder learning rate

Stage 2: GRPO Finetuning

Stage 2 GRPO finetuning uses reinforcement learning to further optimize model performance through reward functions. We mainly adopt the code from Easy-R1.

2.1 Download Toy Data

We use the same toy data as the SFT stage.

2.2 Launch Training

bash scripts/grpo.sh

Main parameter descriptions:

  • --data.config_path: Specify the configuration file path (contains dataset configuration)

2.3 Convert Checkpoint to Huggingface Version

python tools/merge_rl_checkpoints_to_hg_version.py --local_dir PATH_TO_CHECKPOINT_AT_ACTOR_DIR

--local_dir: The path to the checkpoint at the actor directory.

2.4 Reward Function

We implement the following reward functions in verl/configs/reward_func.py:

For Grounding Task:

  • Box IoU: Computes F1 score based on IoU between predicted and ground truth bounding boxes
    • Precision: Average best IoU for each predicted box
    • Recall: Average best IoU for each GT box
    • Use reward_name="box_iou" in config

For Pointing Task:

  • Point in Box: Checks if predicted points fall within ground truth bounding boxes

    • Requires GT data with both points and boxes
    • Use reward_name="point_in_box" in config
  • Point in Mask: Checks if predicted points fall within ground truth masks

    • Requires GT data with both points and masks (RLE format)
    • Use reward_name="point_in_mask" in config

Example GRPO Config for Pointing:

# configs/pointing_grpo.py
from dataset.task_fns import PointingTaskFn
from dataset.task_fns.task_prompts.pointing_task import POINTING_TASK_PROMPTS
from verl.utils.dataset import TSVRLHFDataset

min_pixels = 16 * 28 * 28
max_pixels = 2560 * 28 * 28

pointing_data = dict(
    type=TSVRLHFDataset,
    image_tsv_file="path/to/pointing_data.images.tsv",
    anno_tsv_file="path/to/pointing_data.annotations.tsv",
    anno_idx_file="path/to/pointing_data.annotations.tsv.lineidx",
    min_pixels=min_pixels,
    max_pixels=max_pixels,
    task_fn=dict(
        type=PointingTaskFn,
        task_prompts=POINTING_TASK_PROMPTS,
        image_min_pixels=min_pixels,
        image_max_pixels=max_pixels,
    ),
    dataset_name="pointing_grpo",
    reward_name="point_in_box",  # or "point_in_mask"
)

train_dataset = [
    pointing_data,
]

All reward functions return an F1 score (0-1) calculated from precision and recall.