Adding Custom Unmasking Methods

March 20, 2026 ยท View on GitHub

This guide explains how to add a new unmasking method to ParallelBench.

How Unmasking Works

At each denoising step, the model predicts tokens for all masked positions. The unmasking method decides which predictions to accept and which to keep masked for later refinement. Each method computes a per-token confidence score โ€” higher scores get unmasked first.

There are four method types:

TypeBehaviorCLI parameterTPSPBx scoring
topkFixed k tokens unmasked per stepkDeterministic (= k)Discrete
thresholdUnmask tokens above a confidence thresholdalg_thresholdMeasuredInterpolated
factorScale unmask count by a factoralg_factorMeasuredInterpolated
adaptiveDynamic per-step unmasking (e.g., KLASS)Method-specificMeasuredInterpolated

What You Need to Change

FileWhat to do
parallelbench/models/confidence_scorers.pyImplement the confidence score function
parallelbench/models/unmasking_registry.pyRegister the method with its confidence scorer
parallelbench/models/local/<model>/constants.pyAdd the method to each model's valid set

1. Implement the Confidence Score

Add a function to parallelbench/models/confidence_scorers.py. Every scorer has the same signature:

def my_scorer(p: torch.Tensor, x0: torch.Tensor, x0_p: torch.Tensor) -> torch.Tensor:
    """
    Args:
        p: Token probability distribution (batch, seq_len, vocab_size).
        x0: Predicted token ids (batch, seq_len).
        x0_p: Pre-computed max/sampled probability (batch, seq_len).

    Returns:
        Per-token confidence tensor (batch, seq_len).
    """
    return ...

Common confidence patterns

PatternComputationIntuition
Max probabilityreturn x0_pHow certain the top prediction is
Margintop1 - top2Gap between best and second-best
Negative entropysum(p * log(p))More concentrated = more confident
Randomtorch.rand(...)Uniform baseline

Existing scorers

FunctionUsed by
max_probabilityconfidence_topk, confidence_threshold, confidence_factor
margintopk_margin
negative_entropyentropy_topk
random_confidencerandom

2. Register the Method

Add your method to UNMASKING_REGISTRY in parallelbench/models/unmasking_registry.py:

from parallelbench.models.confidence_scorers import my_scorer

UNMASKING_REGISTRY: dict[str, MethodInfo] = {
    # ... existing entries ...
    "my_method": MethodInfo("topk", "k", derive_topk, my_scorer, ("k",)),
}

The five arguments are:

ArgumentDescription
method_type"topk", "threshold", "factor", or "adaptive"
representative_paramPrimary CLI parameter used for deriving steps/block_length
derive_fnFunction that derives steps/block_length from the representative param
confidence_fnConfidence scorer function (or None)
config_paramsTuple of gen_kwargs keys used to distinguish configs in PBx scoring

The config_params field is critical: it tells pb analyze which hyperparameters to extract from results and how to group configs for PBx score computation. For top-k methods this is ("k",). For methods with multiple hyperparameters (like KLASS), list all of them:

"klass": MethodInfo(
    "adaptive", "k", derive_adaptive, max_probability,
    ("conf_threshold", "kl_threshold", "kl_history_length"),
),

Reuse existing derive functions (derive_topk, derive_threshold, derive_factor, derive_adaptive) and scorers when possible.

You can also register dynamically:

from parallelbench.models.unmasking_registry import MethodInfo, register_method
register_method("my_method", MethodInfo("topk", "k", derive_topk, my_scorer, ("k",)))

3. Add to Model Valid Sets

Each model declares which methods it supports. Add your method name to the relevant VALID_METHODS sets:

# parallelbench/models/local/llada/constants.py
LLADA_VALID_METHODS = {
    "random",
    "confidence_topk",
    "confidence_threshold",
    "confidence_factor",
    "topk_margin",
    "entropy_topk",
    "my_method",  # add here
}

Repeat for each model that should support the method (e.g., dream/constants.py, trado/constants.py).

4. Verify

pb eval --model parallelbench_llada \
  --model_args model_path=GSAI-ML/LLaDA-1.5 \
  --gen_kwargs k=4,max_tokens=32,unmasking=my_method \
  --tasks parallelbench_waiting_line_copy \
  --include_path parallelbench/tasks \
  --batch_size 1 \
  --limit 2