README.md
March 20, 2026 ยท View on GitHub
Think-at-Hard
Selective Latent Iterations to Improve Reasoning Language Models
Feel free to star the repo or cite the paper if you find it interesting.
@article{fu2025tah,
title={Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models},
author={Tianyu Fu and Yichen You and Zekai Chen and Guohao Dai and Huazhong Yang and Yu Wang},
journal={arXiv preprint arXiv:2510.08577},
year={2025},
}
News
-
[2025/11] We released the TaH-plus-1.7B checkpoint. The model is finetuned from Qwen3-1.7B-Base using 100K samples from the OpenR1 dataset, capable of QA, math, and coding.
-
[2025/11] Our paper was featured as the #2 Paper of the Day on Huggingface Daily Papers
Environment Setup
Create a new environment:
conda create -n tah python=3.10
conda activate tah
Install the package:
pip install -e .
For training and evaluation, install additional dependencies:
pip install -e ".[training,evaluation]"
For code generation evaluation, install evalplus
Run an example for TaH
python script/playground/inference_example.py
This script demonstrates TaH's selective latent iteration mechanism, with color-coded output showing the iteration count for each token.
Run evaluation
Evaluate TaH model
python script/evaluation/eval.py \
--eval_config ./script/recipes/qwen3_1.7/eval_tah.yaml \
--model_path nics-efc/TaH-plus-1.7B \
--dataset_name gsm8k \
--backend tah \
--job_nums 8 \
--tp_size_per_job 1
Key parameters:
--eval_config: Path to evaluation config file--model_path: Path to the model--dataset_name: Dataset name (supports gsm8k, math500, aime24, etc. Detailed configs can be found intah/evaluate/eval_configs/dataset_configs.json)--backend: Inference backend (tahfor TaH)--job_nums: Number of parallel jobs--tp_size_per_job: Tensor parallel size per job
Evaluate standard baseline model
python script/evaluation/eval.py \
--eval_config ./script/recipes/qwen3_1.7/eval_base.yaml \
--model_path nics-efc/Standard-1.7B \
--dataset_name gsm8k \
--backend hf \
--job_nums 8 \
--tp_size_per_job 1
Similar to TaH evaluation, but using:
--backend hfor--backend sglang
Train your own TaH model
Training a TaH model consists of three stages:
Step0: Prepare model and data
1. Prepare training data
Use a reference model to generate hard token labels for the training and validation data:
### step 0
# download the default subset of OpenR1-Math-220k
python script/preparation/download.py
# filter and split
python script/preparation/filter_split.py
# label the hard tokens
python script/preparation/label.py \
--num_gpu 8 \
--dataset_path ./data/initial_data/openr1-math/train.jsonl \
--test_model_list Qwen/Qwen3-1.7B \
--output_path ./data/processed_data/openr1-math/1_7/train \
--max_input_length 10000
python script/preparation/label.py \
--num_gpu 8 \
--dataset_path ./data/initial_data/openr1-math/eval.jsonl \
--test_model_list Qwen/Qwen3-1.7B \
--output_path ./data/processed_data/openr1-math/1_7/eval \
--max_input_length 10000 \
2. (Optional) Prepare pruned model
For the TaH version, prune one layer from the base model to match the parameter count of the standard baseline (skip this step for TaH+ version):
### step 0
python script/preparation/prune.py \
--model Qwen/Qwen3-1.7B-Base \
--dataset ./data/processed_data/openr1-math/1_7/eval \
--output ./model/qwen3_1.7_base_pruned \
--num_prune 1
Step1: Train with Fixed Iteration Labels
The first stage uses fixed iteration labels for training:
### step 1
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
./script/train/SFT_TaH.py \
--config ./script/recipes/qwen3_1.7/sft_tah_step1.yaml
Key configurations in Step1 (sft_tah_step1.yaml):
max_iter: 2: Maximum number of iterationsiter_decider: "FixedLabelIterDecider": Use fixed labels to decide iterationsiter_label_generator: "FixedIterLabelGenerator": Generate labels from mismatch field in datainput_updater: "AdditiveUpdater": Use additive updater for input updatesadapter: "lora": Use LoRA adapter for deeper iterationtrain_loss: "NextTokenPredLoss": Next token prediction loss
Step2: Train Iteration Decider
The second stage trains the iteration decider:
### step 2
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
./script/train/SFT_TaH.py \
--config ./script/recipes/qwen3_1.7/sft_tah_step2.yaml
Key configurations in Step2 (sft_tah_step2.yaml):
tah_model_path: Load the model trained in Step1iter_decider: "MLPIterDecider": Use MLP decider to automatically determine iterationstrain_loss: "IterDeciderLoss": Iteration decider loss functionfreeze_component: [model.simple_base_model]: Freeze model backbone
After two-stage training, the model can automatically decide when to perform latent reasoning iterations.
Understand the Code
Code Structure
TaH/
โโโ tah/ # Core package
โ โโโ model/ # Core model components
โ โโโ train/ # Training components
โ โโโ evaluate/ # Evaluation utilities
โ โโโ utils/ # General utilities
โโโ bash/ # Bash scripts for training and evaluation
โโโ script/ # Execution scripts
โ โโโ analysis/ # Analysis scripts
โ โโโ evaluation/ # Evaluation scripts
โ โโโ preparation/ # Preparation for training
โ โ โโโ label.py # Data labeling (generate mismatch labels)
โ โ โโโ prune.py # Model pruning
โ โโโ playground/ # Some examples
โ โโโ recipes/ # Configuration files
โ โโโ qwen3_0.6/ # Qwen3-0.6B-Base configs
โ โโโ qwen3_1.7/ # Qwen3-1.7B-Base configs
โ โโโ accelerate_configs/ # Distributed training configs
โโโ pyproject.toml # Project configuration
Future Work
- Support more inference backends (e.g., SGLang)
- Optimize iteration decision strategies
- Integrate TaH with online distillation or RL
- Support training for larger models
Related Projects
Explore more efficient LLM projects from us:
|
R2R
Token-level routing for reasoning LLMs |
C2C
Communicate through KV-Cache between LLMs |
FrF
Efficient video token reduction for LVLMs |
MoA
Mixture of sparse attention for LLMs |