SFT Structured-TOML Config Reference

June 26, 2026 · View on GitHub


Table of Contents


Overview

Every SFT recipe under examples/toml/sft_config/<recipe>.toml is parsed against the pydantic schema SFTExperimentConfig. Each top-level TOML section ([job], [model], …) maps to one sub-model in that file. The schema is strict — every sub-model sets extra="forbid", so an unknown key raises ValidationError before training starts (typo guard).

After validation, the TOML dict is converted to a Hydra override list by build_hydra_overrides (see VFM ↔ VLM path remaps), and load_experiment_from_toml(...) loads the base config (chosen by [job].task) and applies the overrides via Hydra. Trailing CLI overrides passed after -- to cosmos_framework.scripts.train are appended last, so they win over TOML values.

TOML at a glance

[job]                                # run identity + base-config / experiment selector
[model]                              # top-level model knobs
[model.ema]                          # EMA tracking of generation-pathway weights
[model.parallelism]                  # FSDP / context-parallel / CFG-parallel topology
[model.compile]                      # torch.compile knobs
[model.activation_checkpointing]     # AC mode + recompute knobs
[model.tokenizer]                    # VFM only: Wan VAE
[model.backbone]                     # VLM only: backbone HF id + optional safetensors path
[optimizer]                          # AdamW
[optimizer.lr_multipliers]           # optional per-substring lr multipliers
[scheduler]                          # LambdaLinear / LambdaCosine
[trainer]                            # max_iter, grad_accum, logging
[trainer.callbacks.compile_tokenizer]  # VFM only
[trainer.callbacks.grad_clip]        # clip_norm + force_finite
[checkpoint]                         # load_path, save_iter, key-skip blocklist
[dataloader_train]                   # top-level scalars only
[custom]                             # free-form, project-owned escape hatch (opaque to the framework)

The full pipeline (dataloader class, dataset wiring, model_instance LazyCall, etc.) lives in the experiment SKU Python file under cosmos_framework/configs/base/experiment/sft/<recipe>.py. The TOML only surfaces values the recipe author wants users to tune.

[job]

Run identity + meta-fields that pick the Hydra config tree to load.

fielddefaultdescription
task"vfm"META — chooses which make_config() to call: "vfm"cosmos_framework/configs/base/config.py, "vlm"cosmos_framework/configs/base/vlm/config.py. Also picks the path-remap rules in toml_config_helper.PATH_REMAPS.
experiment""META — names the Hydra experiment LazyDict registered in ConfigStore under experiment/<name>. Resolved at load time via experiment=<name> (e.g. vision_sft_nano).
project""W&B project (team-level bucket). Flows to config.job.project.
group""W&B sub-label for clustering related runs (e.g. "sft"). Flows to config.job.group.
name""W&B run name; forms part of the output dir $IMAGINAIRE_OUTPUT_ROOT/<project>/<group>/<name>/. Leave empty (or use ${now:%Y-%m-%d}_${now:%H-%M-%S}) for auto-timestamped subdir.
wandb_mode"disabled""online" (real-time, needs WANDB_API_KEY), "offline" (log locally, sync later via wandb sync), or "disabled".

[model]

Top-level model knobs. Lands at model.config.* on VFM and on VLM; sub-tree paths are remapped per the VFM ↔ VLM path remaps.

fielddefaultdescription
max_num_tokens_after_packing13312Token-packing target: max tokens after sequence packing. -1 disables the cap. VFM only — VLM uses data_setting.max_tokens (tail override).
joint_attn_implementation"two_way"VFM attention layout: "two_way" (U/G blocks with cross-attention), "three_way" (adds sparsity-aware third block — NATTEN), or "flex" (legacy). VFM only.
attn_implementation"cosmos"VLM HF attention impl: "cosmos" (NATTEN/Blackwell-FMHA wrapper), "flash_attention_2", "sdpa", or "eager". VLM only.
lora_enabledfalseInject LoRA adapters into the generation pathway BEFORE FSDP wraps the network. Pair with optimizer.keys_to_select=["lora_"] and checkpoint.keys_to_skip_loading=[…, "lora_"]. Used by SUPER-tier recipes; NANO leaves it off. VFM only.
lora_rank16LoRA rank r. Adapter shape is (rank × hidden_dim) per target module. Typical: 4 / 8 / 16 / 32.
lora_alpha32LoRA scaling factor. Effective magnitude is alpha / rank; rank=16 alpha=32 → 2× scale.
lora_target_modules"q_proj_moe_gen,k_proj_moe_gen,v_proj_moe_gen,o_proj_moe_gen"Comma-separated substrings of param names that receive an adapter. Default targets the four MoE-gen projection matrices.
precision"bfloat16"Compute dtype for forward/backward (MixedPrecisionPolicy.param_dtype). "bfloat16" is standard for Hopper/Blackwell. (Was [model.parallelism].precision before the ParallelismConfig split.)

[model.ema]

Exponential Moving Average of generation-pathway weights. Lands at model.config.ema.* on both VFM and VLM. When enabled, the trainer keeps a second fp32 copy of trainable params updated as ema_w = (1 - rate^k) · w_curr + rate^k · ema_w_prev. EMA weights are used for inference; live weights keep training.

fielddefaultdescription
enabledtrueTurn EMA tracking on/off. Full fine-tunes typically enable it; LoRA recipes leave it off because the adapter weights are tiny.
rate0.1Base EMA decay. Lower = slower decay = EMA tracks live weights more tightly. Per-step rate is ramped by the iteration counter so the EMA "warms up" from init.
iteration_shift0Step offset added before computing the warmup ramp. Use a positive value when resuming so the EMA doesn't reset to "early-iter" decay strength.

[model.parallelism]

FSDP / context-parallel / classifier-free-guidance topology. Lands at model.config.parallelism.* on both VFM and VLM.

fielddefaultdescription
data_parallel_shard_degree-1FSDP shard degree. -1 = auto-fit WORLD_SIZE from torchrun. Set explicitly to make the run fail loudly on the wrong GPU count.
data_parallel_replicate_degree1HSDP outer replicate degree. >1 runs the same shard topology N times in parallel; usually only needed at very large cluster scale.
context_parallel_shard_degree1Splits the sequence dimension across this many ranks so long-context models fit in memory. Used by super-tier configs (e.g. DP=4, CP=2 → 8 GPUs).
cfg_parallel_shard_degree1Splits the duplicated conditional/unconditional CFG forward across ranks. Almost always 1 for SFT.

The product data_parallel_shard_degree × data_parallel_replicate_degree × context_parallel_shard_degree × cfg_parallel_shard_degree must equal WORLD_SIZE.

[model.compile]

torch.compile knobs. Lands at model.config.compile.* on both VFM and VLM. Both fields used to live on [model.parallelism] — the rename is the only behavior change.

fielddefaultdescription
enabledfalsetorch.compile the network. (Was [model.parallelism].use_torch_compile.) Big speedup on stable shapes; conflicts with some custom CUDA kernels and deterministic modes.
compile_dynamictrueWhen enabled=true, recompile per-shape rather than specializing for a single static shape. Required for the compile_tokenizer callback's progressive warmup.

[model.activation_checkpointing]

Recompute activations during backward to trade FLOPs for memory. Lands at model.config.activation_checkpointing.* on both VFM and VLM.

fielddefaultdescription
mode"full""selective" (per-op SAC — keep matmuls/FMHA, recompute the rest; MoT path only), "full" (checkpoint each whole transformer block), or "none" (no checkpointing — fastest, highest memory).
save_ops_regex["fmha"]Regex patterns for ops to KEEP saved under mode="selective". Ignored in "full"/"none". Default keeps flash/multi-head-attention outputs.
preserve_rng_statetrueStash + restore CUDA RNG across recompute boundaries. Required for deterministic equivalence with the non-checkpointed path; small slowdown.
determinism_check"default"Forwarded to torch.utils.checkpoint. "default" disables the extra determinism check; "match" cross-checks recomputed activations against the original (debug-only, very slow).

[model.tokenizer] (VFM only)

Video tokenizer (VAE) settings. VLM skips this sub-tree (path-remap blocks it).

fielddefaultdescription
vae_path"pretrained/tokenizers/video/wan2pt2/Wan2.2_VAE.pth"Path to Wan2.2_VAE.pth. SFT recipes typically pass this via env interpolation: vae_path = "${oc.env:WAN_VAE_PATH}".

[model.backbone] (VLM only)

Foundation backbone settings. VFM skips this sub-tree — VFM keeps its backbone wiring inline in the experiment Python (vlm_config.model_instance).

fielddefaultdescription
model_name"???" (MISSING)HF repo ID or local snapshot path of the VLM backbone (e.g. "Qwen/Qwen3-VL-8B-Instruct"). Drives AutoConfig + AutoModel selection (architecture). Remapped to model.config.policy.backbone.model_name.
safetensors_path"???" (MISSING)Optional local path to a .safetensors file (or directory) used for weight loading. When set, overrides the auto-downloaded snapshot under model_name; the architecture is still driven by model_name. Useful for pointing at a converted/finetuned checkpoint while keeping the public HF model_name for tokenizer/architecture discovery. Remapped to model.config.policy.backbone.safetensors_path.

[optimizer]

AdamW-family optimizer parameters. Same shape on VFM and VLM (eps skipped on VLM).

fielddefaultdescription
betas[0.9, 0.99]Adam β1, β2 — gradient and squared-gradient EMAs. Standard pair is (0.9, 0.999); SFT recipes commonly use (0.9, 0.99) or (0.9, 0.95) for tighter tracking of recent gradients.
eps1.0e-8Adam numerical-stability epsilon. 1e-8 is PyTorch default; 1e-6 is sometimes used in bf16 to avoid underflow in the squared-gradient denominator. VFM only.
fusedtrueUse the fused AdamW kernel. Faster on modern GPUs; slightly different numerical behavior vs the foreach implementation.
keys_to_select[]Substring allowlist for params that the optimizer trains. Empty = train everything. ["lora_"] = LoRA-only fine-tune (freezes everything except adapters).
lr2.0e-4Base learning rate.
lr_multipliers{}Per-param-group LR multipliers (<substring> = <multiplier>). E.g. action_modality_embed = 5.0 gives that param group 5× the base lr. Substrings not in the dict default to 1.0.
weight_decay0.0AdamW decoupled weight decay. 0 disables.

[scheduler]

LambdaLinear / LambdaCosine LR scheduler. All four f_* values are ratios of optimizer.lr — effective LR at the corresponding milestone = lr × f_x. Each list has one entry per scheduler cycle.

fielddefaultdescription
cycle_lengths[20000]Length of each cycle in optimizer steps. With one entry, the scheduler completes one full warmup→peak→trough cycle over that many iterations.
f_max[1.0]Peak LR multiplier reached at the end of warmup.
f_min[0.0]Final LR multiplier at the end of each cycle (the "floor"). For LambdaCosine the LR decays toward lr × f_min.
f_start[1.0e-6]Initial LR multiplier at step 0, before warmup ramps up.
verbosity_interval0How often the scheduler logs current LR (in optimizer steps). 0 = silent. VFM only.
warm_up_steps[100]Linear warmup duration in optimizer steps. LR ramps from lr × f_start to lr × f_max linearly before cosine/linear decay begins.

[trainer]

fielddefaultdescription
distributed_parallelism"fsdp"Distributed strategy. "fsdp" is the only supported value today.
grad_accum_iter1Micro-batches accumulated before each optimizer.step(). Effective global batch = grad_accum_iter × per-rank batch × world_size.
logging_iter50Console / W&B log frequency (in optimizer steps).
max_iter500Total number of optimizer steps the run will execute.

[trainer.callbacks.compile_tokenizer] (VFM only)

Lazy torch.compile of the VAE tokenizer once shapes stabilize. VLM skips this — no tokenizer to compile.

fielddefaultdescription
enabledtrueMaster switch for the callback.
compile_after_iterations3Wait this many training iterations after start before triggering the compile (lets one-shot init / dataloader settle).
warmup_resolutionsnullResolutions to "prime" the compile cache with. The callback runs the tokenizer once per listed resolution so the compiled graph for each is ready before training hits it. null = use whatever resolutions the tokenizer's encode_chunk_frames knows.

[trainer.callbacks.grad_clip]

fielddefaultdescription
clip_norm1.0Maximum global L2 norm of the gradient. Steps with a larger norm are rescaled so ‖grad‖ ≤ clip_norm.
force_finitetrueWhen true, replace NaN/Inf grads with zero before the step (treats them as no-op rather than crashing). VFM default true; VLM default false.

[checkpoint]

Resume + save policy. Lands at config.checkpoint.*.

fielddefaultdescription
keys_to_skip_loading[]Substring blocklist applied at load time. Any tensor whose FQN contains one of these substrings is skipped (kept at fresh init). Used to mask EMA / LoRA / action tensors when warm-starting from a base checkpoint that doesn't have them.
load_path"???" (MISSING)Path to the checkpoint directory to load. The MISSING sentinel is skipped from the override list, so the user must provide a real path at runtime — typically via env interpolation "${oc.env:BASE_CHECKPOINT_PATH}" in the TOML, or a CLI tail override.
save_iter100Save a new checkpoint every N optimizer steps.

[dataloader_train]

Top-level dataloader scalars only. The dataloader's class (LazyCall) and full pipeline wiring (datasets, packers, …) stay in the experiment Python — they vary too much between VFM IterativeJointDataLoader, PackingDataLoader, and VLM DataPackerDataLoader to model uniformly.

fielddefaultdescription
max_samples_per_batchnullCap on samples per micro-batch. Remapped to max_batch_size on the VLM DataPackerDataLoader. null = no per-count cap (the packer's token budget is what limits batch size).
max_sequence_lengthnullCap on tokens per packed sequence. Remapped to max_tokens on the VLM DataPackerDataLoader. null = no per-token cap.
seed42Dataloader RNG seed. VFM only — skipped on VLM (DataPackerDataLoader has no seed ctor kwarg).

[custom] (free-form escape hatch)

[custom] lets a project carry its own config (dataset paths, sampling ratios, …) in the same TOML as the framework knobs. The framework never looks inside it — it's the one section exempt from the extra="forbid" typo guard (every other section still rejects unknown keys).

How it works:

  • Arbitrary nested content passes through verbatim — scalars, sub-tables ([custom.a.b]), arrays-of-tables ([[custom.items]]).
  • It does not go through Hydra. After load_config finishes, the table is attached as a plain dict via config.custom = raw.get("custom", {}) (or {} when absent — reading config.custom is always safe).
  • So values must be concrete: ${custom} interpolation is not supported, and config.custom is not part of config.to_dict() / serialized config dumps.
[custom]
your_custom_files = "custom_value"

Read it directly to wire your own pipeline:

project_cfg = TrainingDatasetConfig.model_validate(config.custom)

Cross-cutting behaviors

"???" (MISSING) sentinel

A handful of fields default to the literal string "???" — the OmegaConf MISSING sentinel. build_hydra_overrides recognizes this value and emits no override for the corresponding key (see _emit_with_remap in toml_config_helper.py). The effect: if the TOML doesn't explicitly set the field, the value falls through to whatever the experiment Python (or its Hydra base config) sets — instead of emitting key='' which would overwrite the inherited value with empty string.

Fields with this pattern today:

  • [checkpoint].load_path
  • [model.backbone].model_name
  • [model.backbone].safetensors_path

Env interpolation

Recipe TOMLs typically interpolate paths from the environment so the same TOML works across filesystems:

[checkpoint]
load_path = "${oc.env:BASE_CHECKPOINT_PATH}"

[model.tokenizer]
vae_path = "${oc.env:WAN_VAE_PATH}"

DATASET_PATH follows the same convention but is consumed inside the experiment-SKU Python (cosmos_framework/configs/base/experiment/sft/<recipe>.py), not in the TOML.

VFM ↔ VLM path remaps

The same TOML key lands at different Hydra paths depending on [job].task:

TOML pathVFM (task="vfm") Hydra pathVLM (task="vlm") Hydra path
model.<X>model.config.<X>model.config.<X>
model.parallelism.*model.config.parallelism.*model.config.parallelism.*
model.compile.*model.config.compile.*model.config.compile.*
model.activation_checkpointing.*model.config.activation_checkpointing.*model.config.activation_checkpointing.*
model.precisionmodel.config.precisionmodel.config.precision
model.attn_implementation(skipped — VLM-only)model.config.policy.attn_implementation
model.backbone.*(skipped — VLM-only)model.config.policy.backbone.*
model.ema.*model.config.ema.*model.config.ema.*
model.tokenizer.*model.config.tokenizer.*(skipped — VFM-only)
model.{max_num_tokens_after_packing, joint_attn_implementation, lora_*}passes through(skipped — VFM-only)
dataloader_train.max_samples_per_batchpasses throughdataloader_train.max_batch_size
dataloader_train.max_sequence_lengthpasses throughdataloader_train.max_tokens
dataloader_train.seedpasses through(skipped — VLM has no seed kwarg)
optimizer.eps, scheduler.verbosity_interval, trainer.callbacks.compile_tokenizer.*passes through(skipped — VLM has no analog)

Authoritative source: PATH_REMAPS in toml_config_helper.py.

Out-of-schema knobs (Hydra tail overrides)

A few useful knobs aren't currently modeled by SFTExperimentConfig because they're either niche or experiment-specific. Pass them as trailing key.path=value positionals after --:

keypurposeused by
data_setting.max_tokensVLM token-packing cap (the VLM analogue of [model].max_num_tokens_after_packing).launch_sft_llava_ov.sh (when the launcher overrides the default)

Loading flow

load_experiment_from_toml(toml_path, extra_overrides) (in sft_config.py) is the end-to-end loader. It:

  1. Reads the TOML with tomllib.
  2. Validates the parsed dict against SFTExperimentConfig (raises ValidationError on unknown keys).
  3. Picks the base config from [job].task: TASK_TO_BASE_CONFIG["vfm"|"vlm"].
  4. Calls build_hydra_overrides(raw) to produce a ["--", "experiment=<name>", "k.p=v", …] list with per-task remaps applied and MISSING values filtered. [custom] is skipped here (it is injected verbatim in step 7, not per-leaf-remapped).
  5. Appends extra_overrides (CLI tail) so they take precedence over the TOML.
  6. Calls cosmos_framework.utils.config.load_config(base_config_path, overrides), which imports the base config module and runs make_config() (registers every config group and imports every experiment SKU's cs.store(group="experiment", …)), then override(config, overrides) has Hydra compose resolve the experiment=<name> selector against ConfigStore and apply the dotted-path overrides.
  7. Injects [custom] after loading: config.custom = raw.get("custom", {}). This runs after Hydra resolution, so it lands as a plain dict (no ${custom} interpolation; not part of serialized config dumps).

The returned Config is ready for launch().

Extending the schema

To surface a new knob in the TOML:

  1. Add a Field(default=…, description="…") line to the relevant sub-model in cosmos_framework/configs/toml_config/sft_config.py. Pick a sensible default; if the field should fall through to the experiment Python's value when omitted, use "???".
  2. (Per-task wiring only) If the new key needs to land at a different Hydra path on VFM vs VLM, or should be skipped on one task, add an entry to PATH_REMAPS in cosmos_framework/configs/toml_config/toml_config_helper.py. Plain pass-through doesn't need a remap.
  3. (Optional) Add the field to one of the example TOMLs under examples/toml/sft_config/ so users have a working reference.

extra="forbid" on every sub-model means forgetting step 1 will make any TOML that uses the new key fail validation with a clear error, so the schema can't silently diverge from real usage.