Raven

May 29, 2026 · View on GitHub

Raven mascot

Raven

Raven is a linear-time sequence model built on top of Flash Linear Attention. It introduces a routing memory mechanism that selectively writes to a fixed set of persistent memory slots using a learned sparse router — achieving sub-quadratic complexity while maintaining strong associative recall.


Sparse Memory Routing in Raven

Sparse Memory Routing in Raven. Unlike SSMs that update the entire state densely, or SWA that enforces strict FIFO overwriting, Raven uses an input-dependent router. At each step, only a specific subset of memory slots (highlighted) is selected to undergo decay and receive new information. Unselected slots remain completely untouched, preventing interference and preserving long-range recall.


Architecture

Raven vs SSM architecture

Raven replaces the SSM block with an RSM (Routing State Model) layer. Unlike GLA/Mamba2 which write to all memory slots uniformly, Raven learns a per-token sparse router R that selects which slots to update.


How Raven Works

Routing memory mechanism

Each Raven layer maintains a matrix memory state H ∈ R^(slots × d_v). At each timestep the router selects the top-k most relevant slots and performs a gated update:

route_scores = TopK( sigmoid(r_proj(x)) )
decay         = exp( route_scores * f )     # sparse forgetting gate
H             = H * decay + (1 - decay) * k ⊗ v
o             = q · H                       # read

The table below places Raven in the broader landscape of linear models:

Unified view of linear models

Key design choices:

  • Sparse top-k routing — each token writes to a small subset of memory slots
  • Gumbel noise during training for exploration (optional)
  • Mamba2 or GLA decay for the forgetting gate
  • Chunked Triton kernels for training, fused recurrent kernels for generation

Results

In-Context Recall Benchmarks

Recall and benchmark results

Table 2: In-context recall benchmarks and NIAH accuracy vs. context length. Accuracy (%) on SWDE/FDA/SQuAD and single NIAH-1/2/3 across context lengths. Bold = best, underline = second best.

~400M parameter models
ModelParamsMem (M)SWDEFDASQuADN1-1KN1-2KN1-4KN1-8KN1-16KN1-32KN2-1KN2-2KN2-4KN2-8KN2-16KN2-32KN3-1KN3-2KN3-4KN3-8KN3-16KN3-32K
Transformer
w. RoPE340∞ / 042.334.522.11001000000100100000071.647.60000
w. Gate (FoX)376∞ / 052.564.330.110010032.28.04.2010010010024.011.63.295.485.664.211.67.20
SSM
GLA47512.5 / 0.429.011.430.374.625.18.22.20091.237.221.43.60084.257.120.810.22.30
GSA39912.5 / 023.814.524.999.297.190.067.429.611.096.698.828.05.11.0060.030.113.51.000
GDN47512.5 / 0.429.58.331.399.210099.892.041.822.199.292.043.617.86.24.092.680.637.85.26.82.5
Mamba-238212.5 / 0.425.714.931.999.295.652.212.85.42.899.898.068.215.44.43.853.453.617.41.82.23.2
SWA37412.5 / 010.014.429.729.811.06.23.41.2036.214.410.23.83.2026.29.27.41.41.80
Raven42412.5 / 034.122.735.499.810099.899.899.491.498.898.098.881.623.08.876.843.613.41.000

Language Modeling & Zero-Shot Evaluation

Table 3: Language modeling and zero-shot evaluation results. Perplexity on Lambada (LMB.) and zero-shot accuracy across standard benchmarks. Bold = best, underline = second best.

~400M parameter models
ModelParamsLMB. ppl↓LMB. acc↑PIQA↑Hella.↑Wino.↑ARC-e↑ARC-c↑Avg.↑
Transformer
w. RoPE34042.031.064.430.251.044.318.739.9
w. Gate (FoX)37648.130.664.930.751.144.718.940.1
SSM
GLA40042.130.764.430.152.743.819.640.2
GSA39944.130.364.930.751.545.620.540.5
GDN47540.131.665.631.450.245.719.340.6
Mamba-238243.029.965.031.551.247.520.540.1
SWA37440.730.564.530.451.644.918.640.0
Raven42441.032.764.130.351.743.918.440.2

Hybrid Models Retrieval Ability

Table 4: Hybrid-Raven vs. other hybrid architectures on retrieval tasks. ✓ = no convolutional memory needed.

~400M parameter models
ModelNo Conv.SWDEFDASQuADN1-1KN1-2KN1-4KN1-8KN1-16KN1-32KN2-1KN2-2KN2-4KN2-8KN2-16KN2-32KN3-1KN3-2KN3-4KN3-8KN3-16KN3-32K
GDN54.667.234.510010010010093.270.51001001008.00093.270.250.0000
Mamba-256.368.836.010010016.400010010085.800076.980.660.8000
SWA-RoPE51.068.134.110010010010098.260.410010010098.23.1093.478.212.860.04.40
Raven51.464.231.410010010010098.478.610010010010095.465.490.067.073.860.010.214.4

Model Structure

ComponentDetails
Layer typeRavenAttention (replaces standard attention)
MemoryFixed-size slot matrix per head (num_slots slots)
RouterLinear or MLP projection → top-k sigmoid/softmax
DecayMamba2 (A_log + dt_bias) or GLA (logsigmoid)
Feature mapSwish, ReLU, or T2R
ComputationChunked (training) / Fused recurrent (inference)
Hybrid layersOptional standard attention layers at specified indices

Installation

Install the FLA dependency first, following the official FLA guide:

pip install flash-linear-attention

Then clone this repo:

git clone https://github.com/AvivBick/RoutingMemory
cd RoutingMemory
pip install -e .

Requirements: PyTorch ≥ 2.5, Triton ≥ 3.0, einops, transformers ≥ 4.45.0


Usage

As a layer

from raven.layers import RavenAttention

attn = RavenAttention(
    hidden_size=1024,
    num_heads=4,
    num_slots=256,
    topk=32,
    decay_type='Mamba2',    # or 'GLA'
    feature_map='swish',
    router_type='lin',      # or 'mlp'
    router_score='sigmoid', # or 'softmax'
).cuda()

x = torch.randn(1, 2048, 1024).cuda()
y, _, _ = attn(x)  # (batch, seq_len, hidden_size)

As a full causal LM

from raven.models import RavenConfig, RavenForCausalLM
from transformers import AutoModelForCausalLM

config = RavenConfig(
    hidden_size=1024,
    num_hidden_layers=24,
    num_heads=4,
    num_slots=256,
    topk=32,
    decay_type='Mamba2',
    feature_map='swish',
    vocab_size=32000,
)
model = AutoModelForCausalLM.from_config(config).cuda()

Training

Raven uses the flame training framework. Add a config from configs/raven_340M_*.json to flame/configs/, then:

CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 bash train.sh \
    --job.config_file flame/models/fla.toml \
    --job.dump_folder exp/raven-340M \
    --model.config configs/raven_340M_1.json \
    --optimizer.name AdamW \
    --optimizer.lr 3e-4 \
    --lr_scheduler.warmup_steps 1024 \
    --lr_scheduler.decay_type cosine \
    --training.batch_size 16 \
    --training.seq_len 2048 \
    --training.gradient_accumulation_steps 4 \
    --training.steps 30720 \
    --training.dataset /path/to/SlimPajama-627B \
    --training.streaming \
    --training.compile \
    --checkpoint.interval 3072

The configs/ directory contains 12 ablation configurations varying the router design (linear vs. MLP, sigmoid vs. softmax, with/without Gumbel noise and bias).


Repository Structure

raven/
├── layers/
│   └── raven.py                # RavenAttention layer
└── models/
    └── raven/
        ├── configuration_raven.py
        └── modeling_raven.py

configs/
└── raven_340M_*.json           # 12 ablation configs (340M scale)

assets/img/                     # figures used in this README

Upstream: Flash Linear Attention

This repo builds on [fla-org/flash-linear-attention] and depends on it for hardware-efficient Triton kernels. In particular, Raven currently reuses FLA’s GSA chunked and fused recurrent kernels rather than vendoring separate Raven ops in this repository.

hf_model Discord


Citation


@article{afzalbick2026raven,
  title={Raven: High-Recall Sequence Modeling with Sparse Memory Routing},
  author={Arshia Afzal, Aviv Bick, Eric P. Xing, Volkan Cevher, Albert Gu},
  year={2026},
  publisher={MDPI}
}