KLASS: KL-Guided Fast Inference in Masked Diffusion Models (NeurIPS 2025 Spotlight)

December 12, 2025 ยท View on GitHub

arXiv

๐ŸŽ‰ Accepted at NeurIPS 2025 (Spotlight)

Authors: Seo Hyun Kim*, Sunwoo Hong*, Hojung Jung, Youngrok Park, Se-Young Yun
*Equal contribution

KLASS (KL-Adaptive Stability Sampling) is a fast inference method designed to accelerate generation in masked diffusion models while maintaining high-quality outputs.

This repository provides an implementation of KLASS on LLaDA 8B Instruct and Dream 7B Instruct, along with evaluation scripts for standard benchmarks including GSM8K, MATH, HumanEval, and MBPP.

๐Ÿš€ Installation

  1. Create and activate the conda environment:

    conda create -n klass python=3.12
    conda activate klass
    
  2. Install dependencies and models:

    bash install.sh
    

    This script updates generation_utils.py in Dream with a customized version adapted for KLASS.

๐Ÿ“Š Evaluation

We provide ready-to-run evaluation scripts for all supported models and datasets.

LLaDA

# GSM8K
bash scripts/llada_gsm8k.sh

# MATH
bash scripts/llada_math.sh

# Humaneval
bash scripts/llada_humaneval.sh

# MBPP
bash scripts/llada_mbpp.sh

Dream

# GSM8K
bash scripts/dream_gsm8k.sh

# MATH
bash scripts/dream_math.sh

# Humaneval
bash scripts/dream_humaneval.sh

# MBPP
bash scripts/dream_mbpp.sh

โš™๏ธ Configuration & Arguments

You can customize the sampling behavior using the following arguments.

Main Sampling Algorithm

  • alg: Choose the unmasking algorithm.
    • klass: Uses KLASS sampling, which unmask tokens based on a combination of confidence and KL-divergence stability.
    • default (for LLaDA) / maskgit_plus (for Dream): Top-K confidence-based unmasking.
    • random (for LLaDA) / origin (for Dream): Random unmasking order.

KLASS-Specific Parameters

These arguments are used only when alg="klass".

  • conf_threshold: Filter out tokens with confidence lower than this value.
  • kl_threshold: Filter out tokens with a KL score higher than this value (calculated over history_length).
  • history_length: Number of recent steps to use for the KL divergence stability calculation.
  • unmask_strategy: Defines the strategy for unmasking the tokens that satisfy both the confidence and KL thresholds:
    • all: Unmask all tokens that satisfy the thresholds. (Default)
    • max_conf: Among the tokens satisfying the thresholds, unmask only the one with the maximum confidence.
    • min_kl: Among the tokens satisfying the thresholds, unmask only the one with the minimum KL score.

Debugging

  • save_steps: If set, this flag saves the detailed results of each generation step (including position, token ID, confidence, and KL divergence for all tokens) for analysis.

๐Ÿ™ Acknowledgements

This codebase builds upon the official implementations of LLaDA, Dream, and HumanEval. We thank the original authors for their open-source contributions.