Wind RWKV

March 24, 2025 · View on GitHub

A repository with optimized kernels for RWKV language models. Currently focused on RWKV-7.

Kernel benchmarks for RWKV-7

The kernels were timed using tests/speed_test.py with modeldim 4096 and varying (batch size, head size, sequence length) as labeled in the table.

H100

Kernel(8,64,4096)(8,128,4096)(8,256,4096)(1,256,32768)Peak VRAM1Typical error
Chunked bf168 ms11 ms54 ms224 ms5 - 8 GB5e-3
Backstepping fp32 longhead23 ms46 ms80 ms124 ms8 - 14 GB9e-5
Backstepping fp32 smallhead17 ms101 ms862 ms1802 ms7 - 13 GB9e-5
Triton bighead fp3266 ms87 ms168 ms1175 ms6 - 12 GB5e-5
Triton bighead bf16229 ms59 ms358 ms6 - 12 GB5e-3
FLA chunk_rwkv764 ms62 ms89 ms93 ms12 - 13 GB4e-3

MI300X

Kernel(8,64,4096)(8,128,4096)(8,256,4096)(1,256,32768)Peak VRAM1Typical error
Backstepping fp32 longhead29 ms39 ms75 ms162 ms8 - 14 GB9e-5
Backstepping fp32 smallhead251 ms757 ms2706 ms15025 ms7 - 13 GB9e-5
Triton bighead fp3267 ms100 ms287 ms2073 ms6 - 12 GB5e-5
Triton bighead bf1642 ms72 ms198 ms1453 ms6 - 12 GB5e-3
FLA chunk_rwkv752 ms61 ms98 ms202 ms12 - 13 GB4e-3

Kernel descriptions

The RWKV-7 kernels all compute the following:

def naive(r,w,k,v,a,b,s):
    y = th.empty_like(v)
    for t in range(w.shape[1]):
        s = s * th.exp(-th.exp(w[:,t,:,None,:])) + s @ a[:,t,:,:,None] * b[:,t,:,None,:] + v[:,t,:,:,None] * k[:,t,:,None,:]
        y[:,t,:,:,None] = s @ r[:,t,:,:,None]
    return y, s

Here r,w,k,v,a and b have shape [batch size, sequence length, num heads, head size], while the initial state s has shape [batch size, num heads, head size, head size]. All inputs and outputs are bfloat16 precision.

Chunked bf16

This is the fastest kernel when applicable. It processes the sequence in chunks of length 16 (chunked formulation) and uses Ampere (CUDA SM80+, i.e., A100 and later) instructions for fast bfloat16 matmuls.

Backstepping fp32 smallhead

This is essentially the official kernel which was used to train the RWKV-7 World models. Calculates gradients by iterating the state backwards in time (max 15 steps). This makes the code simple, but requires 32-bit floats and limits the decay to ca. 0.5.

Backstepping fp32 longhead

Backstepping fp32 smallhead becomes very slow for large head sizes, since the full state is kept in registers, which overflow into global memory. To fix this, backstepping fp32 longhead uses the observation that the columns of the state are essentially updated independently. So it processes blocks of 64 or 32 columns indepdently. This increasing parallelization, and keeps less state in shared memory at a time, while keeping most of the simplicity of backstepping fp32 smallhead.

Triton bighead

A simple chunked kernel written in triton. The kernel stores intermediate states in global memory instead of shared memory, so it handles large head sizes (like 1024) without crashing. It takes a flag to choose fp32 or bf16 precision3 which affects all matmuls inside the triton kernel.

FLA chunk_rwkv7

RWKV-7 triton kernel from Flash Linear Attention. Chunked implementation with partial sequence length parallelization.

Footnotes

  1. Smallest peak VRAM was typically for (8,64,4096) and largest for (8,256,4096). 2

  2. Triton fails to compile the kernel, only seen on H100.

  3. The kernel also supports tf32 precision for matmuls, but tf32 seems to run into bugs in the triton language, so I didn't expose it.