STree

July 26, 2025 ยท View on GitHub

This repo contains the code for STree: Speculative Tree Decoding for Hybrid State-Space Models. It is adapted from MambaInLlama repo.

Setup

We ran our experiments on a Nvidia RTX3090 machine with CUDA 11.8

Creating

pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu118
pip install packaging
pip install evaluate
pip install causal_conv1d==1.4
pip install flash-attn==2.6.3
pip install transformers==4.49.0
pip install triton==3.0.0
# for evaluation
pip install shortuuid fastchat psutil accelerate

Usage

Distilling small mamba2 model for drafting

ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba2/train_distill_mamba2.py llama3.2_3B/distilled_llama2.yaml

Or

Downloading our distilled draft model

Our draft model can be found on huggingface To download and put it in the correct location:

cd checkpoint
git clone https://huggingface.co/ycwu97/mamba2-distilled-small

Inference

The inference entry point is located at benchmarks/benchmark_generation_mamba_simple.py. It provides a lot of options:

Baseline auto-regressive inference with Mamba2InLlama8B (use JunxiongWang/Llama3.2-Mamba2-3B-distill for 3B models for testing)

python benchmarks/benchmark_generation_mamba_simple.py --model-name JunxiongWang/Llama3.1-Mamba2-8B-distill --prompt "Earth is a planet." --cg

On a Nvidia 3090 GPU, the decoding time is around 2118ms

Speculative tree decoding with static tree as draft

python benchmarks/benchmark_generation_mamba_simple.py --model-name JunxiongWang/Llama3.1-Mamba2-8B-distill --prompt "Earth is a planet." --cg --spec_Ngram --use_tree_decoding --activation_replay --jit_state_copy --npad=4 --ndraft=1 --draft_num_beam=3 --strategy MIL-st

On a Nvidia 3090 GPU, the decoding time is around 1456ms

Vanilla Speculative Decoding

python benchmarks/benchmark_generation_mamba_simple.py --model-name JunxiongWang/Llama3.1-Mamba2-8B-distill --prompt "Earth is a planet." --cg --spec_Ngram --use_Nstep_kernel --activation_replay --jit_state_copy --npad=4 --ndraft=1 --strategy MIL

On a Nvidia 3090 GPU, the decoding time is around 1585ms

Evaluation

Evaluation scripts are found in scripts folder. Each block of commands can be copied into terminal and ran.

Code explanation

Key kernels are implemented in triton in mamba_ssm/ops/tree_scan.py and mamba_ssm/ops/selective_scan_update.py Mamba2 layers in mamba_ssm/modules/mamba2.py are modified to use tree scan kernel and selective Nstep kernel, together with other features to speed up.