TorchSpec
May 11, 2026 ยท View on GitHub
TorchSpec is a torch-native speculative decoding training framework. We introduce a disaggregated way of training speculative decoding draft models where inference and training are fully decoupled and stream hidden states directly from inference engine groups to distributed training workers via Mooncake store, allowing each side to scale independently.
TorchSpec currently includes training flows and examples for:
- Kimi-K2.5
- MiniMax-M2.5
- Qwen3-Coder-Next
๐ค Released Models
Draft models trained with TorchSpec, available on the LightSeek Foundation Hugging Face organization:
- lightseekorg/kimi-k2.5-eagle3
- lightseekorg/kimi-k2.5-eagle3-mla
- lightseekorg/kimi-k2.6-eagle3
- lightseekorg/kimi-k2.6-eagle3-mla
๐ Blogs
- PyTorch blog: TorchSpec: Speculative Decoding Training at Scale
- Release blog: TorchSpec: Speculative Decoding Training at Scale
Table of Contents
- Architecture Overview
- Inference Backend Support
- Quick Start
- Setup
- Examples
- Training Modes
- Checkpoint Conversion
- Metrics Reporting
- Troubleshooting
Architecture Overview
TorchSpec is built around a disaggregated training pipeline:
- Inference engines generate target-model hidden states with inference engines.
- Mooncake store transfers tensors between inference and training without materializing them on disk.
- Training workers consume streamed hidden states to train speculative decoding draft models.
This separation keeps the training side focused on optimization while letting the inference side scale for hidden-state generation throughput.
Inference Backend Support
TorchSpec streams hidden states from inference engines into training workers.
| Backend | Support Tier | Status |
|---|---|---|
| vLLM | First-class | Available |
| TokenSpeed | First-class | In progress |
| SGLang | Best community effort | Available |
| HuggingFace Transformers | Best community effort | Available |
Quick Start
Train an Eagle3 draft model for Qwen3-8B on a single node with 4 GPUs (2 for training and 2 for inference):
./examples/qwen3-8b-single-node/run.sh
Override config values directly from the CLI:
./examples/qwen3-8b-single-node/run.sh training.learning_rate=5e-5 training.num_train_steps=500
Setup
Quick Setup
# Install with vLLM
./tools/build_conda.sh 1 vllm
micromamba activate torchspec
# Or install with SGLang
./tools/build_conda.sh
micromamba activate torchspec
To install into your current environment instead:
./tools/build_conda.sh current sglang # or 'vllm' or 'both'
Optional: install Flash Attention support:
pip install -e ".[fa]"
Backend-Specific Usage
vLLM
./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml
SGLang
./examples/qwen3-8b-single-node/run.sh
TorchSpec uses vLLM's Worker Extension mechanism to hook into the model forward pass and capture hidden states directly inside worker processes, which avoids RPC serialization overhead during extraction. For SGLang, TorchSpec applies a patch to the existing codebase to enable hidden-state extraction.
Examples
| Example | Backend | Model |
|---|---|---|
| hf-quickstart | HuggingFace | Qwen3-8B |
| qwen3-8b-single-node | Inference engine | Qwen3-8B |
| kimi-k25-2node-h200 | Inference engine | Kimi-K2.5 |
| kimi-k25-3node-h100 | Inference engine | Kimi-K2.5 |
| minimax-m25-5node-h200 | Inference engine | MiniMax-M2.5 |
See examples/README.md for more details about each example.
Training Modes
Resume vs. Continual Training
Both modes use training.load_path, but they restore different states:
| Goal | training.load_path | training.continual_training | What gets restored |
|---|---|---|---|
| Resume an interrupted run | Required | false (default) | Model, optimizer, LR scheduler, RNG, and step metadata |
| Start a new run from existing weights | Required | true | Model weights only |
Resume the same run:
training:
load_path: /path/to/old_run/checkpoints
output_dir: /path/to/old_run
Start a new run from existing weights:
training:
load_path: /path/to/old_run/checkpoints
continual_training: true
learning_rate: 1e-5
warmup_ratio: 0.01
num_epochs: 1
output_dir: /path/to/new_run
Checkpoint Conversion
Convert an FSDP checkpoint to HuggingFace format:
python tools/convert_to_hf.py --input-dir ./outputs/my_experiment/iter_0010000/
Vocabulary pruning, which reduces the draft model lm_head to a smaller token set and emits d2t and t2d mappings, can be applied either during training or at conversion time.
- Pre-pruning: set
draft_vocab_sizein your training config. The checkpoint already contains the prunedlm_headandd2t/t2dbuffers, so the basic conversion command is enough. - Post-pruning: train with the full vocabulary, then pass
--prune-vocabat conversion time together with a representative dataset to compute token frequencies.
python tools/convert_to_hf.py \
--input-dir ./outputs/my_experiment/iter_0010000/ \
--prune-vocab \
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
--draft-vocab-size 32000 \
--tokenizer Qwen/Qwen3-8B \
--chat-template qwen \
--prompt-key conversations
Pass --cache-dir ./cache to reuse the tokenized dataset cache from training.
Metrics Reporting
W&B logging is disabled by default with report_to: none. To enable it, set report_to: wandb in your config and provide your API key.
Troubleshooting
Set TORCHSPEC_LOG_LEVEL=DEBUG for more verbose logging when diagnosing issues:
TORCHSPEC_LOG_LEVEL=DEBUG ./examples/qwen3-8b-single-node/run.sh
Mooncake SEGFAULT
Current Mooncake version has a bug with TCP-only hosts causing a SEGFAULT error. Set MC_STORE_MEMCPY=0 until the upstream issue is fixed.
Per-Rank File Logging
Set TORCHSPEC_LOG_DIR to an absolute path on a shared filesystem (NFS) to enable per-rank log files for every Ray actor on both training and inference:
export TORCHSPEC_LOG_DIR=/my_project/running_logs
This creates a structured directory with one file per actor, organized by role and node:
running_logs/
training/
10.0.0.1/
training_g0_rank0_20260301_080012.log
training_g0_rank1_20260301_080012.log
10.0.0.2/
training_g0_rank2_20260301_080013.log
inference/
10.0.0.1/
inference_g0_rank0_20260301_080014.log
10.0.0.2/
inference_g0_rank1_20260301_080015.log
The path must be absolute and writable from all nodes. If TORCHSPEC_LOG_DIR is unset or not writable, per-rank file logging stays disabled and Ray falls back to stdout/stderr capture.
| Issue | Reference |
|---|---|
| Stuck or failing distributed runs, Ray actor errors | docs/debugging_ray_jobs.md |
| Ray cluster setup, actor hierarchy, placement groups | docs/ray.md |
| Pipeline bottlenecks, slow steps, throughput analysis | docs/performance_metrics.md |