README.md

May 6, 2026 ยท View on GitHub

[TPAMI] Parallel Diffusion Solver via
Residual Dirichlet Policy Optimization

Our paper has been accepted to IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI). ๐ŸŽ‰

Note: This work extends EPD-Solver (ICCV 2025). You are currently in the default branch, EPD-Solver++. For those interested in our previous work, the original ICCV 2025 implementation is available in the EPD-Solver branch.

ย  ย  ย 

Algorithm Overview

Training Pipeline

EPD-Solver (Ensemble Parallel Direction) mitigates truncation errors in diffusion sampling by leveraging parallel gradient evaluations within a single step. We introduce a novel two-stage optimization framework that aligns the solver with human preferences without fine-tuning the heavy diffusion backbone.

1. Parallel Gradient Estimation

Instead of sequential evaluations, EPD-Solver computes gradients at multiple learned intermediate timesteps (ฯ„nk\tau_n^k) in parallel. By aggregating these gradients via a simplex-weighted sum, it achieves a higher-order approximation of the integral direction with negligible latency overhead on modern GPUs.

2. Two-Stage Optimization

  • Stage 1: Distillation-Based Initialization We first distill a few-step student solver by minimizing the trajectory error against a high-fidelity teacher (e.g., DPM-Solver). This provides a robust initialization that captures the trajectory curvature.

  • Stage 2: Residual Dirichlet Policy Optimization (RDPO) We reformulate the solver as a stochastic Dirichlet policy. Using a lightweight PPO variant, we fine-tune the solver's low-dimensional parameters (time segments and weights) to maximize human-aligned rewards (e.g., HPSv2, ImageReward). This ensures high perceptual quality and semantic alignment even at low NFEs (e.g., 20 steps).

Installation

  1. Create Environment

    conda env create -f environment.yml -n epd
    conda activate epd
    
  2. Install Dependencies

    # Core dependencies
    pip install omegaconf gdown lightning fairscale piq accelerate timm einops kornia HPSv2
    pip install --upgrade diffusers[torch]
    
    # CLIP & Transformers
    pip install git+https://github.com/openai/CLIP.git
    pip install transformers
    pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    
  3. Setup Environment Variables Important: Run this before training or inference.

    export PYTHONPATH="$PWD/training/ppo/reward_models/HPSv2:$PWD/src/taming-transformers:$PYTHONPATH"
    

    FLUX.1-dev is a gated Hugging Face model. For sampling, set FLUX_MODEL_PATH to a local snapshot when available. For training, exps/flux/flux-start.pkl stores the default black-forest-labs/FLUX.1-dev reference; make sure it is available in your Hugging Face cache, or set FLUX_ALLOW_REMOTE=1 after authentication.

Model Zoo & Checkpoints

We provide pre-trained predictors (Stage 1: Distilled) and RL-finetuned solvers (Stage 2: Best). Stage 2 models are optimized using Residual Dirichlet Policy Optimization for better human preference alignment. For FLUX.1-dev, flux-start.pkl is a manually initialized start table for PPO, not a distilled checkpoint.

ModelResolutionTypeDownload
Stable Diffusion v1.5512x512RL-Best (Stage 2)sd15-best.pkl
Distilled (Stage 1)sd15-distilled.pkl
SD3-Medium1024x1024RL-Best (Stage 2)sd3-1024-best.pkl
Distilled (Stage 1)sd3-1024-distilled.pkl
SD3-Medium512x512RL-Best (Stage 2)sd3-512-best.pkl
Distilled (Stage 1)sd3-512-distilled.pkl
FLUX.1-dev1024x1024RL-Best (Stage 2)flux-best.pkl
Start (manual init)flux-start.pkl

We also provide a detailed guide for each part below.

RDPO Training

To train EPD-Solver using RDPO:

# Available configs: sd15.yaml, sd3_512.yaml, sd3_1024.yaml, flux_dev.yaml
torchrun --master_port=12345 --nproc_per_node=1 -m training.ppo.launch \
    --config training/ppo/cfgs/sd3_1024.yaml

# FLUX.1-dev. This uses the checked-in exps/flux/flux-start.pkl table.
./train_flux.sh

# Or launch manually from the checked-in initial table. The FLUX model
# reference is read from the predictor metadata in exps/flux/flux-start.pkl.
python -m training.ppo.launch \
    --config training/ppo/cfgs/flux_dev.yaml

Note: RDPO training was performed using a single NVIDIA H200 GPU. Refer to launch.sh for full scripts.

Inference

To generate images with an EPD-Solver, use the examples below (replace checkpoint paths with your own exports as needed):

## SD1.5
MASTER_PORT=12345 python sample.py \
    --predictor_path exps/sd15/sd15-best.pkl \
    --prompt-file src/prompts/test.txt \
    --seeds "0-19" \
    --batch 4 \
    --outdir samples/sd15

## SD3-Medium
python sample_sd3.py --predictor exps/sd3-1024/sd3-1024-best.pkl \
  --seeds "0" \
  --outdir samples/sd3 \
  --prompt "..."

## FLUX.1-dev EPD
python sample_flux.py --predictor exps/flux/flux-best.pkl \
  --model-id /path/to/local/FLUX.1-dev \
  --prompt-file src/prompts/test.txt \
  --seeds "0" \
  --outdir samples/flux

Evaluation

We provide six metrics to evaluate generated images: HPSv2.1, PickScore, ImageReward, CLIP, Aesthetic, and MPS. Please refer to the evaluation script section in launch.sh.

Parameter Description

Sampling (sample.py)

ParameterDefaultWhat it controls
predictor_pathrequiredEPD predictor snapshot (.pkl); numeric IDs auto-resolve to the latest matching checkpoint in ./exps.
model_pathNoneOptional backbone checkpoint override; for SD3/FLUX this maps to model_name_or_path.
max_batch_size (--batch)64Per-process batch size; seeds are split across ranks.
seeds0-63Seed list or range; determines how many images are generated.
promptNoneSingle text prompt for all seeds; if omitted, falls back to prompt-file or MS-COCO eval captions for dataset_name=ms_coco.
prompt-fileNoneText or CSV (column text) with prompts; used when prompt is empty.
backendPredictor metadataOverride backbone (ldm/sd3/flux); defaults to what is stored in the predictor.
backend-configNoneJSON object overriding backend options (e.g., SD3/FLUX resolution, torch_dtype, offload, token).
use_fp16FalseReserved flag for mixed precision (not currently wired).
return_intersFalseReserved flag for saving intermediates (not currently wired).
outdirAuto (./samples/{dataset} or ./samples/grids/{dataset})Output root; falls back to a derived path when unset.
gridFalseSave a grid per batch instead of per-image files.
subdirsTrueWhen saving per-image files, create 1k-chunked subfolders.

Sampling (sample_sd3.py)

ParameterDefaultWhat it controls
predictorrequiredSD3 EPD predictor snapshot (.pkl).
seeds0-3Seed list or range; determines how many images are generated.
promptNoneSingle prompt for all seeds; if empty, uses prompt-file or falls back to empty prompts.
prompt-fileNoneText/CSV file with prompts; repeats to match seeds length.
outdir./samples/sd3_epdOutput directory.
gridFalseSave a grid per batch.
max-batch-size4Per-batch sample count (--max-batch-size).
resolutionPredictor/back-end config (512 or 1024)Optional override; must match predictor metadata if set.

Sampling (sample_flux.py)

ParameterDefaultWhat it controls
predictorrequiredFLUX EPD predictor snapshot (.pkl).
model-idPredictor metadata or FLUX_MODEL_PATHFLUX.1-dev repo id or local snapshot path.
seeds0Seed list or range; determines how many images are generated.
promptNoneSingle prompt for all seeds.
prompt-fileNoneText/CSV file with prompts; repeats to match seeds length.
outdir./samples/flux_epdOutput directory.
max-batch-size1Per-batch sample count.

FLUX.1-dev Notes

  • Supported FLUX variant: black-forest-labs/FLUX.1-dev.
  • FLUX support is fixed to 1024x1024, schedule_type=flowmatch, and embedded guidance scale 3.5.
  • The sampling scripts resolve FLUX locally first via FLUX_MODEL_PATH or the Hugging Face cache, then fall back to the Hugging Face repo id. Set FLUX_ALLOW_REMOTE=1 when intentionally loading the gated Hugging Face repo instead of a local snapshot.
  • exps/flux/flux-best.pkl is the released RL-best inference checkpoint. FLUX training starts from exps/flux/flux-start.pkl, a manually initialized start table that train_flux.sh uses directly before PPO launch.

Solver metadata (read from predictor checkpoints)

ParameterDefault sourceNotes
dataset_namePredictor ckptDataset tag (e.g., ms_coco); drives prompt fallback and output paths.
backend / backend_configPredictor ckptBackbone type plus stored options (resolution, flow-match params, offload/token settings for SD3/FLUX, etc.).
num_stepsPredictor ckptInference steps; base NFE 2*(num_steps-1) (minus one eval when afs=True, doubled again for CFG in ms_coco).
num_pointsPredictor ckptNumber of intermediate points per step; used for NFE reporting/outdir naming.
guidance_type / guidance_ratePredictor ckptCFG sampling (e.g., 4.5 for SD3 PPO configs, 7.5 for SD1.5).
schedule_type / schedule_rhoPredictor ckptflowmatch for SD3/FLUX, discrete for SD1.5.
sigma_min / sigma_maxPredictor or backendNoise range passed to scheduler (falls back to backend defaults when unset).
flowmatch_mu / flowmatch_shiftPredictor or backendFlow-matching parameters used by SD3/FLUX schedules.
afs, max_order, predict_x0, lower_order_finalPredictor ckptEPD/DPM solver behavior flags.

RDPO Training configs (training/ppo/cfgs/*.yaml)

Keysd3_512sd3_1024sd15flux_devPurpose
data.predictor_snapshotexps/sd3-512/...-distilled.pklexps/sd3-1024/...-distilled.pklexps/sd15/...-distilled.pklexps/flux/flux-start.pklStarting EPD predictor.
model.backendsd3sd3ldmfluxBackbone family used during RL.
model.resolution5121024n/a1024Training resolution for flow-matching backbones.
model.schedule_typeflowmatchflowmatchdiscreteflowmatchDiffusion schedule during RL.
model.guidance_rate4.54.57.53.5Guidance scale used while training the solver.
ppo.rollout_batch_size16888Samples per PPO rollout.
ppo.dirichlet_concentration10102010Dirichlet policy concentration.
reward.batch_size4441Reward evaluation batch size.
reward.multi.weightshps:1.0 (others 0)samesamesamePer-head reward weights.

Shared defaults across configs: model.dataset_name=ms_coco, model.guidance_type=cfg, model.schedule_rho=1.0, model.num_steps/num_points left null to inherit predictors, reward.type=multi, reward.enable_amp=true, reward.weights_path=weights/HPS_v2.1_compressed.pt, ppo.learning_rate=7e-5, ppo.minibatch_size=4, ppo.ppo_epochs=1, ppo.rloo_k=4, ppo.clip_range=0.2, ppo.kl_coef=0.0, ppo.entropy_coef=0.0, ppo.max_grad_norm=1.0, ppo.decode_rgb=true, logging.log_interval=1, run.output_root=exps, run.seed=0. The SD configs use ppo.steps=99999 and logging.save_interval=500; flux_dev sets sigma_min=0.001, sigma_max=1.0, ppo.steps=20000, and logging.save_interval=200.

Performance Highlights

T2I Performance

Citation

@misc{wang2025paralleldiffusionsolverresidual,
      title={Parallel Diffusion Solver via Residual Dirichlet Policy Optimization}, 
      author={Ruoyu Wang and Ziyu Li and Beier Zhu and Liangyu Yuan and Hanwang Zhang and Xun Yang and Xiaojun Chang and Chi Zhang},
      year={2025},
      eprint={2512.22796},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2512.22796}, 
}
@inproceedings{zhu2025distilling,
      title={Distilling Parallel Gradients for Fast ODE Solvers of Diffusion Models},
      author={Zhu, Beier and  Wang, Ruoyu and Zhao, Tong and Zhang, Hanwang and Zhang, Chi},
      booktitle={International Conference on Computer Vision (ICCV)},
      year={2025}
}