Context Parallel of Linear Atttention
May 7, 2026 · View on GitHub
Context Parallel of Linear Atttention (alias KCP in Moonshot) is context parallelism designed for delta-rule recurrent models such as GDN (Gated Delta Rule) and KDA (Kimi Delta Attention). It enables efficient distributed training by partitioning the sequence dimension across ranks, with each rank processing a local token chunk and CP automatically synchronizing cross-rank states.
Quick Start
Build CP Context
from fla.ops.cp import build_cp_context
# global cu_seqlens before partition (device can be CPU or GPU)
cu_seqlens_global = torch.tensor(
[0, s1, s1 + s2, ..., total],
dtype=torch.long,
device=device
)
# conv1d_kernel_size is required for causal_conv1d CP path
cp_context = build_cp_context(
cu_seqlens_global,
group=dist.group.WORLD,
conv1d_kernel_size=W,
)
Causal Conv1d
from fla.modules.convolution import causal_conv1d
# x_local is the rank-local chunk: [1, T_local, D]
y_local, _ = causal_conv1d(
x=x_local,
weight=weight_local,
bias=bias_local,
activation="swish",
cp_context=cp_context,
)
Note
cp_contextis required;cp_context.conv1d_kernel_sizeandcp_context.cu_seqlensmust be set.- Do not pass
cu_seqlens/cu_seqlens_cpumanually — they are taken from context.
KDA
from fla.ops.kda import chunk_kda
o_local, _ = chunk_kda(
q=F.normalize(q_local, p=2, dim=-1),
k=F.normalize(k_local, p=2, dim=-1),
v=v_local,
g=g_local,
beta=beta_local,
cp_context=cp_context,
disable_recompute=disable_recompute,
)
Note
- CP expects
B == 1for varlen and uses rank-localcu_seqlensfrom context. initial_stateandoutput_final_state=Trueare not supported in CP mode.
Conventions
CP context stores rank-local varlen metadata that tracks how sequences are distributed:
FLACPContext.cu_seqlens— rank-local cumulative sequence lengths, on GPU (int64)FLACPContext.cu_seqlens_cpu— same data on CPU for host-side indexing
Variable-length inputs start as global cu_seqlens before partitioning; build_cp_context converts them into rank-local metadata automatically.
Notation
We follow the notation from the Kimi Linear technical report (Section 2.1). Throughout this document, subscript denotes chunk index, while subscript denotes token position.
Vectors and matrices:
- — column vectors in or at position
- — matrix-form memory state; FLA kernels store as
[d_k, d_v], some backends transpose to[d_v, d_k] - with subscript — stacked vectors within chunk (shape ); sequence length splits into chunks of size
- with subscript — the -th element in chunk , i.e., where
State and decay:
The decay factor is a scalar for GDN, or per-dim for KDA.
Code mapping:
gstores (or for KDA)- After
chunk_local_cumsum,gat position equals - Then and
Recurrence
Both GDN and KDA are built on the delta rule — a recurrent update where the state matrix is first decayed, then updated by subtracting the old key's contribution and adding the new one. This enables efficient "memory editing" where stale information can be forgotten.
GDN — Scalar Per-Head Gate
GDN uses a single scalar gate per head per token. From [Yang et al., 2025]:
KDA — Per-Dim Gate
KDA extends GDN with a per-dimension gate, giving finer control over which features to retain or forget. From Eq. 1 in the Kimi-k1.5 report:
Chunkwise Formulation
For efficiency, we process tokens in chunks using the WY representation (Eq. 7 in the report), which computes auxiliary matrices for each chunk. The inter-chunk state recurrence (Eq. 8) becomes:
This formulation is key to CP: it lets us compute how the state transforms across a chunk, enabling efficient cross-rank synchronization.
GDN vs KDA: Gate Handling
While both models share the delta rule structure, they differ in how gating is applied — a distinction that affects the CP implementation.
GDN
GDN's scalar gate is cheap to apply inside kernels, so we pass the original tensors and let the kernel handle gating internally:
- Gate: , one scalar per head per token
- Code:
gshape[B, T, H]where ; processed bychunk_local_cumsum - Kernel input: Original , , and scalar
g - Internal gating (
USE_G=True):- Inter-chunk decay: (scalar broadcast)
- Gated key:
- Gated query: (backward only)
KDA
KDA's per-dim gate would be expensive to apply inside kernels. Instead, we pre-gate the tensors during the WY representation step:
- Gate: , one value per dimension per token
- Code:
gshape[B, T, H, K]where ; processed bykda_gate_chunk_cumsum - Pre-gated tensors (from
chunk_kda_fwd_intra/recompute_w_u_fwd):kg: row is , i.e.,k * exp2(gk_last - gk)qg: row is , i.e.,q * exp2(gk)(saved for backward)
- Kernel input: Pre-gated
kg(andqgin backward), plusgk=gfor inter-chunk decay - Kernel gating (
USE_GK=True): Only chunk-level decay
This design means CP pre-processing must use the same tensors as the main kernel — original for GDN, pre-gated for KDA.
CP Architecture
Data Flow
The core challenge of CP is that each rank only sees a local chunk, but the recurrent state depends on all previous tokens. We solve this with an all-gather + merge pattern:
-
Local computation: Each rank computes from its chunk
- : accumulated state assuming
- : transition matrix capturing how the chunk transforms incoming state
-
All-gather: Collect from all ranks
-
Merge: Rank reconstructs its initial state by chaining contributions from ranks :
Pre-Process Forward
This step computes for the local chunk.
Stage 1 — Accumulated state :
We simulate processing the chunk with zero initial state. Initialize , then for each sub-chunk :
Stage 2 — Transition matrix :
The transition matrix captures how incoming state is transformed. Initialize , then for each sub-chunk :
Merge (forward direction):
For rank with pre_num_ranks previous ranks:
Pre-Process Backward
The backward pass has the same structure but reversed direction — we merge from ranks after the current rank to propagate gradients backward through the sequence.
Stage 1 — Gradient :
Initialize . For each sub-chunk in reverse order:
where is the scaling factor.
Stage 2 — Gradient :
Initialize . For each sub-chunk in reverse order:
Note
is the transpose of forward ( vs. ).
Merge (backward direction):
For rank with post_num_ranks following ranks:
Code Flow
The following examples show how CP integrates with the existing kernel interfaces.
GDN Forward
g = chunk_local_cumsum(g, chunk_size=64, scale=RCP_LN2)
w, u = recompute_w_u_fwd(k, v, beta, A, g=g)
# CP pre-process: original k, scalar g
initial_state = chunk_gated_delta_rule_fwd_h_pre_process(
k=k, w=w, u=u, g=g, # USE_G=True, USE_GK=False
context=cp_context,
)
# Main kernel: original k, scalar g
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=k, w=w, u=u, g=g,
initial_state=initial_state,
)
GDN Backward
w, u = recompute_w_u_fwd(k, v, beta, A, g=g)
h, v_new, _ = chunk_gated_delta_rule_fwd_h(k=k, w=w, u=u, g=g, ...)
dv = chunk_bwd_dv_local(q=q, k=k, g=g, do=do, ...)
# CP pre-process: original q, k, scalar g
dht, initial_state = chunk_gated_delta_rule_bwd_dhu_pre_process(
q=q, k=k, w=w, do=do, dv=dv, g=g, # USE_G=True, USE_GK=False
context=cp_context,
)
# Main kernel: original q, k, scalar g
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=q, k=k, w=w, g=g,
dht=dht, ...
)
KDA Forward
# 1. Intra-chunk: compute WY repr + pre-gated tensors
w, u, qg, kg, Aqk, Akk = chunk_kda_fwd_intra(q, k, v, gk=g, beta, ...)
# kg = K ⊙ exp2(γ^{r→C}), i.e., rows of Γ^{i→C} ⊙ K
# qg = Q ⊙ exp2(γ^r), i.e., rows of Γ^{1→C} ⊙ Q (saved for backward)
# 2. CP pre-process: pre-gated kg, per-dim gk=g
initial_state = chunk_gated_delta_rule_fwd_h_pre_process(
k=kg, w=w, u=u, gk=g, # USE_G=False, USE_GK=True
context=cp_context,
)
# 3. Main kernel: pre-gated kg, per-dim gk=g
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=kg, w=w, u=u, gk=g,
initial_state=initial_state,
)
KDA Backward
# 1. Recompute WY repr
w, u, qg, kg = recompute_w_u_fwd(q, k, v, beta, A=Akk, gk=g, ...)
# qg = Q ⊙ exp2(γ^r), kg = K ⊙ exp2(γ^{r→C})
# 2. Recompute state
h, v_new, _ = chunk_gated_delta_rule_fwd_h(k=kg, w=w, u=u, gk=g, ...)
# 3. Compute local dv
dAqk, dv = chunk_kda_bwd_dAv(q, k, v=v_new, do, A=Aqk, ...)
# 4. CP pre-process: pre-gated qg, kg, per-dim gk=g
dht, initial_state = chunk_gated_delta_rule_bwd_dhu_pre_process(
q=qg, k=kg, w=w, do=do, dv=dv, gk=g, # USE_G=False, USE_GK=True
context=cp_context,
)
# 5. Main kernel: pre-gated qg, kg
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=qg, k=kg, w=w, gk=g,
dht=dht, ...
)
Input Tensor Summary
| Function | GDN | KDA | Gate Path |
|---|---|---|---|
pre_process_fwd | k=k, g=g | k=kg, gk=g | GDN: USE_G, KDA: USE_GK |
fwd_h | k=k, g=g | k=kg, gk=g | Same as pre_process |
pre_process_bwd | q=q, k=k, g=g | q=qg, k=kg, gk=g | GDN: USE_G, KDA: USE_GK |
bwd_dhu | q=q, k=k, g=g | q=qg, k=kg, gk=g | Same as pre_process |
Key consistency: Pre-process and main kernel must always receive the same tensors — this is critical for correctness.
- KDA: Both receive pre-gated
kg() andqg() - GDN: Both receive original , (gating applied inside the kernel)
Transition Matrix
The transition matrix is central to CP — it captures how a chunk transforms any incoming state, enabling us to chain contributions from multiple ranks.
Forward:
Backward (transposed):
The diagonal term differs between models:
- GDN: — scalar times identity
- KDA: — per-dim diagonal, where (i.e.,
gk_lastin code)
Cross-rank state is computed by chaining matrices:
Important
The chain multiply must stay in fp32 to avoid accumulated precision loss. In bf16, repeatedly casting fp32 accumulators back to bf16 between iterations causes significant error growth over many chunks.
Initial State Memory Optimization
In CP mode, only the first sequence in the local batch can be a continuation from a previous rank — all other sequences start fresh. This means only one initial state is non-zero, presenting an opportunity for memory savings:
compress_h0: Extracts just that one state to save memory duringsave_for_backwardexpand_h0: Restores the full[N, H, d_k, d_v]tensor in backward
Test References
Discussion
While this document focuses on delta-rule models such as GDN and KDA, the underlying CP mechanism is not restricted to delta-rule recurrences. In fact, any linear attention formulation that can be expressed in a chunkwise form — i.e., one where the state transition across a chunk can be decomposed into a transition matrix and an accumulated state — can adopt the same pre-process + all-gather + merge strategy for context parallelism.
The only model-specific components are:
- How and are computed from the local chunk.
- How the merge kernel chains these quantities across ranks.
As long as these two operations are well-defined, the same CP infrastructure (build_cp_context, all-gather, and merge) applies without changing the high-level data flow.
At the time of writing, CP has been implemented and verified for GDN, KDA, and DPLR (a.k.a. RWKV-7). If you would like to see support for another linear-attention variant, please feel free to open an issue.
Acknowledgments
Context Parallel of Linear Attention was first introduced in PR #691, implemented by Duyue MA. It is also known as KCP (Kimi Context Parallel) internally at Moonshot AI. The implementation in this repository was independently contributed to FLA and is a separate codebase from the internal Moonshot implementation.