Pretrain and SFT
April 21, 2025 ยท View on GitHub
For pretraining and SFT, please follow the instructions below to install the environment:
python3 -m venv env
source env/bin/activate
pip install -e ".[train]"
Data Preparation
We cache the visual tokens for efficient training. Below is the command to extract visual tokens with Cosmos Tokenizer:
torchrun \
--nnodes=1 --nproc_per_node=8 \
simpar/data/extract_token.py \
--dataset_type "image" \
--dataset_name "example" \
--code_path "/path_to_saved_tokens" \
--gen_data_path "/path_to_meta_json" \
--gen_resolution 1024
You can specify the meta data file with --gen_data_path, which should be a json file with the following format:
{
"image_path": "path_to_image",
"caption": "a photo of a cat"
}
After this, you can use ./scripts/tokens/generate_meta.py to prepare a meta file.
Launch Training
For both pretraining and SFT, we use the following command to train the model:
ACCELERATE_CPU_AFFINITY=1 \
torchrun \
--nnodes=4 \
--nproc_per_node=8 \
simpar/train/train_mem.py \
--deepspeed scripts/zero3.json \
--model_name_or_path "/path_to_your_dir/Qwen2.5-0.5B-Instruct" \
--version "qwen_1_5" \
--gen_data_path /path_to_annotation_file \
--gen_image_folder "" \
--sample_short True \
--mm_tunable_parts="mm_language_model" \
--p_drop_cond 0.1 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--mm_patch_merge_type spatial_unpad \
--bf16 True \
--run_name test \
--output_dir /path_to_output_dir \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 5000 \
--learning_rate 1e-4 \
--weight_decay 0.01 \
--warmup_ratio 0.0 \
--lr_scheduler_type "constant" \
--logging_steps 1 \
--tf32 True \
--model_max_length 1536 \
--dataloader_num_workers 16 \
--lazy_preprocess True \
--torch_compile True \
--torch_compile_backend "inductor" \
--dataloader_drop_last True \
--report_to wandb \
--attn_implementation sdpa
We set --model_max_length to # of visual tokens + 512, i.e., 1536 for 512 pretraining and 4608 for 1024 SFT.
GRPO Training
We strongly recommend you to maintain different python environments for pretraining/SFT and GRPO training, using venv or conda:
python3 -m venv env_rl
source env_rl/bin/activate
pip install -e ".[train]"
pip install vllm==0.7.2 # important!
pip install wheel
pip install flash-attn --no-build-isolation
pip install "transformers@git+https://github.com/huggingface/transformers.git@7bbc62474391aff64f63fcc064c975752d1fa4de"
git clone https://github.com/huggingface/trl
cd trl
git reset --hard 69ad852e5654a77f1695eb4c608906fe0c7e8624 # specify the commit id!
pip install -e .
cd ..
mv trl trl_arxiv
mv trl_arxiv/trl ./
rm -rf trl_arxiv
pip uninstall bitsandbytes -y
pip install outlines==0.0.46
pip install latex2sympy2_extended math_verify
pip install clint
sudo apt-get install python3-tk -y
We follow Open-R1 to implement GRPO training with trl, please first set up the environment following the instructions in INSTALL.md. Then you can run:
accelerate launch --main_process_port 1234 --config_file simpar/configs/accelerate_configs/zero3.yaml \
--num_processes=7 simpar/train/llava_trainer_grpo.py \
--config simpar/configs/config_grpo.yaml \
--data_path /path_to_annotation_file
Note that trl uses 1 separate GPU for online generation (with vLLM), therefore, we recommend you to use at least 2 GPUs for training. Please refer to their documents for more details here.
We spent lots of time to tune the hyper-parameters and improve the training efficiency. After this, we observed quite promising reward curves ๐:
