Dual Diffusion for Unified Image Generation & Understanding

March 3, 2025 ยท View on GitHub

paper | webpage

1. Environment

We provided an environment file for pip install.

pip install -r requirements.txt

2. Inference of pretrained model

The pre-trained checkpoints can be downloaded through belowing links

NameSpecificationLink
512 BaseDual-diffusion pretrained, can do generation and captionmodel
512 SFTSFT on LLaVA data, can do generation and vqamodel

After downloading the checkpoints, check the Jupyter notebook notebooks/demo.ipynb for example usage.

Minimal working example:

from sd3_modules.dual_diff_pipeline import DualDiffSD3Pipeline

dual_diff_pipe = DualDiffSD3Pipeline.from_pretrained("./pretrained_models/dual_diff_sd3_512_base", torch_dtype=torch.bfloat16).to('cuda')
imgs = dual_diff_pipe(
        prompt="A gourmet hamburger set on a rustic wooden table. The burger is made with a perfectly grilled, juicy beef patty topped with melted gourmet cheese, crispy bacon, fresh lettuce, ripe tomatoes, and caramelized onions.",
        height=512,
        width=512,
        num_images_per_prompt=1)

3. Data

We support two kinds of image-text data, wrapped webdataset (.tar) data and unwrapped data. For unwrapped data, a json file storing data information is needed. The meta data should be a list of dictionaries like below:

[
    {
        "image_path": "images/img1.jpg",
        "ratio": 1.33,
        "height": 600,
        "width": 800,
        "caption": "A sunny day in the park.",
        "re_caption": "A bright, lively park scene."
    },
    {
        "image_path": "images/img2.jpg",
        "ratio": 0.75,
        "height": 400,
        "width": 300,
        "caption": "A night sky full of stars.",
        "re_caption": "The starry night illuminates the scene."
    },

]

Following dataset are used in our project:

NameUsageLink
Datacomp-recapBase pretrainingdata
ShareGPT4V pretrainT5 embedding alignment, text diffusion trainingdata
LAION aestheticImage diffusion trainingdata
MidJourney 1.1MImage diffusion trainingdata
LLaVA 1.5Text SFTdata

4. Training

To train the model, a SD3-medium checkpoint is needed, which can be downloaded from here. In addition, we also have an aligned embedding that corresponds to the "mask token" in T5's volcabulary here.

The example configuration for training are provided under configs directory. We use 32 H100 for the pretraining, 16 A100 for SFT.

  • Dual-diffusion training on image-text data (fill the torchrun argument with your machine's setting):
export OMP_NUM_THREADS=8
precision=bf16

torchrun --nnodes $WORKER_NUM \
    --node_rank $ID \
    --nproc_per_node $WORKER_GPU \
    --master_addr $WORKER_0_HOST \
    --master_port $port \
    train_dual_diffusion_sd3.py  \
        --config configs/dual_diff_pretrain.py \
        --results_dir results/ \
        --model_parallel_size 1 \
        --data_parallel h_sdp \
        --precision ${precision} --grad_precision fp32
  • Supervised fine-tune (text diffusion with prompt + image diffusion) on some visual-instruction dataset (we used LLaVA 1.5's):
export OMP_NUM_THREADS=8
precision=bf16

torchrun --nnodes $WORKER_NUM \
    --node_rank $ID \
    --nproc_per_node $WORKER_GPU \
    --master_addr $WORKER_0_HOST \
    --master_port $port \
    train_dual_diffusion_sd3_sft.py  \
        --config configs/dual_diff_sft.py \
        --results_dir results/ \
        --model_parallel_size 1 \
        --data_parallel h_sdp \
        --precision ${precision} --grad_precision fp32 \
        --resume_t5 ${t5_mask_emb_pth}

Acknowledgement

The implementation of this project is inspired from the great codebase of PixArt, MDLM, Lumina-Next.

Our DiT backbone is finetuned from SD3-medium.

Reference

If you find this project useful, please kindly consider citing our work:

@misc{li2024dualdiffusionunifiedimage,
      title={Dual Diffusion for Unified Image Generation and Understanding}, 
      author={Zijie Li and Henry Li and Yichun Shi and Amir Barati Farimani and Yuval Kluger and Linjie Yang and Peng Wang},
      year={2024},
      eprint={2501.00289},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2501.00289}, 
}