FlashLib

June 18, 2026 · View on GitHub

Slack Discord

A GPU library for classical machine-learning operators — kmeans, knn, ivf-flat, pca, svd, dbscan, hdbscan, umap, t-sne, regression, GEMM, and more — built on Triton and CuteDSL.

See the blog post for motivation, design, and benchmarks.

Installation

Install with pip:

pip install flashlib

From source:

git clone https://github.com/FlashML-org/flashlib.git
cd flashlib
pip install -e .

Usage

import torch
from flashlib import flash_kmeans

x = torch.randn(1_000_000, 128, device="cuda", dtype=torch.float32)
labels, centroids, n_iter = flash_kmeans(x, n_clusters=1024, max_iters=20)

Every primitive is exposed as a top-level flash_* function and as a sklearn-style class (KMeans, PCA, HDBSCAN, …).

Index-based primitives like IVF-Flat and IVF-PQ (GPU approximate nearest neighbours) build an index once and query it many times:

import torch
from flashlib import IVFFlat

db = torch.randn(1_000_000, 128, device="cuda")
queries = torch.randn(10_000, 128, device="cuda")

index = IVFFlat(nlist=1024, nprobe=16).fit(db)
distances, indices = index.kneighbors(queries, n_neighbors=10)  # squared L2

nprobe is the recall knob: at a fixed (nlist, nprobe) the probed candidate set — and thus recall — matches a reference IVF-Flat (FAISS / cuVS), so raising it trades speed for recall without changing the kernel.

For billion-scale corpora, IVFPQ adds product-quantization compression: each vector is stored as `m$ 1-\text{byte} \text{codes} (8–32 \times \text{smaller} \text{than} \text{fp32}):

$``python from flashlib import IVFPQ

128-dim fp32 (512 B/vec) -> m=16 PQ codes (16 B/vec) = 32x compression

index = IVFPQ(nlist=1024, m=16, nprobe=16).fit(db) distances, indices = index.kneighbors(queries, n_neighbors=10) # ADC squared L2 print(index.compression_ratio) # 32.0



### Informative API

The `flashlib.info` submodule predicts runtime, FLOPs, and HBM bytes for any
primitive in ~5 µs on pure CPU — useful for budgeting a pipeline before
launching it, and small enough for an LLM agent to call in a GPU-less
environment. It does not import torch, triton, or cutlass.

```python
import flashlib.info as info

est = info.estimate("kmeans",
                    shape=(100_000, 64),
                    params={"K": 256, "max_iters": 20},
                    device="H200")
print(est.summary_line())

See the blog post for the full API, the tolerance-driven dispatch, and per-primitive benchmarks.

Coverage

The current release ships 17 high-level primitives across the following families:

familyprimitives
Clusteringflash_kmeans, flash_dbscan, flash_hdbscan, flash_spectral_clustering
Nearest nbrsflash_knn, flash_ivf_flat (IVF-Flat ANN), flash_ivf_pq (IVF-PQ ANN)
Decompositionflash_pca, flash_truncated_svd
Manifoldflash_umap, flash_tsne
Regressionflash_linear_regression, flash_ridge, flash_logistic_regression
Classificationflash_multinomial_nb, flash_random_forest
Preprocessingflash_standard_scaler

Plus low-level linear-algebra primitives (cov_gemm, gram_gemm, ab_gemm, eigh, polar, msign, cholqr2, split_basis) and a Pareto-frontier set of multi-precision GEMM variants (gemm, gemm_tf32, gemm_3xtf32, gemm_bf16, gemm_fp16, gemm_fp16_x9, gemm_fp16_x3_kahan, gemm_ozaki2_int8, …).

Citation

@misc{yang2026flashlib,
  title  = {FlashLib: Bringing Flash Magic to Classical Machine Learning Operators},
  author = {Yang, Shuo and Xi, Haocheng and Zhao, Yilong and Mang, Qiuyang and
            Wang, Zhe and Sun, Shanlin and Keutzer, Kurt and Gonzalez, Joseph E. and
            Han, Song and Xu, Chenfeng and Stoica, Ion},
  year   = {2026},
  url    = {https://flashml-org.github.io/},
}

License

Apache License 2.0.