README.md

June 12, 2026 Β· View on GitHub

πŸ€–FFPA: Yet another Faster Flash Prefill Attention
with O(1)⚑️GPU SRAM complexity for large headdimπŸ‘


FFPA(Split-D): Yet another Faster Flash Prefill Attention with Split-D strategy, achieve O(1) SRAM complexity and O(d/4) register complexity for large headdim (> 256), 1.5~3x πŸŽ‰ faster than SDPA. πŸ“šπŸ‘‡The Core features:

Self AttnGQA/MQACross AttnCausal/MaskDropoutHeaddimFwd/Bwd
βœ”οΈ(Nq=Nkv)βœ”οΈ(Hq!=Hkv)βœ”οΈ(Nq!=Nkv)βœ”οΈ(attn_mask)βœ”οΈ(p>0)320~10241.5~3x↑

πŸŽ‰πŸŽ‰ Latest News

πŸ“– Quick Start

First, install the prebuilt package from PyPI or build ffpa-attn from source:

# Fisrt, install the prebuilt package from PyPI
pip3 install -U ffpa-attn # CUDA 13.0+, PyTorch 2.11+
# Or, build ffpa-attn from source, just follow the cmds
git clone https://github.com/xlite-dev/ffpa-attn.git
# Then, build the wheel package (Triton + CuTeDSL backends)
cd ffpa-attn && pip3 install -e . --no-build-isolation
# Optional: install ffpa-attn w/ CUDA backend (forward only)
ENABLE_FFPA_CUDA_IMPL=1 MAX_JOBS=32 pip3 install -e .

Then, try to accelerate the attention for large headdim with just one-line of code:

>>> import torch.nn.functional as F
>>> from ffpa_attn import ffpa_attn_func
>>> # Monkey-patch SDPA to point to FFPA. Every thing that FFPA
>>> # does not support will auto fallback to SDPA: D <= 256, etc.
>>> F.scaled_dot_product_attention = ffpa_attn_func # one-line code

For more advanced features, please refer to our online docs at πŸ“˜ffpa-attn.io.

πŸ“– Split-D

We extend FlashAttention to support large headdim (D>256D>256) via fine-grained tiling at the MMA level for QK⊀QK^\top and PVPV matrix multiplication, referred to as Split-D. This design keeps SRAM usage fixed at BrΓ—16B_r \times 16 (with Br=BcB_r=B_c) for Q, K and V, yielding constant SRAM complexity O(BrΓ—16)β‰ˆO(1)O(B_r \times 16) \approx O(1) and register complexity O(d/4)O(d/4).

FFPA enables headdim > 256, and outperforms standard SDPA by 1.5~3xπŸŽ‰.

Note

FFPA has been tested on Ampere, Ada, Hopper, and Blackwell architectures (e.g., A30, L20, 4090, H200, 5090), achieves 1.5~3Γ—β†‘πŸŽ‰ speedup over SDPA. FFPA is mainly design for prefill and large headdim, and may not be faster than SDPA for 😈 small sequence length (N<512) or small headdim (D<=256).

πŸŽ‰ Benchmark

Runnable benchmark are provided under bench. The performance benchmarks for the NVIDIA L20 (Ada), NVIDIA Geforce RTX 5090 (Blackwell), NVIDIA H800 PCIE (Hopper), NVIDIA H200 SXM (Hopper, CuTeDSL backend, up to 427 TFLOPS!πŸŽ‰) with large headdims can be found at bench.


πŸ€– Backends

FFPA supports multiple backends for the forward and backward pass, including: SDPA (baseline), CUDA (forward only), Triton, and CuTeDSL. The CuTeDSL backend is currently in early stage and has some constraints, but it can achieve up to 427πŸŽ‰ TFLOPS on H200! Stay tuned for future updates.

BackendArchFwdBwdHeaddimAutotuneSpeedupRecommend
SDPAsm>=75βœ”βœ”All❌1.0xπŸ€—sm>=75
CUDAsm>=80βœ”βŒ320~1024❌1.5x~3xπŸŽ‰sm80~89,120
Tritonsm>=80βœ”βœ”320~1024βœ”1.5x~5xπŸŽ‰sm>=80
CuTeDSLsm>=80βœ”βœ”320~1024❌1.5x~2xπŸŽ‰sm80~89,120
CuTeDSLsm90βœ”βœ”320~512❌3x~6xπŸŽ‰sm90

Special thanks to Butterfingrz for contributing to the CuTeDSL backend! Awesome work!πŸŽ‰

How to use different backends for your own scenario? Users can simply pass the Backend configs (SDPABackend, CUDABackend, TritonBackend or CuTeDSLBackend) to ffpa_attn_func, for example:

>>> from ffpa_attn import ffpa_attn_func, CuTeDSLBackend
>>> # CuTeDSL backend, D=512 scenario, fastest on H200!πŸŽ‰
>>> o = ffpa_attn_func(q, k, v, backend=CuTeDSLBackend())

Persistent Autotune

Generate device-specific tuned configs for production deployment (currently, Triton only), avoiding per-process autotune cost. The generated JSON is saved under configs dir and automatically loaded when runtime autotune is disabled (the default). See the docs of Triton Autotune for details.

python -m ffpa_attn.autotune --mode max --full-tasks --overwrite # 1 GPU
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # Multi-GPU (`pip install ray`)
python -m ffpa_attn.autotune --mode max --full-tasks --num-gpus 8 --overwrite

End-to-End (E2E) Training

NVIDIA-NeMo Automodel PR #2436 shows that on Gemma4-31B training (L=8192, 8xH200, FSDP2 + Activation Checkpointing), accelerating the 10/60 D=512 full-attention layers with ffpa-attn delivers about 1.4x-1.5xπŸŽ‰ higher throughput (E2E) than SDPA at similar memory footprint, with loss aligned within normal bf16 noise.

©️License

Apache License 2.0

©️Citations

@misc{deftruth2026ffpa,
  author       = {DefTruth and Butterfingrz},
  title        = {FFPA: Efficient Flash Prefill Attention for Large Head Dimensions via Split-D},
  year         = {2026},
  publisher    = {Zenodo},
  version      = {v1.0},
  doi          = {10.5281/zenodo.20638547},
  url          = {https://doi.org/10.5281/zenodo.20638547}
}

πŸ“– References