CORE: Context-Robust Remasking for Diffusion Language Models

June 8, 2026 ยท View on GitHub

Kevin Zhai, Sabbir Mollah, Zhenyi Wang, Mubarak Shah
University of Central Florida

arXiv Project Page

CORE is a training-free, inference-time revision method for Masked Diffusion Models. Standard decoders freeze a token once it is unmasked, even when later context exposes it as wrong. Instead of trusting static/stale confidence, CORE identifies context-brittle tokens by stress-testing them: it masks a small candidate set, measures each token's instability (drop in likelihood) under that perturbed context, and remasks the most unstable ones for resampling. The method plugs into the LLaDA sampler and adds only a handful of extra forward passes.

Setup

Requires Python 3.12 and a CUDA GPU (experiments use a single A100-80GB).

# 1) Install torch from the appropriate CUDA index (see requirements.txt for notes)
pip install torch==2.11.0 --index-url https://download.pytorch.org/whl/cu128

# 2) Install the rest of the pinned dependencies
pip install -r requirements.txt

Dependency compatibility (important):

  • transformers must be 4.46.x. LLaDA's trust_remote_code modeling file predates the transformers 5.x loader refactor and fails to load on 5.x.
  • lm-eval 4.x exposes LM.device as a read-only property; eval.py handles this compatibly.

The base model GSAI-ML/LLaDA-8B-Base (~15 GB) is fetched from the Hugging Face Hub on first run. Set HF_HOME to control the cache location, and after the initial download you can run cache-only with HF_HUB_OFFLINE=1:

export HF_HOME=/path/to/hf_cache
hf download GSAI-ML/LLaDA-8B-Base   # optional explicit prefetch

Quick Start

Reproduce the main-results table (runs low_confidence and core across GSM8K, HumanEval, Minerva-MATH, BBH, MBPP with the paper's few-shot counts):

bash run_all.sh

Running a Single Task

run.sh takes the task name as its argument and reads settings from environment variables:

METHOD=core STEPS=128 NUM_FEWSHOT=3 SEED=1234 bash run.sh mbpp
VariableDefaultDescription
METHODlow_confidenceUnmasking / remasking strategy (see below)
STEPS128Number of diffusion steps
NUM_FEWSHOT0Few-shot examples
SEED1234Random seed

Under the hood this calls eval.py (an lm-evaluation-harness plugin registered as llada_dist) with gen_length=512, block_length=512. To pass other harness flags (e.g. --limit for a quick smoke test), call eval.py directly:

python eval.py --llada_seed 1234 --tasks gsm8k --num_fewshot 4 --limit 2 \
  --confirm_run_unsafe_code --model llada_dist --log_samples \
  --output_path ./logs/smoke_gsm8k_core \
  --model_args model_path=GSAI-ML/LLaDA-8B-Base,gen_length=256,steps=64,block_length=256,remasking=core

Remasking strategies (remasking= / METHOD)

ValueDescription
low_confidenceStandard LLaDA unmasking baseline
topk_marginTop-2 probability margin unmasking
randomRandom unmasking
coreCORE โ€” instability-based context-robust remasking (ours)
margin_remaskCompute-matched control: remask by smallest margin
random_remaskCompute-matched control: remask at random

CORE knobs (environment variables, read in generate.py)

VariableDefaultDescription
REVISE_EVERY8Run a revision pass every E steps (0 disables)
CANDIDATE_M32Candidate set size m stress-tested per revision pass
BASE_MASKINGconfidenceBase unmasking score: confidence or margin
JOINT_REEVAL0If 1, add a forward pass on the corrected context (ablation; +1 NFE)
MECH_SAVE_DIR(unset)If set, dump per-token instability/mechanism stats here
CORE_DEBUG0If 1, print verbose per-step [budget]/[remask] traces
TEMPERATURE0.0Sampling temperature (>0 enables stochastic decoding + per-example seeding)

Revision is active only in the intermediate step window [0.25, 0.75) and revises at most k_rm = 1 token per pass (matching the paper).

Citation

If you use this code or find the paper useful, please cite:

@article{zhai2026corecontextrobustremaskingdiffusion,
      title={CORE: Context-Robust Remasking for Diffusion Language Models}, 
      author={Kevin Zhai and Sabbir Mollah and Zhenyi Wang and Mubarak Shah},
      year={2026},
      eprint={2602.04096},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2602.04096}, 
}

Acknowledgements

We thank colleagues and collaborators for discussions and feedback that improved this work. We also acknowledge the authors and maintainers of the open-source libraries, pretrained models (e.g., LLaDA), and evaluation suites (e.g., lm-eval) used in this repository, along with the creators of the benchmarks and datasets used in our experiments. Finally, we appreciate the compute infrastructure and operational support that enabled the runs reported here.