Flash Bidirectional Linear Attention

March 1, 2026 ยท View on GitHub

Flash Bidirectional Linear Attention

The aim of this repository is to implement bidirectional linear attention for non-causal modeling using Triton. Contributions and suggestions are welcome!

image

Update

  • [2026/02] Update PISA
  • [2025/02] Update PolaFormer
  • [2024/12] Update simple_la, a simple form of linear_attn without the norm term.

Models

Roughly sorted according to the timeline supported in flash_bla

YearModelTitlePaperCodefla impl
2024LinfusionLinFusion: 1 GPU, 1 Minute, 16K Imagearxivofficialcode
2024MLLADemystify Mamba in Vision: A Linear Attention Perspectivearxivofficialcode
2025PolaFormerPolaFormer: Polarity-aware Linear Attention for Vision Transformersarxivofficialcode
2025RALABreaking the Low-Rank Dilemma of Linear Attentionarxivofficialcode
2026PISAPISA: Piecewise Sparse Attention Is Wiser for Efficient Diffusion Transformersarxivofficialcode

Usage

Installation

git clone https://github.com/fla-org/flash-bidirectional-linear-attention.git
pip install -e flash-bidirectional-linear-attention/.

Integrated Models

This library has integrated some models, which can be called directly. Taking LinFusion as an example:

import torch
from diffusers import AutoPipelineForText2Image
from flash_bla.models import LinFusion

sd_repo = "stabilityai/stable-diffusion-xl-base-1.0"


pipeline = AutoPipelineForText2Image.from_pretrained(
    sd_repo, torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))

linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path="Yuanshi/LinFusion-XL")

image = pipeline(
    "An astronaut floating in space. Beautiful view of the stars and the universe in the background.",
    generator=torch.manual_seed(123)
).images[0]

Benchmarks

Profiled on the A800-80G GPU.

B8-H16-D64:
    T  torch_la_fwd  flash_bla_fwd  torch_sdpa_fwd  torch_la_bwd  flash_bla_bwd  torch_sdpa_bwd
 1024      0.083968       0.068608        0.073728      0.476160       0.378880        0.405504
 4096      0.178176       0.083968        0.784384      1.018880       0.444416        3.175424
16384      0.549888       0.283648       11.750400      3.556352       1.566720       44.189184
32768      1.034240       0.550912       47.788033      6.864896       3.040256      175.127548

Acknowledgments

Thanks to the following repositories for their inspiration: