TurboQuant MLX

March 25, 2026 · View on GitHub

Reproduction of KV-Cache quantization from TurboQuant (Google, 2025) (Paper) on Apple Silicon using MLX.

Result: Up to 5.5x KV-Cache compression. Two paths: V2 (hardware-accelerated, mx.quantized_matmul) for speed, V3 (Lloyd-Max codebook, paper-correct) for maximum quality. Mostly MLX-native ops, with a custom Metal kernel for fused QJL sign-bit scoring.

Benchmark Results

Tested on Apple M4 Max (64 GB), models from mlx-community (4-bit weight quantized).

Multi-Model Quality (Perplexity, lower is better)

Strategybits/dimLlama 3.2 3BLlama 3.1 8BMistral 7BGemma 3 4B
D=128D=128D=128D=256
fp16 baseline1612.949.476.7912.18
V2 3-bit rot+QJL313.63 (+5.3%)10.21 (+7.8%)7.14 (+5.1%)12.05 (-1.1%)
V2 4-bit rotated412.84 (-0.8%)9.61 (+1.4%)6.89 (+1.4%)12.53 (+2.9%)
V2 4-bit LEAN413.02 (+0.6%)9.85 (+4.0%)6.87 (+1.2%)12.37 (+1.6%)
V3 3.5-bit mixed3.512.98 (+0.3%)10.10 (+6.7%)7.06 (+4.0%)12.44 (+2.1%)
V3 3.25-bit mixed3.2513.57 (+4.8%)10.25 (+8.3%)7.17 (+5.6%)12.74 (+4.6%)
V3 3-bit Lloyd-Max313.60 (+5.1%)10.28 (+8.6%)7.27 (+7.0%)12.93 (+6.2%)
V3 2.75-bit mixed2.7514.95 (+15.5%)11.21 (+18.4%)7.33 (+7.9%)13.88 (+14.0%)
V3 2.5-bit mixed2.516.44 (+27.0%)12.80 (+35.2%)7.53 (+10.8%)13.04 (+7.0%)
V3 2-bit Lloyd-Max221.27 (+64.3%)15.67 (+65.5%)8.10 (+19.3%)14.64 (+20.2%)

Key finding: V2 3-bit rot+QJL beats fp16 on Gemma 3 (D=256) — the rotation + QJL correction acts as a regularizer at larger head dimensions. V3 2.5-bit on Gemma (+7.0%) is dramatically better than on Llama 3B (+27.0%), confirming that larger head_dim improves quantization quality.

Throughput (Llama 3.2 3B, tok/s)

Strategy              T=512   T=1024   T=2048   T=4096   T=8192
──────────────────────────────────────────────────────────────────
Standard fp16          208      199      191      175      148
MLX 4-bit Quant        188      188      184      174      156
V2 4-bit LEAN          188      188      184      174      156
V2 4-bit (rotated)     135      133      131      124      115
V2 3-bit rot+QJL       101       96       84       65       45
V3 3.5-bit mixed        82       74       59       42       24
V3 3-bit Lloyd-Max      98       86       70       47       27
V3 2.5-bit mixed        83       75       59       42       24

V2 uses mx.quantized_matmul (Metal kernel) — near-native speed. V3 uses software dequant (centroid lookup + mx.matmul) — slower but paper-correct quality.

KV-Cache Compression at T=8192

StrategyCache SizeCompression
fp16969 MB1x
V2 4-bit LEAN266 MB3.6x
V3 3.5-bit mixed236 MB4.1x
V3 3-bit Lloyd-Max207 MB4.7x
V3 2.5-bit mixed177 MB5.5x

Recommendation

Use CaseStrategyQuality (D=128)Quality (D=256)Speed
Maximum speedV2 4-bit LEAN+0.6-4% PPL+1.6% PPL~105% of fp16 at 8K
Best quality at 4-bitV2 4-bit rotated-0.8 to +1.4%+2.9%~78% of fp16
Best 3-bit (D=256)V2 3-bit rot+QJL+5-8%-1.1%~30% of fp16 at 8K
Near-lossless compressionV3 3.5-bit mixed+0.3-7%+2.1%~16% of fp16
BalancedV3 3-bit Lloyd-Max+5-9%+6.2%~18% of fp16
Aggressive compressionV3 2.5-bit mixed+11-35%+7.0%~16% of fp16

Architecture

┌─────────────────────────────────────────────┐
│  mlx-lm (Llama, Mistral, ...)               │
│    ↓ SDPA dispatch (monkey-patch)           │
├─────────────────────────────────────────────┤
│  turboquant.patch                            │
│    → Detects TurboQuant cache objects       │
│    → Routes to V2 or V3 attention           │
├─────────────────────────────────────────────┤
│                                             │
│  V2 Path (Speed)         V3 Path (Quality)  │
│  ┌───────────────┐       ┌───────────────┐  │
│  │ attention_v2   │       │ attention_v3   │  │
│  │ mx.quantized_  │       │ Centroid lookup│  │
│  │ matmul (Metal) │       │ + mx.matmul    │  │
│  ├───────────────┤       ├───────────────┤  │
│  │ cache_v2       │       │ cache_v3       │  │
│  │ mx.quantize    │       │ Lloyd-Max      │  │
│  │ affine quant   │       │ codebook quant │  │
│  │ ± rotation     │       │ + rotation     │  │
│  │ ± QJL          │       │ ± channel split│  │
│  └───────────────┘       └───────────────┘  │
│                                             │
├─────────────────────────────────────────────┤
│  Shared: codebook.py, codebook_ops.py,      │
│  qjl.py, rotation.py                        │
├─────────────────────────────────────────────┤
│  MLX Metal Backend                           │
│    → quantized_matmul (V2 only)             │
│    → All ops are MLX-native                 │
└─────────────────────────────────────────────┘

V2 Variants (Affine Quantization, Hardware-Accelerated)

VariantRotationNorm-BakingQJLSpeedDescription
LEANFastestmx.quantize directly. Matches MLX built-in QuantizedKVCache.
rotated~70%Random QR rotation + norm-baking. Best 4-bit quality.
rotated+QJL~30%+1-bit residual correction. Fused Metal kernel for sign-bit scoring.

V3 Variants (Lloyd-Max Codebook, Paper-Correct)

VariantChannelsDescription
uniformall @ b-bitLloyd-Max codebook at b bits. Best quality per bit.
mixedn@(b+1) + rest@bOutlier channel splitting. Fractional bit rates (2.5, 3.5).

Paper Reproduction

What was confirmed

  1. Quality-neutral at 4-bit — PPL 13.02 vs 12.94 fp16 (+0.6%). With rotation: 12.84 (-0.8%)
  2. 3.6-5.5x cache compression depending on bit width
  3. Bandwidth crossover — V2 compressed cache overtakes fp16 at T~4K
  4. Random rotation (QR) improves quality — distributes outlier channels evenly
  5. Lloyd-Max codebook beats affine at 3-bit — PPL +5-9% vs +9-23% (V3 vs V2)
  6. Outlier channel splitting enables fractional bit rates — V3 3.5-bit mixed: +0.3% PPL
  7. QJL improves V2 3-bit — from +6.6% to +5.3% as additional correction
  8. Results generalize across Llama 3.2 3B, Llama 3.1 8B, Mistral 7B, Gemma 3 4B
  9. Larger head_dim improves quantization — Gemma (D=256) shows dramatically better quality at low bits than Llama (D=128). V3 2.5-bit: +7% (Gemma) vs +27% (Llama 3B)
  10. V2 3-bit rot+QJL beats fp16 on Gemma — PPL 12.05 vs 12.18 (-1.1%). Rotation + QJL acts as regularizer at D=256

What differs

  • Hardware: Paper tests on A100 (80 GB HBM2e, 2.0 TB/s). We test on M4 Max (Unified Memory, ~400 GB/s).
  • Weight precision: Paper tests full-precision (bfloat16) models. We test 4-bit weight quantized models, which compounds KV cache quantization error.
  • Kernels: Paper uses custom CUDA kernels for codebook dequant. We use MLX-native ops. V2 uses mx.quantized_matmul (Metal kernel, fast). V3 uses software dequant via centroid lookup (correct, slow).
  • TurboQuant_prod: The paper's (b-1)-bit MSE + 1-bit QJL scheme doesn't improve quality at D=128 or D=256 in our tests. QJL works as an additional correction (V2 3-bit rot+QJL) but not as a replacement for MSE bits. See analysis below.
  • 2-bit quality: Both V3 Lloyd-Max and V2 affine degrade ~60% at 2-bit (D=128). With channel splitting (2.5-bit mixed), quality improves to +7-35% depending on model and head_dim. Gemma (D=256) achieves +7% vs Llama 3B (D=128) at +27%.
  • V3 throughput: Without custom Metal kernels for codebook dequant+matmul, V3 runs ~5-6x slower than V2. On A100 with custom CUDA kernels, the paper avoids this penalty.

Why TurboQuant_prod doesn't help

The paper's TurboQuant_prod uses (b-1)-bit MSE + 1-bit QJL for inner-product-optimal quantization. The QJL correction estimates <q, residual> via the Johnson-Lindenstrauss sign projection.

In our tests, TurboQuant_prod consistently degrades quality at both D=128 and D=256:

  • V3 3-bit prod (2-bit MSE + QJL): PPL 19.48 vs V3 3-bit MSE: 13.60 (D=128)
  • At D=256 (Gemma head_dim), the gap does NOT shrink — prod remains worse

Root cause: centroid resolution loss through softmax amplification.

The JL estimator variance scales correctly as O(π/(2d)) for unit-norm queries (verified in tests). But the real bottleneck is not JL variance — it's the catastrophic centroid resolution drop from b-bit to (b-1)-bit:

  • 3-bit (8 centroids): MSE distortion 0.034σ²
  • 2-bit (4 centroids): MSE distortion 0.120σ² — 3.5x worse

The QJL correction applies a linear correction to attention scores, but softmax amplifies score errors exponentially. Having 4 centroids instead of 8 creates coarser score quantization that softmax magnifies into attention weight errors far exceeding what the QJL correction can recover.

QJL does work when added as extra information (V2 3-bit rot+QJL: +5.3% vs +6.6% without QJL), but not when it replaces MSE bits (TurboQuant_prod). This holds across all tested dimensions and models.

Note: The paper may achieve different results with custom CUDA kernels, full-precision weight models, and potentially different QJL scaling. Our models use 4-bit weight quantization, which compounds KV cache quantization error.

Quickstart

# Requirements: Apple Silicon Mac with Python 3.10+
pip install mlx mlx-lm

# Demo: text generation with compressed KV cache
python run_llm.py

# Benchmark: speed + quality
python benchmark.py

# Long-context benchmark: throughput at 512-8192 tokens
python benchmark_longseq.py

# Multi-model benchmark: PPL across 4 models (incl. Gemma D=256)
python benchmark_models.py

Custom Models

import mlx_lm
from turboquant.cache_v2 import TurboQuantKVCacheV2
from turboquant.cache_v3 import TurboQuantKVCacheV3
import turboquant.patch as tq_patch

tq_patch.apply()  # Monkey-patch SDPA dispatch

model, tokenizer = mlx_lm.load("mlx-community/Llama-3.2-3B-Instruct-4bit")
head_dim = model.layers[0].self_attn.head_dim
n_layers = len(model.layers)

# Option A: V2 4-bit (fast, hardware-accelerated)
cache = [
    TurboQuantKVCacheV2(
        head_dim=head_dim, bits=4, group_size=64,
        use_rotation=True, use_normalization=True,
    )
    for _ in range(n_layers)
]

# Option B: V3 3.5-bit mixed (near-lossless, 4.1x compression)
cache = [
    TurboQuantKVCacheV3(
        head_dim=head_dim, bits=3,
        n_outlier=64, outlier_bits=4,  # 64 channels @ 4-bit, 64 @ 3-bit
    )
    for _ in range(n_layers)
]

Project Structure

turboquant/
├── cache_v2.py          # V2: KV cache with mx.quantize (affine, fast)
├── cache_v3.py          # V3: Lloyd-Max codebook + channel splitting
├── attention_v2.py      # V2: SDPA with mx.quantized_matmul
├── attention_v3.py      # V3: SDPA with software dequant
├── codebook.py          # Lloyd-Max optimal centroids (1-4 bit)
├── codebook_ops.py      # Pure MLX pack/unpack for 2/3/4-bit indices
├── qjl.py               # Pure MLX QJL encoding (sign-bit packing)
├── fused_qjl.py         # Fused Metal kernel for QJL sign-bit dot products
├── patch.py             # Monkey-patch for mlx-lm SDPA dispatch
├── rotation.py          # Random rotation (QR) + JL matrix generation
├── kernels.py           # V1: Metal kernels + packing (legacy)
├── cache.py             # V1: cache (legacy)
├── attention.py         # V1: attention (legacy)
└── attention_fused.py   # V1: fused attention (legacy)

benchmark.py             # Speed + quality benchmark
benchmark_common.py      # Shared eval text and perplexity computation
benchmark_longseq.py     # Long-context throughput benchmark
benchmark_models.py      # Multi-model PPL comparison
run_llm.py               # Interactive demo
tests/
├── test_turboquant.py   # 58 unit tests (core components)
└── test_metal_barrier.py # Metal kernel barrier reproduction test

Technical Details

Pre-allocation (step=256)

Both V2 and V3 use pre-allocated buffers with slice assignment instead of per-token concatenation. Reduces allocations from O(T) to O(T/256).

Norm-Baking (V2)

For the rotated variant, L2 norms are baked into quantized scales/biases:

dequant(data, norm*scale, norm*bias) = norm * dequant(data, scale, bias)

Eliminates 2 element-wise operations from the SDPA hot path.

Lloyd-Max Codebook (V3)

After random rotation, each coordinate is ~N(0, 1/sqrt(D)). Lloyd-Max gives optimal centroids for this distribution:

  • 4-bit (16 levels): Nearly identical to affine. Both work well.
  • 3-bit (8 levels): Lloyd-Max significantly better. Non-uniform spacing matches Gaussian tails.
  • 2-bit (4 levels): Both degrade substantially. Need channel splitting for usable quality.

Outlier Channel Splitting (V3)

After rotation, all channels are statistically equivalent (iid Gaussian). A fixed channel split achieves fractional bit rates:

  • 3.5-bit: 64 channels @ 4-bit + 64 @ 3-bit = (644+643)/128 = 3.5 bits/dim
  • 2.5-bit: 64 channels @ 3-bit + 64 @ 2-bit = (643+642)/128 = 2.5 bits/dim

The split is fixed (no per-token overhead) because rotation eliminates channel-dependent outliers.

QJL Residual Correction (V2)

The residual (original - dequantized) is projected through a random matrix and stored as 1-bit sign bits. During attention, this corrects key score estimation via the JL inner product estimator.

Works as an additional correction on V2 affine quantization (3-bit: +6.6% -> +5.3%). Does NOT work as a bit replacement (TurboQuant_prod) because the (b-1)-bit centroid resolution loss is amplified exponentially by softmax, overwhelming the linear QJL correction.

MLX-LM Bug: QuantizedKVCache.nbytes

MLX-LM's QuantizedKVCache.nbytes property crashes with NameError: name 'tree_reduce' is not defined because tree_reduce is used but not imported in cache.py. Our benchmarks work around this by manually summing tensor sizes.

References

License

MIT