Training MetaMorph Guide

April 15, 2025 · View on GitHub

This guide explains how to train MetaMorph.

Overview

MetaMorph training follows a two-stage approach:

  1. Pretraining the MLP Connector: Connects vision and language representations.
  2. Fine-tuning: Optimizes both the LLM and the connector together.

Key Training Parameters

Basic Model Configuration

ParameterDescriptionDefault Value
--model_name_or_pathPath to the base language modelPATH_TO_LLAMA3-8B
--versionConversation template versionllama3
--model_max_lengthMaximum sequence length for training4096
--output_dirDirectory to save model checkpointsPATH_TO_OUTPUT_DIR

Vision Tower Configuration

ParameterDescriptionDefault Value
--vision_towerVision model used for processing imagessiglip/CLIP-ViT-SO400M-14-384
--mm_vision_select_layerWhich layer of vision model to use-1 (last layer)
--freeze_visionWhether to freeze vision backboneTrue
--normalize_visionWhether to normalize vision embeddingsTrue

Image Token Configuration

ParameterDescriptionDefault Value
--image_token_reductionMethod to reduce image tokensinterpolation
--num_image_tokensNumber of tokens used per image64
--mm_use_im_start_endUse special image start/end tokensTrue
--mm_use_im_patch_tokenUse image patch tokensFalse

Multimodal Projector Configuration

ParameterDescriptionDefault Value
--mm_projector_typeType of projector to map vision to languagemlp2x_gelu
--tune_mm_mlp_adapterWhether to tune only the adapter (for pretraining)True (pretraining), False (finetuning)
--pretrain_mm_mlp_adapterPath to pretrained adapter (for finetuning)PATH_TO_Pretrained_Adapter

Visual Auto-Regressive Parameters

ParameterDescriptionDefault Value
--use_vision_arEnable vision auto-regressive predictionFalse (pretraining), True (finetuning)
--vision_head_typeType of vision head for AR predictionmlp

Optimization Parameters

ParameterDescriptionDefault Value
--learning_rateBase learning ratePlease adjust based on your batch size
--weight_decayWeight decay for AdamW optimizer0.0
--warmup_ratioRatio of steps for learning rate warmup0.03
--lr_scheduler_typeType of learning rate schedulercosine
--bf16Use bfloat16 precisionTrue
--fp16Use float16 precisionFalse

Data Parameters

ParameterDescriptionRecommended Value
--data_pathPath to training dataPath_To_Data_JSONL
--image_folderPath to image directoryPATH_TO_Images

Training Scripts

Debug / Non-SLURM system

  1. Stage 1: Use the scripts/pretrain_1node.sh script to pretrain the MLP connector.
  2. Stage 2: Use the scripts/finetune_1node.sh script to finetune the full model.

SLURM / Multi-Node Training

  1. Stage 1: Use the scripts/slurm_pretrain.sh script to pretrain the MLP connector.
  2. Stage 2: Use the scripts/slurm_finetune.sh script to finetune the full model.

Data Format

MetaMorph supports data in the following format:

{
  "id": "unique_id",
  "image": "path/to/image.jpg",
  "conversations": [
    {
      "from": "human",
      "value": "<image> What is shown in this image?"
    },
    {
      "from": "gpt",
      "value": "This is a detailed description of the image."
    }
  ]
}
  1. Pretrain the MLP connector with tune_mm_mlp_adapter=True and use_vision_ar=False. This stage is focused on effectively connecting the vision and language models.

  2. Fine-tune the full model with the pretrained adapter, setting use_vision_ar=True to enable visual generation capabilities.

  3. For best results with limited resources, adjust batch size and gradient accumulation steps to maintain the effective global batch size. The formula is:

    • Global Batch Size = per_device_train_batch_size × gradient_accumulation_steps × num_gpus × num_nodes

Learning Rate Calculation

When using a different batch size than the original implementation, adjust the learning rate using:

Optimal Learning Rate = Base Learning Rate * √(Batch Size / Base Batch Size)

For example, if the base learning rate is 6.93e-5 for a batch size of 1536, and you're using a batch size of 768, the optimal learning rate would be: 6.93e-5 * √(768/1536) = 4.9e-5

Visualization Training

For training the visualization component (to generate images from SigLIP embeddings), use the scripts in the visualization/ directory. The key parameters are explained in visualization/Train_Visualization.md.