Infinite Mask Diffusion for Few-Step Distillation
May 15, 2026 ยท View on GitHub
Infinite Mask Diffusion for Few-Step Distillation
Jaehoon Yoo*, Wonjung Kim*, Chanhyuk Lee, Seunghoon Hong
KAIST
TL;DR
We propose Infinite Mask Diffusion Model, which leverages the simple design and effective conditional generation of Masked Diffusion Models while overcoming their theoretical lower bound of factorization error.
Overview
Masked Diffusion Models (MDMs) have emerged as a promising alternative to autoregressive models in language modeling, offering the advantages of parallel decoding and bidirectional context processing within a simple yet effective framework. Specifically, their explicit distinction between masked tokens and data underlies their simple framework and effective conditional generation. However, MDMs typically require many sampling iterations due to factorization errors stemming from simultaneous token updates. We observe that a theoretical lower bound of the factorization error exists, which standard MDMs cannot reduce due to their use of a deterministic single-state mask. In this paper, we propose the Infinite Mask Diffusion Model (IMDM), which introduces a stochastic infinite-state mask to mitigate the theoretical bound while directly inheriting the benefits of MDMs, including the compatibility with pre-trained weights. We empirically demonstrate that MDM fails to perform few-step generation even in a simple synthetic task due to the factorization error bound, whereas IMDM can find an efficient solution for the same task. Finally, when equipped with appropriate distillation methods, IMDM surpasses existing few-step distillation methods at small step counts on LM1B and OpenWebText.
Project Structure
โโโ config/ <- Config files for datasets/denoising networks/noise schedules/LR schedules.
| โโโ config.yaml <- Main config file
|
โโโ integral/
|
โโโ models/ <- Denoising network architectures. Supports [DiT](https://arxiv.org/abs/2212.09748) and AR transformer.
| โโโ dit.py <- DiT structure
| โโโ ema.py <- EMA model
| โโโ unit_test_attention.py <- Attention module
|
โโโ scripts/ <- Shell scripts for training/evaluation.
| โโโ lm1b <- Shell scripts for LM1B dataset
| โโโ owt <- Shell scripts for OpenWebText dataset
| โโโ owt_large <- Shell scripts for OpenWebText dataset with 860M large models
|
โโโ algo.py <- Main model structures: Algorithms such as DUO, MDLM, AR, SEDD, D3PM, ReDi, IMDM, IM-SDTT, IM-ReDi.
โโโ dataloader.py <- Dataloader and tokenizer module
โโโ eval_mauve.py <- Eval the MAUVE score from cond. generation samples
โโโ LICENSE <- Apache License 2.0
โโโ main.py <- Main
โโโ metrics.py <- Metrics module
โโโ README.md
โโโ requirements.txt <- Help to install env
โโโ trainer_base.py <- Boiler plate trainer using pytorch lightning.
โโโ utils.py <- LR scheduler, logging, `fsspec` handling.
Usage
To get started, follow these steps:
-
Install requirement
pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 pip install -r requirements.txt pip install flash-attn==2.7.4.post1 --no-build-isolation -
Download Pretrained models
# 1. OWT Finetuned models(IM-SDTT-ReDi1, IM-SDTT-ReDi1 large) # Download from Hugginface(https://huggingface.co/Ugness/IMDM) # 2. Pretrained models from MDLM paper(OWT) # Download official MDLM checkpoint from Google Drive folder(https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing). # put checkpoints into ./outputs/lm1b, ./outputs/owt, ./outputs/owt_large. -
Download LM1B, OpenWebText dataset
# The training code automatically downloads the LM1B, OWT dataset onto your local(./cache/). # Or, set LM1B and OpenWebText dataset to the cache dir(./cache/). -
Use IMDM
# LM1B ## Train bash scripts/lm1b/train_lm1b_mdlm.sh bash scripts/lm1b/train_lm1b_imdm.sh bash scripts/lm1b/train_lm1b_imsdtt_uniform.sh (required MDLM ckpt) bash scripts/lm1b/train_lm1b_imredi1.sh (required MDLM ckpt, Reflow dataset) bash scripts/lm1b/train_lm1b_imsdtt_uniform_redi1.sh (required IM-SDTT ckpt, Reflow dataset) ### Create Rectified Coupling (perturbed ReDi) bash scripts/lm1b/datagen_lm1b_mdlm.sh (required MDLM ckpt) bash scripts/lm1b/datagen_lm1b_imsdtt_uniform.sh (required IM-SDTT ckpt) ## Eval ### LM1B PPL eval bash scripts/lm1b/eval_lm1b_mdlm.sh bash scripts/lm1b/eval_lm1b_imdm.sh ### Uncond. generation eval bash scripts/lm1b/gen_ppl_lm1b_mdlm.sh bash scripts/lm1b/gen_ppl_lm1b_imsdtt_uniform.sh bash scripts/lm1b/gen_ppl_lm1b_imredi1.sh bash scripts/lm1b/gen_ppl_lm1b_imsdtt_uniform_redi1.sh # OpenWebText ## Train bash scripts/owt/train_owt_imsdtt.sh (required MDLM ckpt) bash scripts/owt/train_owt_imsdtt_redi1.sh (required IM-SDTT ckpt, Reflow dataset) ### Create Rectified Coupling (perturbed ReDi) bash scripts/owt/datagen_owt_imsdtt.sh (required IM-SDTT ckpt) ## Eval ### Zero-Shot PPL eval bash scripts/owt/zero_shot_mdlm.sh bash scripts/owt/zero_shot_imdm.sh ### Uncond. generation eval bash scripts/owt/gen_ppl_owt_mdlm.sh bash scripts/owt/gen_ppl_owt_imsdtt_redi1.sh ### Cond. generation eval bash scripts/owt/cond_gen_ppl_owt_mdlm.sh bash scripts/owt/cond_gen_ppl_owt_imsdtt_redi1.sh ### Eval MAUVE python eval_mauve.py --generation_path /path/to/your/cond/samples/json/file # OpenWebText(large model) ## Train bash scripts/owt_large/train_owt_imsdtt.sh (required MDLM ckpt) bash scripts/owt_large/train_owt_imsdtt_redi1.sh (required IM-SDTT ckpt, Reflow dataset) ### Create Rectified Coupling (perturbed ReDi) bash scripts/owt_large/datagen_owt_imsdtt.sh (required IM-SDTT ckpt) ## Eval ### Uncond. generation eval bash scripts/owt_large/gen_ppl_owt_imsdtt_redi1.sh ### Cond. generation eval bash scripts/owt_large/cond_gen_ppl_owt_imsdtt_redi1.sh ### Eval MAUVE python eval_mauve.py --generation_path /path/to/your/cond/samples/json/file
Acknowledgments
This repository is built upon the codebases of Duo and ReDi.
BibTeX
@inproceedings{yoo2026imdm,
title={Infinite Mask Diffusion for Few-Step Distillation},
author={Yoo, Jaehoon and Kim, Wonjung and Lee, Chanhyuk and Hong, Seunghoon},
year={2026},
booktitle={ICML}
}