README.md
May 6, 2026 ยท View on GitHub
[TPAMI] Parallel Diffusion Solver via
Residual Dirichlet Policy Optimization
Our paper has been accepted to IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI). ๐
Note: This work extends EPD-Solver (ICCV 2025). You are currently in the default branch,
EPD-Solver++. For those interested in our previous work, the original ICCV 2025 implementation is available in theEPD-Solverbranch.
Algorithm Overview
EPD-Solver (Ensemble Parallel Direction) mitigates truncation errors in diffusion sampling by leveraging parallel gradient evaluations within a single step. We introduce a novel two-stage optimization framework that aligns the solver with human preferences without fine-tuning the heavy diffusion backbone.
1. Parallel Gradient Estimation
Instead of sequential evaluations, EPD-Solver computes gradients at multiple learned intermediate timesteps () in parallel. By aggregating these gradients via a simplex-weighted sum, it achieves a higher-order approximation of the integral direction with negligible latency overhead on modern GPUs.
2. Two-Stage Optimization
-
Stage 1: Distillation-Based Initialization We first distill a few-step student solver by minimizing the trajectory error against a high-fidelity teacher (e.g., DPM-Solver). This provides a robust initialization that captures the trajectory curvature.
-
Stage 2: Residual Dirichlet Policy Optimization (RDPO) We reformulate the solver as a stochastic Dirichlet policy. Using a lightweight PPO variant, we fine-tune the solver's low-dimensional parameters (time segments and weights) to maximize human-aligned rewards (e.g., HPSv2, ImageReward). This ensures high perceptual quality and semantic alignment even at low NFEs (e.g., 20 steps).
Installation
-
Create Environment
conda env create -f environment.yml -n epd conda activate epd -
Install Dependencies
# Core dependencies pip install omegaconf gdown lightning fairscale piq accelerate timm einops kornia HPSv2 pip install --upgrade diffusers[torch] # CLIP & Transformers pip install git+https://github.com/openai/CLIP.git pip install transformers pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers -
Setup Environment Variables Important: Run this before training or inference.
export PYTHONPATH="$PWD/training/ppo/reward_models/HPSv2:$PWD/src/taming-transformers:$PYTHONPATH"FLUX.1-dev is a gated Hugging Face model. For sampling, set
FLUX_MODEL_PATHto a local snapshot when available. For training,exps/flux/flux-start.pklstores the defaultblack-forest-labs/FLUX.1-devreference; make sure it is available in your Hugging Face cache, or setFLUX_ALLOW_REMOTE=1after authentication.
Model Zoo & Checkpoints
We provide pre-trained predictors (Stage 1: Distilled) and RL-finetuned solvers (Stage 2: Best).
Stage 2 models are optimized using Residual Dirichlet Policy Optimization for better human preference alignment.
For FLUX.1-dev, flux-start.pkl is a manually initialized start table for PPO, not a distilled checkpoint.
| Model | Resolution | Type | Download |
|---|---|---|---|
| Stable Diffusion v1.5 | 512x512 | RL-Best (Stage 2) | sd15-best.pkl |
| Distilled (Stage 1) | sd15-distilled.pkl | ||
| SD3-Medium | 1024x1024 | RL-Best (Stage 2) | sd3-1024-best.pkl |
| Distilled (Stage 1) | sd3-1024-distilled.pkl | ||
| SD3-Medium | 512x512 | RL-Best (Stage 2) | sd3-512-best.pkl |
| Distilled (Stage 1) | sd3-512-distilled.pkl | ||
| FLUX.1-dev | 1024x1024 | RL-Best (Stage 2) | flux-best.pkl |
| Start (manual init) | flux-start.pkl |
We also provide a detailed guide for each part below.
RDPO Training
To train EPD-Solver using RDPO:
# Available configs: sd15.yaml, sd3_512.yaml, sd3_1024.yaml, flux_dev.yaml
torchrun --master_port=12345 --nproc_per_node=1 -m training.ppo.launch \
--config training/ppo/cfgs/sd3_1024.yaml
# FLUX.1-dev. This uses the checked-in exps/flux/flux-start.pkl table.
./train_flux.sh
# Or launch manually from the checked-in initial table. The FLUX model
# reference is read from the predictor metadata in exps/flux/flux-start.pkl.
python -m training.ppo.launch \
--config training/ppo/cfgs/flux_dev.yaml
Note: RDPO training was performed using a single NVIDIA H200 GPU. Refer to launch.sh for full scripts.
Inference
To generate images with an EPD-Solver, use the examples below (replace checkpoint paths with your own exports as needed):
## SD1.5
MASTER_PORT=12345 python sample.py \
--predictor_path exps/sd15/sd15-best.pkl \
--prompt-file src/prompts/test.txt \
--seeds "0-19" \
--batch 4 \
--outdir samples/sd15
## SD3-Medium
python sample_sd3.py --predictor exps/sd3-1024/sd3-1024-best.pkl \
--seeds "0" \
--outdir samples/sd3 \
--prompt "..."
## FLUX.1-dev EPD
python sample_flux.py --predictor exps/flux/flux-best.pkl \
--model-id /path/to/local/FLUX.1-dev \
--prompt-file src/prompts/test.txt \
--seeds "0" \
--outdir samples/flux
Evaluation
We provide six metrics to evaluate generated images: HPSv2.1, PickScore, ImageReward, CLIP, Aesthetic, and MPS. Please refer to the evaluation script section in launch.sh.
Parameter Description
Sampling (sample.py)
| Parameter | Default | What it controls |
|---|---|---|
predictor_path | required | EPD predictor snapshot (.pkl); numeric IDs auto-resolve to the latest matching checkpoint in ./exps. |
model_path | None | Optional backbone checkpoint override; for SD3/FLUX this maps to model_name_or_path. |
max_batch_size (--batch) | 64 | Per-process batch size; seeds are split across ranks. |
seeds | 0-63 | Seed list or range; determines how many images are generated. |
prompt | None | Single text prompt for all seeds; if omitted, falls back to prompt-file or MS-COCO eval captions for dataset_name=ms_coco. |
prompt-file | None | Text or CSV (column text) with prompts; used when prompt is empty. |
backend | Predictor metadata | Override backbone (ldm/sd3/flux); defaults to what is stored in the predictor. |
backend-config | None | JSON object overriding backend options (e.g., SD3/FLUX resolution, torch_dtype, offload, token). |
use_fp16 | False | Reserved flag for mixed precision (not currently wired). |
return_inters | False | Reserved flag for saving intermediates (not currently wired). |
outdir | Auto (./samples/{dataset} or ./samples/grids/{dataset}) | Output root; falls back to a derived path when unset. |
grid | False | Save a grid per batch instead of per-image files. |
subdirs | True | When saving per-image files, create 1k-chunked subfolders. |
Sampling (sample_sd3.py)
| Parameter | Default | What it controls |
|---|---|---|
predictor | required | SD3 EPD predictor snapshot (.pkl). |
seeds | 0-3 | Seed list or range; determines how many images are generated. |
prompt | None | Single prompt for all seeds; if empty, uses prompt-file or falls back to empty prompts. |
prompt-file | None | Text/CSV file with prompts; repeats to match seeds length. |
outdir | ./samples/sd3_epd | Output directory. |
grid | False | Save a grid per batch. |
max-batch-size | 4 | Per-batch sample count (--max-batch-size). |
resolution | Predictor/back-end config (512 or 1024) | Optional override; must match predictor metadata if set. |
Sampling (sample_flux.py)
| Parameter | Default | What it controls |
|---|---|---|
predictor | required | FLUX EPD predictor snapshot (.pkl). |
model-id | Predictor metadata or FLUX_MODEL_PATH | FLUX.1-dev repo id or local snapshot path. |
seeds | 0 | Seed list or range; determines how many images are generated. |
prompt | None | Single prompt for all seeds. |
prompt-file | None | Text/CSV file with prompts; repeats to match seeds length. |
outdir | ./samples/flux_epd | Output directory. |
max-batch-size | 1 | Per-batch sample count. |
FLUX.1-dev Notes
- Supported FLUX variant:
black-forest-labs/FLUX.1-dev. - FLUX support is fixed to
1024x1024,schedule_type=flowmatch, and embedded guidance scale3.5. - The sampling scripts resolve FLUX locally first via
FLUX_MODEL_PATHor the Hugging Face cache, then fall back to the Hugging Face repo id. SetFLUX_ALLOW_REMOTE=1when intentionally loading the gated Hugging Face repo instead of a local snapshot. exps/flux/flux-best.pklis the released RL-best inference checkpoint. FLUX training starts fromexps/flux/flux-start.pkl, a manually initialized start table thattrain_flux.shuses directly before PPO launch.
Solver metadata (read from predictor checkpoints)
| Parameter | Default source | Notes |
|---|---|---|
dataset_name | Predictor ckpt | Dataset tag (e.g., ms_coco); drives prompt fallback and output paths. |
backend / backend_config | Predictor ckpt | Backbone type plus stored options (resolution, flow-match params, offload/token settings for SD3/FLUX, etc.). |
num_steps | Predictor ckpt | Inference steps; base NFE 2*(num_steps-1) (minus one eval when afs=True, doubled again for CFG in ms_coco). |
num_points | Predictor ckpt | Number of intermediate points per step; used for NFE reporting/outdir naming. |
guidance_type / guidance_rate | Predictor ckpt | CFG sampling (e.g., 4.5 for SD3 PPO configs, 7.5 for SD1.5). |
schedule_type / schedule_rho | Predictor ckpt | flowmatch for SD3/FLUX, discrete for SD1.5. |
sigma_min / sigma_max | Predictor or backend | Noise range passed to scheduler (falls back to backend defaults when unset). |
flowmatch_mu / flowmatch_shift | Predictor or backend | Flow-matching parameters used by SD3/FLUX schedules. |
afs, max_order, predict_x0, lower_order_final | Predictor ckpt | EPD/DPM solver behavior flags. |
RDPO Training configs (training/ppo/cfgs/*.yaml)
| Key | sd3_512 | sd3_1024 | sd15 | flux_dev | Purpose |
|---|---|---|---|---|---|
data.predictor_snapshot | exps/sd3-512/...-distilled.pkl | exps/sd3-1024/...-distilled.pkl | exps/sd15/...-distilled.pkl | exps/flux/flux-start.pkl | Starting EPD predictor. |
model.backend | sd3 | sd3 | ldm | flux | Backbone family used during RL. |
model.resolution | 512 | 1024 | n/a | 1024 | Training resolution for flow-matching backbones. |
model.schedule_type | flowmatch | flowmatch | discrete | flowmatch | Diffusion schedule during RL. |
model.guidance_rate | 4.5 | 4.5 | 7.5 | 3.5 | Guidance scale used while training the solver. |
ppo.rollout_batch_size | 16 | 8 | 8 | 8 | Samples per PPO rollout. |
ppo.dirichlet_concentration | 10 | 10 | 20 | 10 | Dirichlet policy concentration. |
reward.batch_size | 4 | 4 | 4 | 1 | Reward evaluation batch size. |
reward.multi.weights | hps:1.0 (others 0) | same | same | same | Per-head reward weights. |
Shared defaults across configs: model.dataset_name=ms_coco, model.guidance_type=cfg, model.schedule_rho=1.0, model.num_steps/num_points left null to inherit predictors, reward.type=multi, reward.enable_amp=true, reward.weights_path=weights/HPS_v2.1_compressed.pt, ppo.learning_rate=7e-5, ppo.minibatch_size=4, ppo.ppo_epochs=1, ppo.rloo_k=4, ppo.clip_range=0.2, ppo.kl_coef=0.0, ppo.entropy_coef=0.0, ppo.max_grad_norm=1.0, ppo.decode_rgb=true, logging.log_interval=1, run.output_root=exps, run.seed=0. The SD configs use ppo.steps=99999 and logging.save_interval=500; flux_dev sets sigma_min=0.001, sigma_max=1.0, ppo.steps=20000, and logging.save_interval=200.
Performance Highlights
Citation
@misc{wang2025paralleldiffusionsolverresidual,
title={Parallel Diffusion Solver via Residual Dirichlet Policy Optimization},
author={Ruoyu Wang and Ziyu Li and Beier Zhu and Liangyu Yuan and Hanwang Zhang and Xun Yang and Xiaojun Chang and Chi Zhang},
year={2025},
eprint={2512.22796},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2512.22796},
}
@inproceedings{zhu2025distilling,
title={Distilling Parallel Gradients for Fast ODE Solvers of Diffusion Models},
author={Zhu, Beier and Wang, Ruoyu and Zhao, Tong and Zhang, Hanwang and Zhang, Chi},
booktitle={International Conference on Computer Vision (ICCV)},
year={2025}
}