Data prepration

May 25, 2025 · View on GitHub

We provide data examples for T2I, Editing, and VLM tasks. The T2I dataset is generated using FLUX.1‑dev; the editing examples are randomly sampled from SEED‑Data‑Edit‑Part3; and the VLM set is sourced from LLaVA‑OneVision‑Data.

We offer examples in both raw-image folder and parquet shard formats. For other data formats, you can use our dataset code as a template and extend it as needed.

  1. Download the sample dataset

    wget -O bagel_example.zip \
      https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/bagel_example.zip
    unzip bagel_example.zip -d /data
    
  2. Expected hierarchy

    bagel_example
    ├── t2i/                           # text-to-image (parquet)
    ├── editing/                       # image editing (parquet)
    │   ├── seedxedit_multi/
    │   └── parquet_info/
    └── vlm/
        ├── images/                    # JPEG / PNG frames
        └── llava_ov_si.jsonl          # vision‑language SFT conversations
    
  3. Edit every your_data_path placeholder in data/dataset_info.py.

  4. (Optional) Extend DATASET_INFO with your own parquet shards or JSONL files to mix extra data.


Training

The baseline training recipe looks like this (replace environment variables with real paths or values):

# Pre-training
torchrun \
  --nnodes=$num_nodes \
  --node_rank=$node_rank \
  --nproc_per_node=8 \
  --master_addr=$master_addr \
  --master_port=$master_port \
  train/pretrain_unified_navit.py \
  --dataset_config_file ./data/configs/example.yaml \
  --llm_path $llm_path \
  --vae_path $vae_path \
  --vit_path $vit_path \
  --layer_module Qwen2MoTDecoderLayer \
  --use_flex True \
  --resume_from $resume_from \
  --results_dir $output_path \
  --checkpoint_dir $ckpt_path \
  --max_latent_size 64  # 32 for low-resolution pre-training

# Fine-tuning
torchrun \
  --nnodes=$num_nodes \
  --node_rank=$node_rank \
  --nproc_per_node=8 \
  --master_addr=$master_addr \
  --master_port=$master_port \
  train/pretrain_unified_navit.py \
  --dataset_config_file ./data/configs/example.yaml \
  --model_path $model_path \
  --layer_module Qwen2MoTDecoderLayer \
  --max_latent_size 64 \
  --resume-from $model_path \
  --finetune_from_hf True \
  --auto_resume True \
  --resume-model-only True \
  --finetune-from-ema True \
  --log_every 1 \
  --lr 2e-5 \
  --num_worker 1 \
  --expected_num_tokens 10240 \
  --max_num_tokens 11520 \
  --max_num_tokens_per_sample 10240
  • When fine-tuning BAGEL, set max_latent_size=64 to ensure the correct pretrained weights are loaded. If this is not set, an out-of-bounds error may occur.
  • The total value of num_used_data should be greater than NUM_GPUS × NUM_WORKERS. (For toy data, use num_worker=1.)
  • For T2I-only fine-tuning, set visual_und=False. For VLM-only fine-tuning, set visual_gen=False.
  • For debugging purposes, use smaller values for expected_num_tokens, max_num_tokens, and max_num_tokens_per_sample.
  • When fine-tuning on toy data, the loss behaves as follows:
    [2025-05-25 17:01:37] (step=0000000) Train Loss mse: 0.4063, Train Loss ce: 0.5504, Train Steps/Sec: 0.01, 
    [2025-05-25 17:01:40] (step=0000001) Train Loss mse: 0.4121, Train Loss ce: 0.8152, Train Steps/Sec: 0.44, 
    [2025-05-25 17:01:42] (step=0000002) Train Loss mse: 0.3876, Train Loss ce: 1.3411, Train Steps/Sec: 0.40, 
    [2025-05-25 17:01:45] (step=0000003) Train Loss mse: 0.3825, Train Loss ce: 0.7360, Train Steps/Sec: 0.44, 
    

You are encouraged to adjust any of these hyperparameters to fit your GPU budget and the scale of your dataset. If you encounter any issues, please open an issue for assistance. 🎉

Model config

ArgumentDefaultDescription
llm_pathhf/Qwen2.5-0.5B-InstructLanguage‑model backbone (HuggingFace repo or local folder).
vae_pathflux/vae/ae.safetensorsPre‑trained VAE checkpoint for latent diffusion.
vit_pathhf/siglip-so400m-14-980-flash-attn2-navitSigLIP ViT used for image understanding.
max_latent_size32Maximum latent grid side; defines highest generable resolution.
latent_patch_size2VAE pixels represented by one latent patch.
vit_max_num_patch_per_side70Max ViT patches per image side after resizing.
text_cond_dropout_prob0.1Probability to drop text conditioning while training.
vae_cond_dropout_prob0.3Dropout on VAE latent inputs.
vit_cond_dropout_prob0.3Dropout on visual features.

(See ModelArguments for many more options.)

Data config

ArgumentDefaultDescription
dataset_config_filedata/configs/example.yamlYAML that groups datasets and assigns sampling weights.
num_workers4Background workers per rank for the PyTorch DataLoader.
prefetch_factor2Batches pre‑fetched by each worker.
max_num_tokens_per_sample16384Skip raw samples longer than this.
max_num_tokens36864Hard cap for a packed batch (prevents OOM).
max_buffer_size50Overflow buffer length for oversized samples.
data_seed42Seed for reproducible shuffling and sampling.

Training config

ArgumentDefaultDescription
total_steps500_000Optimiser steps to run.
lr1e-4Peak learning rate after warm‑up.
lr_schedulerconstantLearning‑rate schedule (constant or cosine).
warmup_steps2000Linear warm‑up duration.
ema0.9999Exponential moving‑average decay for model weights.
max_grad_norm1.0Gradient‑clipping threshold.
save_every2000Checkpoint frequency (steps).
visual_gen / visual_undTrueEnable image generation / understanding branches.
freeze_llm / freeze_vit / freeze_vaeFalse / False / TrueFreeze selected modules to save VRAM or for ablations.
use_flexTrue (in example)Enable FLEX packing for higher GPU utilisation.
sharding_strategyHYBRID_SHARDFSDP sharding mode.
num_shard8Parameter shards per rank in HYBRID mode.

Distributed‑launch environment variables

VarMeaning
num_nodes / node_rankMulti‑node orchestration indices.
nproc_per_nodeNumber of GPUs per node.
master_addr / master_portNCCL rendezvous endpoint.

Logging config

ArgumentDefaultDescription
results_dirresultsRoot directory for logs and metrics.
checkpoint_dirresults/checkpointsCheckpoints are saved here.
log_every10Steps between console / W&B logs.
wandb_projectbagelWeights & Biases project name.
wandb_namerunRun name inside the project.
wandb_offlineFalseSwitch to offline mode (logs locally, sync later).
wandb_resumeallowResumption policy if an existing run ID is detected.

Tip Export WANDB_API_KEY before launching if you want online dashboards.