Model Training

May 22, 2026 · View on GitHub

This document introduces how to use DiffSynth-Studio for model training.

Script Parameters

Training scripts typically include the following parameters:

  • Dataset base configuration
    • --dataset_base_path: Root directory of the dataset.
    • --dataset_metadata_path: Metadata file path of the dataset.
    • --dataset_repeat: Number of times the dataset is repeated in each epoch.
    • --dataset_num_workers: Number of processes for each Dataloader.
    • --data_file_keys: Field names that need to be loaded from metadata, usually image or video file paths, separated by ,.
  • Model loading configuration
    • --model_paths: Paths of models to be loaded. JSON format.
    • --model_id_with_origin_paths: Model IDs with original paths, for example "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors". Separated by commas.
    • --extra_inputs: Extra input parameters required by the model Pipeline, for example, training image editing model Qwen-Image-Edit requires extra parameter edit_image, separated by ,.
    • --fp8_models: Models loaded in FP8 format, consistent with the format of --model_paths or --model_id_with_origin_paths. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
    • --resume_from_checkpoint: Load model weights from a checkpoint file and resume training. Currently only supports single model loading without LoRA.
  • Training base configuration
    • --learning_rate: Learning rate.
    • --num_epochs: Number of epochs.
    • --trainable_models: Trainable models, for example dit, vae, text_encoder.
    • --find_unused_parameters: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
    • --weight_decay: Weight decay size. See torch.optim.AdamW for details.
    • --task: Training task, default is sft. Some models support more training modes. Please refer to the documentation for each specific model.
  • Output configuration
    • --output_path: Model save path.
    • --remove_prefix_in_ckpt: Remove prefixes in the state dict of model files.
    • --save_steps: Interval of training steps for saving models. If this parameter is left blank, the model will be saved once per epoch.
  • LoRA configuration
    • --lora_base_model: Which model LoRA is added to.
    • --lora_target_modules: Which layers LoRA is added to.
    • --lora_rank: Rank of LoRA.
    • --lora_checkpoint: Path of LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
    • --preset_lora_path: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
    • --preset_lora_model: Model that preset LoRA is merged into, for example dit.
  • Gradient configuration
    • --use_gradient_checkpointing: Whether to enable gradient checkpointing.
    • --use_gradient_checkpointing_offload: Whether to offload gradient checkpointing to memory.
    • --gradient_accumulation_steps: Number of gradient accumulation steps.
  • CPU Offload training configuration
    • --cpu_offload: Enable CPU offload training. Weights are kept on CPU and loaded to GPU one layer at a time.
    • --optimize_on_cpu: When --cpu_offload is enabled, run optimizer on CPU. All params are offloaded to CPU. Default is False (trainable params stay on GPU, optimizer on GPU).
    • --param_size_threshold: (Experimental) When --cpu_offload is enabled, modules with total params above this threshold (in MB) are recursively split into children. None means offload every leaf module directly. Default: None.
  • Image dimension configuration (applicable to image generation models and video generation models)
    • --height: Height of images or videos. Leave height and width blank to enable dynamic resolution.
    • --width: Width of images or videos. Leave height and width blank to enable dynamic resolution.
    • --max_pixels: Maximum pixel area of images or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be scaled down, and images with resolution smaller than this value will remain unchanged.

Some models' training scripts also contain additional parameters. See the documentation for each model for details.

Preparing Datasets

DiffSynth-Studio adopts a universal dataset format. The dataset contains a series of data files (images, videos, etc.) and annotated metadata files. We recommend organizing dataset files as follows:

data/example_image_dataset/
├── metadata.csv
├── image_1.jpg
└── image_2.jpg

Where image_1.jpg, image_2.jpg are training image data, and metadata.csv is the metadata list, for example:

image,prompt
image_1.jpg,"a dog"
image_2.jpg,"a cat"

We have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to diffsynth.core.data.

Sample Dataset
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset

Loading Models

Similar to model loading during inference, we support multiple ways to configure model paths, and the two methods can be mixed.

Download and load models from remote sources

If we load models during inference through the following settings:

model_configs=[
    ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
    ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
    ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
]

Then during training, fill in the following parameters to load the corresponding models:

--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors"

Model files are downloaded to the ./models path by default, which can be modified through environment variable DIFFSYNTH_MODEL_BASE_PATH.

By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set environment variable DIFFSYNTH_SKIP_DOWNLOAD to True.

Load models from local file paths

If loading models from local files during inference, for example:

model_configs=[
    ModelConfig([
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
    ]),
    ModelConfig([
        "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
        "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
        "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
        "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
    ]),
    ModelConfig("models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors")
]

Then during training, set to:

--model_paths '[
    [
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
        "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
    ],
    [
        "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
        "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
        "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
        "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
    ],
    "models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors"
]' \

Note that --model_paths is in JSON format, and extra , cannot appear in it, otherwise it cannot be parsed normally.

Setting Trainable Modules

The training framework supports training of any model. Taking Qwen-Image as an example, to fully train the DiT model, set to:

--trainable_models "dit"

To train LoRA of the DiT model, set to:

--lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32

We hope to leave enough room for technical exploration, so the framework supports training any number of modules simultaneously. For example, to train the text encoder, controlnet, and LoRA of the DiT simultaneously:

--trainable_models "text_encoder,controlnet" --lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32

Additionally, since the training script loads multiple modules (text encoder, dit, vae, etc.), prefixes need to be removed when saving model files. For example, when fully training the DiT part or training the LoRA model of the DiT part, please set --remove_prefix_in_ckpt pipe.dit.. If multiple modules are trained simultaneously, developers need to write code to split the state dict in the model file after training is completed.

Starting the Training Program

The training framework is built on accelerate. Training commands are written in the following format:

accelerate launch xxx/train.py \
  --xxx yyy \
  --xxxx yyyy

We have written preset training scripts for each model. See the documentation for each model for details.

By default, accelerate will train according to the configuration in ~/.cache/huggingface/accelerate/default_config.yaml. Use accelerate config to configure interactively in the terminal, including multi-GPU training, DeepSpeed, etc.

We provide recommended accelerate configuration files for some models, which can be set through --config_file. For example, full training of the Qwen-Image model:

accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
  --dataset_base_path data/example_image_dataset \
  --dataset_metadata_path data/example_image_dataset/metadata.csv \
  --max_pixels 1048576 \
  --dataset_repeat 50 \
  --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
  --learning_rate 1e-5 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Qwen-Image_full" \
  --trainable_models "dit" \
  --use_gradient_checkpointing \
  --find_unused_parameters

Training Considerations

  • In addition to the csv format, dataset metadata also supports json and jsonl formats. For how to choose the best metadata format, please refer to ../API_Reference/core/data.md#metadata
  • Training effectiveness is usually strongly correlated with training steps and weakly correlated with epoch count. Therefore, we recommend using the --save_steps parameter to save model files at training step intervals.
  • When data volume * dataset_repeat exceeds $1$0^{9}$$, we observed that the dataset speed becomes significantly slower, which seems to be a PyTorch bug. We are not sure if newer versions of PyTorch have fixed this issue.
  • For learning rate --learning_rate, it is recommended to set to 1e-4 in LoRA training and 1e-5 in full training.
  • The training framework does not support batch size > 1. The reasons are complex. See Q&A: Why doesn't the training framework support batch size > 1?
  • Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, --find_unused_parameters needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
  • The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting --num_epochs to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
  • --use_gradient_checkpointing is usually enabled unless GPU VRAM is sufficient; --use_gradient_checkpointing_offload is enabled as needed. See diffsynth.core.gradient for details.
  • To load a previously trained model checkpoint and resume training, use --lora_checkpoint to load a LoRA checkpoint, or use --resume_from_checkpoint to load a base model checkpoint. Currently only supports single model loading.

Low VRAM Training

The framework supports multiple methods to reduce VRAM usage during training:

NameHow to EnableTechnical PrincipleEffectWhen to EnableReference
Gradient CheckpointingEnable via --use_gradient_checkpointingDoes not retain gradient-related parameters during forward pass; recomputes them during backward passSignificantly reduces VRAM usage, increases computation timeRecommended in most casesDocs
Gradient Checkpointing OffloadEnable via --use_gradient_checkpointing_offloadOn top of Gradient Checkpointing, moves checkpointed parameters from VRAM to RAMFurther reduces VRAM usage and increases computation time, also increases RAM usageOnly recommended for video generation model trainingDocs
DeepSpeedConfigure interactively via accelerate configDeepSpeed supports sharding gradients, optimizer states, etc. across multiple GPUsReduces VRAM usage, increases communication cost between GPUs and machines, increases computation timeOnly recommended for multi-GPU and multi-node cluster trainingDocs
FP8 TrainingSet which model components to switch to FP8 mode via --fp8_modelsStores model parameters in FP8 precision in VRAM, temporarily converts to higher precision during inference; only supports models that don't require gradient updatesReduces VRAM usage, slightly increases computation time, introduces minor training errorOnly recommended for non-training modules like text_encoder and vae; can also be enabled for dit during LoRA trainingDocs
Two-Stage Split TrainingComplex setup, please refer to the docs. Note that Wan series models have different activation methods, please refer to the corresponding code examples.Splits training into two stages: first stage performs gradient-free computation and saves intermediate results to disk; second stage computes gradients and updates model parameters.Reduces VRAM usage, increases computation speed, uses additional disk spaceSome models' two-stage training has not been verified, use with cautionDocs
CPU OffloadEnable via --enable_model_cpu_offloadKeeps model in RAM during training, moves layers to VRAM one by one for forward and backward passesReduces VRAM usage, increases computation time, increases RAM usageOnly recommended for single GPU with extremely limited VRAMDocs