RAPID: Long-Context Inference with Retrieval-Augmented Speculative Decoding

March 2, 2025 ยท View on GitHub

This repo provides the official implementation of our paper "Long-Context Inference with Retrieval-Augmented Speculative Decoding".

Updates

  • [2024.2.28] ๐Ÿš€ Release the paper and code of RAPID.

Highlights

  • Using RAG to speculate (and then accelerate) the generation of long-context LLMs.
  • Using retrieval-augmented target distribution to incorporate benefits from both RAG and long-context LLMs.
  • Got 2x speedup with performance improvements when using self-speculation.

Quick Start

  1. Retrieval relevant chunks from a long context for a query. We provide a simple and effective RAG pipeline based on BGE-M3 embedding model.
from src.rag import get_rag_context
import math


long_context = ""
query = ""

target_rag_length = 8192
token_per_chunk = 512
num_to_retrieve = math.ceil(target_rag_length / token_per_chunk)

rag_context, num_rag_chunks = get_rag_context(
  long_context,
  query,
  embed_model_path="BAAI/bge-m3",
  tokenizer=tokenizer,
  num_to_retrieve=num_to_retrieve,
  max_chunk_token=token_per_chunk,
  rag_threshold=0.3
)
  1. We implemented RAPID by wrapping the huggingface transformers generate with a monkey patch:
from src.utils import generate, _get_candidate_generator, _assisted_decoding
import transformers


transformers.generation.utils.GenerationMixin.generate = generate
transformers.generation.utils.GenerationMixin._get_candidate_generator = _get_candidate_generator
transformers.generation.utils.GenerationMixin._assisted_decoding = _assisted_decoding
  1. Now RAPID can be utilized by calling model.generate:
from src.utils import load_model_tokenizer

target_model, target_tok = load_model_tokenizer(
  "meta-llama/Llama-3.1-8B-Instruct",
  device_list="0"
)
draft_model, draft_tok = target_model, target_tok # using self-sepcualtion

# using upward-speculation
"""
draft_model, draft_tok = load_model_tokenizer(
  "meta-llama/Llama-3.1-70B-Instruct",
  device_list="1,2,3,4,5,6,7" # depend on your GPUs
)
"""
long_input = long_context + "\n" + query # also make your instructions here
rag_input = rag_context + "\n" + query # also make your instructions here

input_ids = self.tokenizer([long_input], return_tensors="pt", add_special_tokens=False).input_ids.to(target_model.device)
draft_input_ids = self.assistant_tokenizer([rag_input], return_tensors="pt", add_special_tokens=False).input_ids.to(draft_model.device)
            
outputs = target_model.generate(
    input_ids, 
  	assistant_input_ids=draft_input_ids, 
    eos_token_id=target_tok.eos_token_id,  
    assistant_model=draft_model, 
    speculative_margin=10, # \eta in Eq. (6) 
    use_cache=True,
  	max_new_tokens=1024,
    **generation_kwargs
    # do_sample=False, 
    # top_p=1, 
    # top_k=-1, 
    # temperature=1,
)
responses = self.tokenizer.batch_decode(outputs[:,input_ids.shape[1]:], skip_special_tokens=True)[0]
print(responses)

Evaluation

LongBench V2

  • Evaluate with temperature=0.1 following official settings. To avoid the randomness issue when evaluating with vllm in official repo, we provide an evaluation script based on transformers generate.
LLaMA-3.1 Series
Target ModelDraft ModelฮทCoTOveralEasyHardShortMediumLong
LLaMA-3.1-8B-Instruct--โŒ28.029.227.333.325.125.0
LLaMA-3.1-8B-InstructLLaMA-3.1-8B-Instruct10โŒ32.434.930.937.829.828.7
LLaMA-3.1-8B-Instruct--โœ…30.435.427.336.727.925.0
LLaMA-3.1-8B-InstructLLaMA-3.1-8B-Instruct10โœ…34.239.131.241.131.228.7
LLaMA-3.1-70B-Instruct--โŒ31.632.331.241.127.424.1
LLaMA-3.1-8B-InstructLLaMA-3.1-70B-Instruct50โŒ38.840.637.637.838.141.7
LLaMA-3.1-70B-InstructLLaMA-3.1-70B-Instruct20โŒ40.242.238.942.837.241.7
LLaMA-3.1-70B-Instruct--โœ…36.235.936.345.034.025.9
LLaMA-3.1-8B-InstructLLaMA-3.1-70B-Instruct40โœ…40.244.337.641.137.743.5
LLaMA-3.1-70B-InstructLLaMA-3.1-70B-Instruct20โœ…40.245.337.044.436.340.7
Qwen2.5 Series
Target ModelDraft ModelฮทCoTOveralEasyHardShortMediumLong
Qwen2.5-7B-Instruct--โŒ30.231.229.641.724.722.2
Qwen2.5-7B-InstructQwen2.5-7B-Instruct20โŒ32.035.429.940.627.426.9
Qwen2.5-7B-Instruct--โœ…33.236.531.246.724.228.7
Qwen2.5-7B-InstructQwen2.5-7B-Instruct5โœ…35.440.632.242.233.028.7
Qwen2.5-72B-Instruct--โŒ40.041.738.942.237.241.7
Qwen2.5-7B-InstructQwen2.5-72B-Instruct50โŒ35.638.533.842.230.235.2
Qwen2.5-72B-InstructQwen2.5-72B-Instruct20โŒ42.944.342.148.937.743.5
Qwen2.5-72B-Instruct--โœ…43.949.540.546.740.546.3
Qwen2.5-7B-InstructQwen2.5-72B-Instruct50โœ…41.241.740.843.336.746.3
Qwen2.5-72B-InstructQwen2.5-72B-Instruct20โœ…44.145.343.447.242.841.7

InfiniteBench

  • Evaluate with top_p=1, temperature=1.
LLaMA-3.1 Series
Target ModelDraft ModelฮทEn.QAEn.MCEn.SumAVG.
LLaMA-3.1-8B-Instruct--34.5853.2830.1439.33
LLaMA-3.1-8B-InstructLLaMA-3.1-8B-Instruct1034.9063.3230.2742.83
LLaMA-3.1-70B-Instruct--36.4868.5630.1845.07
LLaMA-3.1-8B-InstructLLaMA-3.1-70B-Instruct1040.9479.0429.9649.98
LLaMA-3.1-70B-InstructLLaMA-3.1-70B-Instruct1040.5681.6629.6450.62
Qwen2.5 Series
Target ModelDraft ModelฮทEn.QAEn.MCEn.SumAVG.
Qwen2.5-7B-Instruct--16.9366.8130.6238.12
Qwen2.5-7B-InstructQwen2.5-7B-Instruct2019.8175.9831.6442.48
Qwen2.5-72B-Instruct--39.2181.6632.4551.11
Qwen2.5-7B-InstructQwen2.5-72B-Instruct2030.1083.8432.2148.72
Qwen2.5-72B-InstructQwen2.5-72B-Instruct1040.5285.5932.9453.02

Citation

If you find our paper useful, hope you can star our repo and cite our paper as follows:

@article{chen2025longcontextinferenceretrievalaugmentedspeculative,
      title={Long-Context Inference with Retrieval-Augmented Speculative Decoding}, 
      author={Guanzheng Chen and Qilong Feng and Jinjie Ni and Xin Li and Michael Qizhe Shieh},
      year={2025},
      eprint={2502.20330},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2502.20330}, 
}