Speculative Decoding
May 13, 2026 · View on GitHub
llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
Speculative decoding leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
Implementations
The llama-server application supports several implementations of speculative decoding. An implementation with draft model can be mixed with an implementation without draft model.
Multimodal (--mmproj) compatibility (atomic-llama-cpp-turboquant)
When --mmproj is set, mtp, nextn, and eagle3 speculative types remain enabled at load: their draft paths do not depend on get_text_tokens() / prompt_tgt the way draft and ngram_* do. Other types are auto-disabled at load with a warning. Mixed speculative chains (e.g. ngram_simple + draft) are rejected at slot init if any impl is not multimodal-safe.
Per-turn behaviour:
- Text-only turns on a multimodal slot — draft head runs as usual (same acceptance as without
--mmproj). - Turns containing an image chunk — the slot logs
skipping speculative prime for multimodal promptand falls back to plain target decoding for that turn only. The image is still recognised correctly, just without draft speedup.
The fallback is required because NextN / MTP prime needs per-token target hidden states for every prompt position, but mtmd image-decode currently only emits an output row for the last token of each image batch. Lifting this restriction is on the roadmap (mtmd batches need to mark every position with logits[i] = true, or be replayed teacher-forced).
See common_speculative_is_mtmd_safe / common_speculative_all_impls_mtmd_safe in common/speculative.cpp, the mmproj gates and the skip_draft_mtmd per-turn gate in tools/server/server-context.cpp. Validated configurations and the end-to-end recipe are documented in NEXTN.md §10.
Draft Model (draft)
A much smaller model (called the draft model) generates drafts. A draft model is the most used approach in speculative decoding.
Gemma 4 MTP assistant (mtp)
For Gemma 4 targets with the gemma4_assistant (MTP head) GGUF, use --spec-type mtp. The assistant is not a second llama_context: weights are loaded into the target model via llama_model_load_mtp_from_file (done automatically when using the server/CLI init path). Cross-attention in the MTP graph reads K/V from the target KV cache (shared full/sliding layers).
- Prefer
--mtp-head /path/to/assistant.gguffor clarity;--model-draft(-md) is accepted as a backward-compatible alias (same path field). - The draft block size (B) is
--draft-block-size(the head proposesB - 1tokens per round; default 4). --gpu-layers-draft/-ngldand-ctkd/-ctvdstill apply to how the assistant tensors are placed and typed when the assistant GGUF is loaded; the target uses-ngland-ctk/-ctv.
Example (paths illustrative). TurboQuant KV on the target: -ctk/-ctv. Assistant-side cache types follow the draft flags if you use them for offload/quant selection.
llama-server \
-m /path/to/gemma-4-target.gguf \
--mtp-head /path/to/gemma-4-assistant.gguf \
--spec-type mtp \
--draft-block-size 4 \
-c 16384 \
-ngl 99 -ngld 99 \
-ctk turbo3 -ctv turbo3 \
-ctkd turbo3 -ctvd turbo3 \
-fa on \
--host 127.0.0.1 --port 8080
Repo helper (defaults under .scratch/): scripts/run-gemma4-mtp-server.sh.
Async MTP draft pipeline
The MTP draft head runs on a dedicated ggml_backend_sched (sched_mtp) and a
worker thread, isolated from the target's scheduler. This is exposed via two C
APIs:
LLAMA_API int32_t llama_decode_mtp_async(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos attn_pos,
llama_token last_token,
const float * h_prev,
int32_t n_steps);
LLAMA_API int32_t llama_decode_mtp_wait(
struct llama_context * ctx,
llama_token * out_drafts,
float * out_h_prev_last);
Contract:
- At most one in-flight request per context. Calling
_asyncwhile a previous request has not been waited on returns-7. h_previs copied into the request, so the caller may free or reuse it immediately after_asyncreturns.- Target KV positions ≤
attn_posmust remain stable until_waitreturns. The current append-only KV cache satisfies this as long as no eviction or overlappingseq_rmhappens between submit and wait. llama_decode_mtp(...)is preserved as a backward-compatible synchronous facade (=_asyncimmediately followed by_wait).
Why use the async pair?
- Graph-cache isolation: MTP graph reuse is no longer invalidated by target
decode resets. On Gemma 4 + Q4_K_XL, this alone delivered ~+8% throughput
in single-slot benchmarks (95.3 → 102.8 tps at
--draft-block-size 3), with identical accept rate. - Pipeline depth-2 (pure overlap,
llama-server): after target sample/accept andllama_memory_seq_rm, the server callscommon_speculative_prepare_next(spec, last_accepted_token), which submitsllama_decode_mtp_asyncfor the next round. The blockingllama_decode_mtp_waitis deferred to the start of the nextcommon_speculative_draft, so MTP work can overlap with post-accept bookkeeping and token I/O. This uses the real sampled token (no optimistic guess). KV contract: submit usesattn_posafterseq_rm; the nextllama_decodeappends only positions> attn_pos, so backbone cells read by MTP remain stable until_wait(append-only cache). Stale in-flight requests are drained incommon_speculative_beginand on skip / param-mismatch paths. - In-graph argmax: the MTP graph publishes a
ggml_argmaxof the final logits (I32 [1]) viallm_graph_result::get_argmax(). Per draft step the host reads back 4 bytes (one token id) instead of the full F32 [n_vocab] logits row, and the serial CPU argmax overn_vocabis gone. The full logits are still computed in-graph and can be fetched on demand by passing a non-nullout_logitsto the synchronousllama_decode_mtp(...)API (diagnostic / legacy path). On Gemma 4 + Q4_K_XL this delivered an additional ~+2-3% throughput (109.5 → 112.5 tpsatn=128,95.8 → 97.8 tpsatn=512), with bit-identical greedy drafts. - Future work: optimistic last-token prediction could hide an additional
llama_decodelatency on hits but risks wrong drafts on misses.
Diagnostic: per-draft acceptance trace
Set LLAMA_MTP_ACC_TRACE to enable a per-iteration NDJSON tracer in
common/speculative.cpp. Each mtp_draft event records iter, path
(sync / lazy / skip-streak / skip-nsteps), seq_id, id_last,
h_idx, attn_pos, n_steps, h_l2 (L2 norm of the input hidden state),
and the drafted token ids; each mtp_accept event records the iteration
counter, n_accepted, and n_drafted_prev. The host can pair them by iter
to reconstruct per-position acceptance, MTP h_prev stability, and
selection-bias breakdowns by h_idx.
# write trace to stderr
LLAMA_MTP_ACC_TRACE=1 ./llama-server [...]
# write trace to a file
LLAMA_MTP_ACC_TRACE=/tmp/mtp.ndjson ./llama-server [...]
Throughput tuning (Edge assistants / heavy backends)
Ordered-embeddings MTP (use_ordered_embeddings) runs a centroid LM head (top_k, routed get_rows) per draft step; the greedy graph still materializes a full-vocab logits row (masked fill + scatter) so argmax and optional full-row export stay consistent with verify.
LLAMA_MTP_SKIP_STREAK_THRESHOLD: off by default (0/ unset). If set to1–32, after that many consecutive batches with zero accepted MTP drafts, MTP drafting is skipped for one verify-only batch, then retried.- Helper scripts
scripts/run-gemma4-e4b-mtp-server.shandscripts/run-gemma4-e2b-mtp-server.shsupportMTP_PRESET:throughput(block 2 / max 6),lift(block 3 / max 8),balanced, orquality. Override withDRAFT_BLOCK_SIZE,DRAFT_MAX, and optionallyLLAMA_MTP_SKIP_STREAK_THRESHOLD.
Disabled at zero overhead (no env var or value 0 / empty); when enabled the
extra cost is one n_bb-wide L2 reduction per draft and a small NDJSON write.
Reconverting gemma4_assistant from Hugging Face
Example (paths relative to repo root):
python convert_hf_to_gguf.py .scratch/gemma-4-26B-A4B-it-assistant \
--outfile .scratch/gemma-assistant-mtp.gguf --outtype f16
Use the resulting GGUF as --mtp-head (or -md) with --spec-type mtp. Older assistant GGUFs with token_embd.weight first axis 2816 (backbone width) instead of 1024 will fail load; run scripts/verify-gemma4-assistant-gguf.py on the file to check.
Qwen 3.x NextN (nextn)
For Qwen3.6 (and compatible) checkpoints that ship NextN head weights in the combined *_MTP.gguf, use --spec-type nextn with --model-draft (-md) pointing at the same GGUF as the main model. The server detects this and reuses the already-loaded target llama_model — a second llama_context is built over the same weights with llama_context_params.nextn_draft = true, which routes graph construction to llm_build_qwen35_nextn / llm_build_qwen35moe_nextn and sizes the draft KV cache only for the NextN layer (kv_only_nextn = true, mutated transparently inside llama_context ctor). There is no second mmap of the GGUF.
- Drafting reads CPU-copied pre-final-norm hidden states (
embeddings_pre_normpath); it does not use Gemma'sllama_decode_mtp_*APIs. llama_set_nextnonly pairs target and draft forllama_context_nextn_seq_rm; seeNEXTN.mdfor details.- Standalone NEXTN_ONLY GGUFs (
general.architecture = qwen35*_mtp) are still supported as a fallback for users who ship the draft head as a separate artifact (the server then performs a secondllama_model_load_from_filewithoverride_arch); the shared-model path is preferred whenever--modeland--model-draftpoint at the same combined_MTP.gguf.
llama-server \
-m /path/to/qwen3.6-MTP.gguf \
-md /path/to/qwen3.6-MTP.gguf \
--spec-type nextn \
--draft-max 2 --draft-min 1 \
-c 8192 -ngl 99 -ngld 99 -fa on \
--host 127.0.0.1 --port 8080
Pair with -ctk turbo3 -ctv turbo3 to compose with TurboQuant KV — on MoE targets (e.g. Qwen 3.6 35B-A3B) this combination is +24-36% tps over the turbo3 baseline at single-slot in the matrix bench (see NEXTN.md §7).
Repo helpers: scripts/run-qwen36-27b-nextn-server.sh, scripts/run-qwen36-35ba3b-nextn-server.sh.
n-gram Cache (ngram-cache)
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences. A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
See:
- #5479, #6828, #6848
n-gram Map (ngram-simple, ngram-map-*)
These implementations search the token history for patterns and use matching sequences as draft candidates. They require no additional model but rely on patterns that have already appeared in the generated text. An example to use this approach can be the rewriting of source code by a LLM.
n-gram Map (ngram-simple)
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
llama-server [...] --spec-type ngram-simple --draft-max 64
n-gram Map Key (ngram-map-k)
This implementation looks for the current n-gram of size n (called the key) in the token history. If the key n-gram is followed by the same m tokens (called the mgram) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument --spec-ngram-min-hits, default is 1) before generating drafts.
The number of accepted tokens is stored for each used n-gram.
Example:
llama-server [...] --spec-type ngram-map-k --draft-max 64
n-gram Map Key-4-Values (ngram-map-k4v)
This experimental implementation looks for the current n-gram of size n (called the key) in the token history. For each key, up to four values (n-grams of size m, called mgrams) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
The number of accepted tokens is stored for each used n-gram.
Example: Server options to be used if there are a lot of longer repetitions.
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 --draft-max 64
n-gram Mod (ngram-mod)
Add basic ngram hasher for speculative decoding:
- For each ngram, compute a hash using LCG
- For each computed hash, store the next token
- During speculation, iteratively compute the rolling hash of the last n tokens and pick the next token from the storage
Some characteristics:
- Lightweight (~16 MB)
- Constant memory and complexity
- Can generate variable draft lengths (i.e. m is not fixed)
Currently, a single hash pool is shared across all server slots, so different requests can benefit from each other.
Sample usage:
# notes:
# - small `n` are not recommended
# - MoEs require long drafts
# - dense models: can reduce `--draft-min` and `--draft-max`
llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64
Applications:
- Iterating over a block of text/code (e.g. in llama.vim)
- Reasoning models (when they have to repeat their thinking in the final answer)
- Summarization
Example Video:
- See #19164
Differences between ngram-simple, ngram-map and ngram-mod
- ngram-simple looks for a previous matching n-gram and inserts the following m-gram.
- ngram-map-k looks for a previous matching n-gram and inserts the following m-gram but uses an internal hash-map of n-grams in the current context window.
- ngram-mod uses a hash pool which is shared across all server slots. The hash pool is a map from n-gram hash to the next token (not the next m-gram as in ngram-map).
Command-Line Options
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
--draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX)
--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding
(default: 0)
(env: LLAMA_ARG_DRAFT_MIN)
[...]
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
type of speculative decoding to use when no draft model is provided
(default: none)
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
of lookup n-gram (default: 12)
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
of draft m-gram (default: 48)
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
--spec-type TYPE
Specifies a type of speculative decoding without draft model.
| Type | Description |
|---|---|
none | No speculative decoding (default) |
ngram-cache | Use n-gram cache lookup |
ngram-simple | Use simple n-gram pattern matching |
ngram-map-k | Use n-gram pattern matching with n-gram-keys |
ngram-map-k4v | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
ngram-mod | Use basic ngram hasher for speculative decoding with shared pool |
Example: Server-instance used to refactor source code.
./llama-server [...] --spec-type ngram-simple
--spec-ngram-size-n N
Sets the size N of the lookup n-gram for n-gram map based speculative decoding. The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
--spec-ngram-size-m M
Sets the size M of the draft m-gram for n-gram map based speculative decoding. The m-gram size determines how many tokens to draft when a match is found. Larger values can provide more speedup but may reduce acceptance rate.
--spec-ngram-min-hits H
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
Statistics
Each speculative decoding implementation prints statistics.
draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
draft acceptance rate = 0.70312 ( 90 accepted / 128 generated)
statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms
statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts = 26, #gen tokens = 1248, #acc tokens = 968, dur(b,g,a) = 2.234, 1.427, 0.016 ms
#calls(b,g,a): number of calls of begin (new prompt), generation and accumulation of this implementations#gen drafts: number of drafts generated by this implementation#acc drafts: number of drafts accepted (partially) by the main model#gen tokens: number of tokens generated by this implementation (including rejected tokens)#acc tokens: number of tokens accepted by the main model- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).