AR-RAG: Autoregressive Retrieval Augmentation for Image Generation

June 17, 2025 · View on GitHub

arXiv Hugging Face Models License: MIT GitHub Stars

This repository contains the official implementation of AR-RAG: Autoregressive Retrieval Augmentation for Image Generation.

AR-RAG Showcase

Contents

Overview

AR-RAG introduces a novel retrieval augmentation paradigm that enhances modern photorealistic image generation by augmenting image predictions with k-nearest neighbor (k-NN) retrievals at the patch level. Unlike existing approaches that rely on full-image retrieval conditioned on textual captions, AR-RAG retrieves locally similar patches based on their surrounding visual context, enabling caption-free retrieval while enforcing spatial coherence and semantic consistency for higher-quality image generation.

We propose two parallel frameworks:

  1. Distribution-Augmentation in Decoding (DAiD): A training-free decoding strategy that directly merges the distribution of model-predicted patches with the distribution of retrieved patches.

  2. Feature-Augmentation in Decoding (FAiD): A parameter-efficient fine-tuning method that smoothly integrates retrieved patches into the generation process via convolution operations.

Performance Highlights

Our methods significantly improve image generation quality across multiple benchmarks:

GenEval Benchmark

MethodSingle Obj.Two Obj.CountingColorsPositionColor Attri.Overall ↑
Janus-Pro0.980.770.520.840.610.550.71
DAiD (ours)0.980.820.540.870.630.490.72
FAiD (ours)1.000.920.410.870.710.600.75

DPG-Bench

MethodGlobalEntityAttributeRelationOtherOverall ↑
Janus-Pro81.7684.5384.3492.2275.2077.26
DAiD (ours)83.5884.4684.7691.4976.4077.88
FAiD (ours)82.6785.8085.3892.376.8079.36

MSCOCO and Midjourney Benchmarks (FID ↓)

ModelMSCOCO FIDMidjourney FID
Janus-Pro19.5912.81
DAiD (ours)18.0211.93
FAiD (ours)17.609.31

Model Zoo

ModelDescriptionSizeHF Link
AR-RAG-FAiDFine-tuned model with Smoothly Feature Blending1.2B🤗 Model

Patch-level Retrieval Database

Data SourceImage NumSuggest GPU MemoryHF Link
JourneyDB1M12 GBZIP
CC12M12M96 GBZIP
DataCamp70M-🤗 Coming soon

Installation

git clone https://github.com/PLUM-Lab/AR-RAG.git
cd AR-RAG

# Create and activate conda environment
conda env create -f arrag.yml

Patch-level Retrieval Database & Retriever Construction

Download the checkpoint of VQ-VAE model from LlamaGen

wget -P arrag/Janus/janus https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt

Construct Retreiver from Image Data

bash arrag/build_retriever/build_retriever.sh

The output faiss index will be: data/retriever/index_L

Download Pre-built Retrieval Database

# Download pre-built retrieval database
wget http://nlplab1.cs.vt.edu/~jingyuan/AR-RAG/retrieval_db.zip

Training

FAiD Model Training

bash ./arrag/train/train_FAiD.sh

The default output checkpoint path: result/ckpts/ckpts_FAiD_bx_hx.

Text to Image Generation

AR-RAG Showcase

DAiD

python arrag/t2i_example/t2i_daid_L.sh

The default output image path: result/generated_imgs/example_t2i_daid.jpg.

FAiD

python arrag/t2i_example/t2i_faid_L.sh

The default output image path: result/generated_imgs/example_t2i_faid.jpg.

License

This project is licensed under the MIT License - see the LICENSE file for details.