TrimKV: Token Retention for Memory-Bounded Key-Value Eviction

May 13, 2026 Β· View on GitHub

πŸš€ Updates

  • πŸ†• DBTrimKV β€” dynamic-budget variant powered by PagedTrimKVCache. A single global KV budget is shared across all layers and heads and reallocated on the fly at every step, instead of fixing a per-head budget upfront. The retention gate's final projection is tied across layers and heads, and the runtime uses PagedTrimKVCache β€” a paged-attention-style cache where blocks are dynamically (re)assigned to the heads that currently need them. The result: significantly outperforms TrimKV at low KV budgets, and matches or even beats the full KV cache β€” without any per-head tuning. Same training surface as TrimKV β€” just two env-var flips (RETENTION_GATE=rg10, GLOBAL_CAPACITY=True). Public LLM checkpoints:
  • πŸ†• First VLM release β€” TrimKV / DBTrimKV go multimodal. Full Qwen3-VL / Qwen2.5-VL / LLaVA support with end-to-end training recipes in train/vlm/ and an evaluation harness in experiments/lmms-eval/ and experiments/mmdu/. Auto-downloading data prep for R1-Onevision, M4-Instruct, LLaVA-Video-178K, MMDU, and OpenR1-Math-220k under train/vlm/scripts/data/. Public VLM checkpoints (DBTrimKV):
  • Codebase refactor for transformers v4.57.0. This release freezes the codebase at a version close to what produced the paper results, so all reported numbers are reproducible. If you hit issues, please open a GitHub issue.

What is TrimKV?

An efficient and learnable key–value eviction strategy designed to improve the efficiency of large language models (LLMs) in long-horizon inference.

Imagine what if our brain worked like a transformer:

teaser

This is because it tried to remember every single detail (token) forever. TrimKV lets your model forget the parts that aren't very important so it doesn't melt its VRAM. Don't let the brain (or GPU) explode. πŸ’₯🧠

The core idea behind TrimKV is to learn the intrinsic importance of each key–value pair at creation time β€” what we call token retention β€” and then decay this importance exponentially over time to mimic standard inference running with eviction.

The retention score is query-agnostic and captures the long-term utility of tokens. This is different from attention scores, which are query-dependent: they capture short-term utility for predicting the next token, are recomputed at every step, and are highly dependent on the transient decoding state.

TrimKV vs DBTrimKV

Both variants share the same training loop, datasets, and loss surface. They differ in how the KV budget is allocated, which retention-gate parameterisation is used, and which cache class powers inference:

TrimKVDBTrimKV (new)
Budget semanticsper-layer, per-head local budget M_local = Msingle global budget M_global = M Γ— num_layers Γ— num_heads, redistributed dynamically across layers/heads
Gate parameterisationindependent retention gate per headfinal projection of the gate tied across layers and heads
Inference cacheTrimKVCache (fixed per-head allocation)PagedTrimKVCache β€” paged-attention-style blocks dynamically (re)assigned to heads that currently need capacity
RETENTION_GATE flagrgrg10
GLOBAL_CAPACITY flagFalseTrue

DBTrimKV's combination of the global retention gate with PagedTrimKVCache lets it run at much tighter average budgets while preserving accuracy β€” heads with high retention demand temporarily borrow capacity from heads with low demand on a per-step basis. See train/llm/README.md and train/vlm/README.md for the full training surface, and src/trimkv/cache_utils.py for the cache implementations.

Why TrimKV?

It's fast

teaser

It's smart

teaser

Getting started

Requirements

  • Python 3.11 or higher (tested with 3.12)
  • PyTorch 2.7.0 or higher (tested with 2.8.0)
  • FlashAttention 2.7.2.post1 or higher (tested with 2.8.0)
  • Transformers 4.57.1
pip install -r requirements.txt

This is a minimal set of requirements for training. Additional dependencies may be needed for individual experiments; see examples/env.yaml for a full reproducible environment.

Installation

pip install trimkv

Quick start

import torch
from trimkv.models.qwen3 import TrimKVQwen3ForCausalLM
from trimkv.cache_utils import TrimKVCache, PagedTrimKVCache
from transformers import AutoTokenizer

# Pick any TrimKV / DBTrimKV checkpoint from the table below
model_path = "ngocbh/DBTrimKV-Qwen3-4B-Math"
download_from = "huggingface"  # also: "wandb", "local"

model = TrimKVQwen3ForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    load_trimkv_weights=True,
    download_from=download_from,
    use_cache=True,
    device_map="cuda",
)
model.config._attn_implementation = "flash_attention_2"

tokenizer = AutoTokenizer.from_pretrained(
    model.config.base_model, use_fast=True, padding_side="left"
)

# PagedTrimKVCache is the inference-time cache used by DBTrimKV. It allocates a
# global pool of blocks and (re)assigns them to heads on the fly so heads with
# high retention demand can borrow capacity from heads with low demand.
# For (non-DB) TrimKV, swap in TrimKVCache(memory_size=..., buffer_size=..., device="cuda").
past_key_values = PagedTrimKVCache(
    num_layers=model.config.num_hidden_layers,
    num_heads=model.config.num_key_value_heads,
    max_seq_len=32768,
    memory_size=128,
    num_blocks_ratio=1.0,
    buffer_size=32,
    strategy="fixed_budget",
    device="cuda",
)

# Use model.generate as normal β€” pass past_key_values to enable TrimKV eviction.

For a runnable end-to-end example see examples/test_qwen3.py. VLM checkpoints use TrimKVQwen3VLForConditionalGeneration from trimkv.models.qwen3_vl and the same PagedTrimKVCache, but read model.config.text_config.num_hidden_layers / num_key_value_heads instead.


Training

  • LLMs (Qwen3, Qwen2, Llama, Phi-3): train/llm/ β€” DeepSpeed + πŸ€— Trainer. Two recipes (train_trimkv_long.sh for long-context KL distillation, train_trimkv_math.sh for R1-style math reasoning). Same recipes train both TrimKV and DBTrimKV β€” flip RETENTION_GATE / GLOBAL_CAPACITY to switch.
  • VLMs (Qwen2.5-VL, Qwen3-VL, LLaVA): train/vlm/ β€” same harness extended for visual data. Auto-downloading data prep for R1-Onevision, M4-Instruct, LLaVA-Video-178K, MMDU, and OpenR1-Math-220k under train/vlm/scripts/data/.

Experiments

Per-benchmark evaluation harnesses live in experiments/ β€” see experiments/README.md for the full index.

  • Baselines: TrimKV, DBTrimKV, R-KV, SeerAttention, SnapKV, StreamingLLM, H2O, KeyDiff, LocRet.
  • Long-horizon generation: GSM8K, MATH-500, AIME-24, LongProc.
  • Long-context understanding: SCBench, LongMemEval, LongBench, LongBench v2.
  • Multimodal: lmms-eval task suite (mathvision_testmini, video_mmmu_*, mmmu_pro_vision, videomme, videomathqa_mcq, mmstar) plus MMDU.

Released models

LLM checkpoints

Base ModelVariantCheckpointTraining DatasetsMax Context LenTraining MM
Qwen3-1.7BTrimKVTrimKV-Qwen3-1.7B-MathOpenR1-Math-220k16K512
Qwen3-4BTrimKVTrimKV-Qwen3-4B-MathOpenR1-Math-220k16K512
Qwen3-8BTrimKVTrimKV-Qwen3-8B-MathOpenR1-Math-220k16K512
Qwen3-14BTrimKVTrimKV-Qwen3-14B-MathOpenR1-Math-220k16K512
Qwen3-4B-Instruct-2507TrimKVTrimKV-Qwen3-4B-Instruct-2507Synth-Long, BookSum, Buddhi128K4096
Phi-3-mini-128k-instructTrimKVTrimKV-Phi-3-mini-128k-instructLongAlpaca128K2048
Qwen3-4BDBTrimKV πŸ†•DBTrimKV-Qwen3-4B-MathOpenR1-Math-220k32K128
Qwen3-4B-Instruct-2507DBTrimKV πŸ†•DBTrimKV-Qwen3-4B-Instruct-2507Synth-Long, BookSum, Buddhi128K512

VLM checkpoints β€” first multimodal release πŸ†•

Base ModelVariantCheckpointTraining DatasetsMax Context LenTraining MM
Qwen3-VL-8B-ThinkingDBTrimKVDBTrimKV-Qwen3-VL-8B-ThinkingR1-Onevision, M4-Instruct, LLaVA-Video-178K, MMDU, OpenR1-Math-220k32K32
Qwen3-VL-4B-InstructDBTrimKVDBTrimKV-Qwen3-VL-4B-InstructM4-Instruct, MMDU32K32

Happy to mention here if you have your own checkpoints for different settings.


Citation

@article{bui2025cache,
  title={Cache what lasts: Token retention for memory-bounded kv cache in llms},
  author={Bui, Ngoc and Sharma, Shubham and Lamba, Simran and Mishra, Saumitra and Ying, Rex},
  journal={arXiv preprint arXiv:2512.03324},
  year={2025}
}
@article{bui2025make,
  title={Make Each Token Count: Towards Improving Long-Context Performance with KV Cache Eviction},
  author={Bui, Ngoc and Nguyen, Hieu Trung and Cohan, Arman and Ying, Rex},
  journal={arXiv preprint arXiv:2512.03324},
  year={2025}
}

Acknowledgements

A large portion of this repository is adapted from or built on top of the following projects: