BanditSpec: Bandit-Based Speculative Decoding for Efficient Autoregressive Generation
September 2, 2025 ยท View on GitHub
This repository implements BanditSpec, a speculative decoding framework that adaptively balances exploration and exploitation using bandit algorithms to accelerate autoregressive generation in large language models (LLMs). The framework is compatible with both LLaMA and Qwen2 architectures.
๐ง Key Components
eagle_llama.py: Defines the Eagle (Li Y. et al. 2024 ) draft model based on LLaMA.eagle_qwen.py: Defines the Eagle (Li Y. et al. 2024 ) draft model based on Qwen2.llama.py,qwen.py: Customized versions of LLaMA and Qwen2 architectures.generate_utils.py: Implements core decoding strategies including BanditSpec.inference_length.py: Main script to run throughput benchmarking across different batch sizes and strategies.llama_long.png: Visualization of throughput improvement comparisons.
๐ง Setup
Install Dependencies
pip install torch transformers fairscale flash-attn tqdm
โ ๏ธ Make sure
flash-attnis compiled for your CUDA and PyTorch version.
Download EAGLE models from their repo (https://github.com/SafeAILab/EAGLE)
Folder Structure
project/
โโโ inference_length.py
โโโ eagle_llama.py
โโโ eagle_qwen.py
โโโ llama.py
โโโ qwen.py
โโโ generate_utils.py
โโโ llama_long.png
โโโ llama_model/ # contains config.json and pytorch_model.bin for LLaMA
โโโ eagle_model/ # contains config.json and pytorch_model.bin for Eagle
Modify inference_length.py to set:
target_path = "llama_model"
eagle_path = "eagle_model"
๐ Running BanditSpec
python inference_length.py
This will run decoding experiments across:
- Different batch sizes
- Various
gammavalues - Baselines like
Best Arm,Worst Arm, and fixedgamma
๐ Output Format
bsz spec_quota gamma throughput
10 256 BanditSpec 1.43
20 256 gamma=1 1.61
...
Reference
Li Y, Wei F, Zhang C, et al. Eagle: Speculative sampling requires rethinking feature uncertainty[J]. arXiv preprint arXiv:2401.15077, 2024.