Multi-Token Assisted Decoding (MTAD)

April 11, 2025 ยท View on GitHub

This repository contains the implementation of MTAD from the paper: "Optimized Multi-Token Joint Decoding with Auxiliary Model for LLM Inference", ICLR 2025.

The implementation is based on MCSD.


Update

2025.4.9: Implement Multi-Candidate MTAD, which incorporates tree-wise parallel decoding for better efficiency and output quality. The details of the algorithm will be released on arxiv.

๐Ÿš€ Dependencies

Ensure you have the following installed:

  • PyTorch: >= 2.4.1
  • Python: >= 3.8
  • Transformers: >= 4.34.0
  • pandas

๐Ÿ“‚ Datasets

Spider

Download the Spider dataset from their official website: https://yale-lily.github.io/spider

Human-Eval

Install Human-Eval from its GitHub repository: https://github.com/openai/human-eval

MT-Bench

The script does not directly support MT-Bench, but you can modify the script from FastChat to generate answers using our decoding method and run evaluation.


๐Ÿ›  Usage

Setting up the Environment

If you want to run official Llama Models, set your Hugging Face token first:

env HFTOKEN=your_huggingface_token

Then, run evaluation.py with the appropriate options.

Important Options

ArgumentDescription
--datasetName of the dataset (spider or human_eval)
--draft-modelPath to the draft model
--target-modelPath to the target model
--tokenizerPath to the tokenizer (defaults to target model if not provided)
--mtadRun MTAD decoding
--beam-widthBeam width of the draft model for MTAD (default: 4)
--accept-thresAcceptance threshold for MTAD (default: 0.5)
--fp16Use float16 dtype for the target model
--k-configBranch factor for SpecInfer (comma-separated values, e.g., --k-config 4,2,2)
--datapathPath to the JSON data file
--max-new-tokensMaximum number of new tokens
--replacementEnable sampling with replacement
--disable-tqdmDisable tqdm progress bar
--disable-tree-attnDisable tree parallel decoding, use it when you want to run original MTAD

๐Ÿ“Œ Example Commands and Outputs

For detailed example scripts and outputs, refer to examples.md.


โš ๏ธ Notes

  • SpecInfer utilizes tree attention, which is only implemented for the Llama model.
  • MTAD does not require tree attention, so you can directly use AutoModelForCausalLM with MTAD.

๐Ÿ”— References


Now, you're all set to use MTAD for efficient LLM inference! ๐Ÿš€