Rex-Omni Finetuning Code
January 10, 2026 ยท View on GitHub
We provide code for both SFT and GRPO finetuning of the Rex-Omni model.
Table of Contents
- Rex-Omni Finetuning Code
Installation
# Install Rex-Omni, skip if you have already installed it
conda create -n rexomni -m python=3.10
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
git clone https://github.com/IDEA-Research/Rex-Omni.git
cd Rex-Omni
pip install -v -e .
# Install required dependencies
cd finetuning
pip install -v -e .
Stage 1: SFT Finetuning
Stage 1 SFT finetuning uses supervised learning to finetune the model with annotated data.
1.1. Data Format
This project uses TSV (Tab-Separated Values) format to store datasets. Each dataset requires three files:
- Image TSV file (
*.images.tsv): Stores base64-encoded images - Annotation TSV file (
*.annotations.tsv): Stores annotation data - Line index file (
*.annotations.tsv.lineidx): Stores byte offsets for each annotation line in the file
Here is an example code to parse the dataset:
import json
import os
from base64 import b64decode
from io import BytesIO
import numpy as np
from torch.utils.data import Dataset
class TSVBase(Dataset):
"""Base class for TSV dataset. This class is used to load image and annotations from TSV file.
Args:
img_tsv_file (str): The path to the image TSV file.
ann_tsv_file (str): The path to the annotation TSV file.
ann_lineidx_file (str): The path to the annotation lineidx file.
num_workers (int): The number of workers.
data_ratio (float, optional): The ratio of data to use. Defaults to 1.0.
filter_empty (bool): If filter the samples without annotations. When training, set it to True.
dataset_type (str): The data source.
"""
def __init__(
self,
img_tsv_file: str,
ann_tsv_file: str,
ann_lineidx_file: str,
):
self.data = []
f = open(ann_lineidx_file)
for line in tqdm(f):
self.data.append(int(line.strip()))
self.img_handle = None
self.ann_handle = None
self.img_tsv_file = img_tsv_file
self.ann_tsv_file = ann_tsv_file
self.preparer = None
self.captionbuilder = None
self._transforms = None
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
ann_line_idx = self.data[idx]
if self.ann_handle is None:
self.ann_handle = open(self.ann_tsv_file)
self.ann_handle.seek(ann_line_idx)
img_line_idx, ann = self.ann_handle.readline().strip().split("\t")
img_line_idx = int(img_line_idx)
if self.img_handle is None:
self.img_handle = open(self.img_tsv_file)
self.img_handle.seek(img_line_idx)
img = self.img_handle.readline().strip().split("\t")[1]
if img.startswith("b'"):
img = img[1:-1]
img = BytesIO(b64decode(img))
img = Image.open(img).convert("RGB")
target = json.loads(ann)
return img, target
1.2. Download Toy Data
We have released toy datasets for both grounding and pointing tasks with 1000 samples for finetuning examples. You can download them at Huggingface
1.2.1 Grounding Data Format Example
The Grounding task is used to train the model for region localization (phrase grounding). Below is the format specification for annotation data:
Annotation TSV file format (*.annotations.tsv):
- Line format:
{image_line_idx}\t{annotation_json} image_line_idx: Line index of the corresponding image in the images.tsv fileannotation_json: JSON-formatted annotation data
Annotation JSON data structure for Grounding:
{
"boxes": [
{
"bbox": [x0, y0, x1, y1], // Bounding box coordinates in xyxy format
"phrase": "category name", // the category or description for this box
},
...
]
}
Annotation JSON data structure for Pointing:
{
"points": [
{
"point": [x, y], // Point coordinates in absolute pixel values
"phrase": "category name", // the category or description for this point
},
...
]
}
Key Differences:
- Grounding: Uses
"boxes"field with"bbox": [x0, y0, x1, y1](4 coordinates) - Pointing: Uses
"points"field with"point": [x, y](2 coordinates) - Both use
"phrase"to specify the object category
Negative Sample Support:
Rex-Omni supports negative samples (categories not present in the image) directly in the data format:
For Grounding task:
{
"boxes": [
{"bbox": [10, 20, 100, 200], "phrase": "person"}, // Positive sample
{"bbox": null, "phrase": "car"} // Negative sample
]
}
For Pointing task:
{
"points": [
{"point": [65, 110], "phrase": "person"}, // Positive sample
{"point": null, "phrase": "car"} // Negative sample
]
}
When a bbox or point is null, the model will learn to respond with "None" for that category, indicating the object is not present in the image. This helps reduce hallucinations.
Alternative: Using extra_categories
You can also dynamically add negative samples during training without including them in the data:
task_fn=dict(
type=GroundingTaskFn,
task_prompts=GROUNDING_SINGLE_REGION_STAGE_XYXY,
image_min_pixels=min_pixels,
image_max_pixels=max_pixels,
extra_categories=["car", "dog", "cat", ...], # Categories for negative sampling
)
This will randomly sample categories from extra_categories that don't appear in the image and add them as negative samples.
1.2.2 Visualize the Toy Dataset
- Visualize Grounding Data
python tools/vis_tsv_dataset.py \
--img_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_data.images.tsv \
--ann_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_data.annotations.tsv \
--ann_lineidx_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_data.annotations.tsv.lineidx \
--output_dir Mountchicken/Rex-Omni-Finetune-ToyData/vis \
--num_samples 20 \
- Visualize Pointing Data
python tools/vis_tsv_dataset.py \
--img_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_point_data.images.tsv \
--ann_tsv_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_point_data.annotations.tsv \
--ann_lineidx_file Mountchicken/Rex-Omni-Finetune-ToyData/toy_point_data.annotations.tsv.lineidx \
--output_dir Mountchicken/Rex-Omni-Finetune-ToyData/vis \
--num_samples 20 \
1.2.3 Convert Custom Data to TSV Format
We also provide a script to convert custom data in JSON format to TSV format.
Usage:
python tools/convert_json_data_to_tsv.py \
--json_file /path/to/annotations.json \
--image_root_path /path/to/images \
--save_image_tsv_path output/images.tsv \
--save_ann_tsv_path output/annotations.tsv \
--save_ann_lineidx_path output/annotations.tsv.lineidx
JSON Format:
The input JSON file should contain one JSON object per line. Each line must have two keys: image_name and annotation.
image_name: The relative path or filename of the image (will be joined withimage_root_path)annotation: The annotation data (can be any valid JSON object, will be stored as-is in the TSV file)
Example JSON file (annotations.json):
For Grounding task:
{"image_name": "image1.jpg", "annotation": {"boxes": [{"bbox": [10, 20, 100, 200], "phrase": "person"}]}}
{"image_name": "subdir/image2.jpg", "annotation": {"boxes": [{"bbox": [50, 50, 150, 250], "phrase": "car"}]}}
{"image_name": "image3.png", "annotation": {"boxes": [{"bbox": [0, 0, 200, 300], "phrase": "dog"}]}}
For Pointing task:
{"image_name": "image1.jpg", "annotation": {"points": [{"point": [65, 110], "phrase": "person"}]}}
{"image_name": "subdir/image2.jpg", "annotation": {"points": [{"point": [100, 150], "phrase": "car"}]}}
{"image_name": "image3.png", "annotation": {"points": [{"point": [100, 150], "phrase": "dog"}]}}
Parameters:
--json_file: Path to the input JSON file (one JSON object per line)--image_root_path: Root directory where images are stored--save_image_tsv_path: Output path for the image TSV file--save_ann_tsv_path: Output path for the annotation TSV file--save_ann_lineidx_path: Output path for the annotation lineidx file
The script will:
- Read JSON file line by line
- Load images from
image_root_path/image_name - Convert images to base64-encoded JPEG format
- Generate TSV files with proper line indices
1.3. Launch Training
Use the provided script to launch training:
bash scripts/sft.sh
Main parameter descriptions:
--config: Specify the configuration file path (contains dataset configuration)--deepspeed: DeepSpeed configuration file for memory optimization--output_dir: Model output directory--per_device_train_batch_size: Training batch size per device--gradient_accumulation_steps: Gradient accumulation steps--learning_rate: Main learning rate--mm_projector_lr: Multimodal projector learning rate--vision_tower_lr: Vision encoder learning rate
Stage 2: GRPO Finetuning
Stage 2 GRPO finetuning uses reinforcement learning to further optimize model performance through reward functions. We mainly adopt the code from Easy-R1.
2.1 Download Toy Data
We use the same toy data as the SFT stage.
2.2 Launch Training
bash scripts/grpo.sh
Main parameter descriptions:
--data.config_path: Specify the configuration file path (contains dataset configuration)
2.3 Convert Checkpoint to Huggingface Version
python tools/merge_rl_checkpoints_to_hg_version.py --local_dir PATH_TO_CHECKPOINT_AT_ACTOR_DIR
--local_dir: The path to the checkpoint at the actor directory.
2.4 Reward Function
We implement the following reward functions in verl/configs/reward_func.py:
For Grounding Task:
- Box IoU: Computes F1 score based on IoU between predicted and ground truth bounding boxes
- Precision: Average best IoU for each predicted box
- Recall: Average best IoU for each GT box
- Use
reward_name="box_iou"in config
For Pointing Task:
-
Point in Box: Checks if predicted points fall within ground truth bounding boxes
- Requires GT data with both points and boxes
- Use
reward_name="point_in_box"in config
-
Point in Mask: Checks if predicted points fall within ground truth masks
- Requires GT data with both points and masks (RLE format)
- Use
reward_name="point_in_mask"in config
Example GRPO Config for Pointing:
# configs/pointing_grpo.py
from dataset.task_fns import PointingTaskFn
from dataset.task_fns.task_prompts.pointing_task import POINTING_TASK_PROMPTS
from verl.utils.dataset import TSVRLHFDataset
min_pixels = 16 * 28 * 28
max_pixels = 2560 * 28 * 28
pointing_data = dict(
type=TSVRLHFDataset,
image_tsv_file="path/to/pointing_data.images.tsv",
anno_tsv_file="path/to/pointing_data.annotations.tsv",
anno_idx_file="path/to/pointing_data.annotations.tsv.lineidx",
min_pixels=min_pixels,
max_pixels=max_pixels,
task_fn=dict(
type=PointingTaskFn,
task_prompts=POINTING_TASK_PROMPTS,
image_min_pixels=min_pixels,
image_max_pixels=max_pixels,
),
dataset_name="pointing_grpo",
reward_name="point_in_box", # or "point_in_mask"
)
train_dataset = [
pointing_data,
]
All reward functions return an F1 score (0-1) calculated from precision and recall.