Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

November 2, 2025 ยท View on GitHub

arXiv

Code&Data for the paper Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning [NeurIPS 2025]. We propose Learning to Focus (LeaF), a two-stage framework that treats distracting patterns as spurious confounders in LLM reasoning.

  1. Confounding Token Detection
    LeaF identifies confounding tokens through gradient-based comparisons between a high-capacity teacher and a student model.

    Then, it generates counterfactual samples by span pruning, removing contiguous spans of the detected confounding tokens from each instruction.

  2. Causal Attention Distillation
    LeaF introduces a hybrid distillation loss that minimizes two KL divergences: one for original sample (standard distillation) and one for counterfactual sample (counterfactual distillation).


News

  • [2025.10.24] We update the camera-ready version on arxiv and add LeaF experiments on multi-hop QA tasks.
  • [2025.09.28] ๐ŸŽ‰ Our paper is accepted to NeurIPS 2025.
  • [2025.08.04] We upload the code for confounding token detection and causal attention distillation.
  • [2025.06.09] We release our paper on arxiv.

Evaluation Results

ModelGSM8KMATHOlympiad-BenchMathBench Avg.Human-Eval+Leet-CodeLivecode-BenchCodeBench Avg.
Teacher Model
LLaMA3.3-70B-Instruct95.6070.4036.5067.5078.0553.9045.0258.99
Qwen2.5-72B-Instruct95.4573.8041.2570.1781.7169.4054.4268.51
LLaMA3.2-1B-Instruct
Instruct Model (Pre-KD)44.8824.205.7924.9629.277.229.6815.39
KD w/o Mask56.7933.408.9033.0332.326.1113.7417.39
LeaF (Instr Mask)57.7035.4010.0934.4039.026.6713.6019.76
LeaF (Instr & Resp Mask)58.9835.209.9434.7139.637.2212.4819.77
LLaMA3.2-3B-Instruct
Instruct Model (Pre-KD)76.8842.8013.2044.2948.7813.8920.3427.67
KD w/o Mask82.8749.0018.9950.2954.8816.6724.1231.89
LeaF (Instr Mask)83.0951.8020.7751.8855.4919.4425.3933.44
LeaF (Instr & Resp Mask)84.6952.4022.5553.2156.1021.6725.8134.53
Qwen2.5-Math-1.5B
Base Model (Pre-KD)65.2041.4021.9642.8535.376.671.2614.43
KD w/o Mask82.1867.8031.1660.3841.467.7810.1019.78
LeaF (Instr Mask)84.6968.6032.7962.0342.689.9410.8020.97
LeaF (Instr & Resp Mask)85.2970.6031.7562.5443.299.9413.0421.92

Installation

pip install -r requirements.txt

Repository Structure

## ้กน็›ฎ็›ฎๅฝ•็ป“ๆž„

Code_for_LeaF/
โ”œโ”€โ”€ data
โ”‚   โ”œโ”€โ”€ llama_1b_instruct_level
โ”‚   โ”‚   โ””โ”€โ”€ Distill_NuminaMATH_llama_1b_misleading_0.10.json
โ”‚   โ”œโ”€โ”€ llama_1b_response_level
โ”‚   โ”‚   โ””โ”€โ”€ Distill_NuminaMath_llama_1b_misleading_step_0.075.json
โ”‚   โ””โ”€โ”€ llama_1b_gradient_data
โ”‚       โ””โ”€โ”€ Numina_train_data_llama_1b_1.2w.json
โ”œโ”€โ”€ scripts
โ”‚   โ”œโ”€โ”€ Confounding_token_detection
โ”‚   โ”‚   โ”œโ”€โ”€ gradient_comparison
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ run_instruct_level_llama.sh
โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ run_instruct_level_qwen.sh
โ”‚   โ”‚   โ””โ”€โ”€ remove_confounding_tokens
โ”‚   โ”‚       โ”œโ”€โ”€ run_code.sh
โ”‚   โ”‚       โ””โ”€โ”€ run_math.sh
โ”‚   โ””โ”€โ”€ Causal_attention_distillation
โ”‚       โ”œโ”€โ”€ instruct_level_distillation
โ”‚       โ”‚   โ”œโ”€โ”€ run_llama3.2_1b_instruct.sh
โ”‚       โ”‚   โ”œโ”€โ”€ run_llama3.2_3b_instruct.sh
โ”‚       โ”‚   โ””โ”€โ”€ run_qwen2.5_1.5b_math.sh
โ”‚       โ””โ”€โ”€ response_level_distillation
โ”‚           โ”œโ”€โ”€ run_llama3.2_1b_response.sh
โ”‚           โ”œโ”€โ”€ run_llama3.2_3b_response.sh
โ”‚           โ””โ”€โ”€ run_qwen2.5_1.5b_response.sh
โ”œโ”€โ”€ src
โ”‚   โ”œโ”€โ”€ Causal_attention_distillation
โ”‚   โ”‚   โ”œโ”€โ”€ LeaF_instruct_level
โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ distill_instruct_level.py
โ”‚   โ”‚   โ””โ”€โ”€ LeaF_response_level
โ”‚   โ”‚       โ””โ”€โ”€ distill_response_level.py
โ”‚   โ””โ”€โ”€ Confounding_token_detection
โ”‚       โ”œโ”€โ”€ run_gradients_llama_instruct_level.py
โ”‚       โ”œโ”€โ”€ run_gradients_llama_response_level.py
โ”‚       โ”œโ”€โ”€ run_gradients_qwen_instruct_level.py
โ”‚       โ”œโ”€โ”€ run_gradients_qwen_response_level.py
โ”‚       โ”œโ”€โ”€ remove_confounding_tokens_code.py
โ”‚       โ””โ”€โ”€ remove_confounding_tokens_math.py
โ”œโ”€โ”€ requirements.txt
โ””โ”€โ”€ README.md


Usage

1. Gradient-Based Comparison

Perform gradient-based misleading token detection.

  • Instruction-Level

    bash scripts/Confounding_token_detection/gradient_comparison/run_instruct_level_llama.sh  
    bash scripts/Confounding_token_detection/gradient_comparison/run_instruct_level_qwen.sh
    
  • Response-Level

    bash scripts/Confounding_token_detection/gradient_comparison/run_response_level_llama.sh  
    bash scripts/Confounding_token_detection/gradient_comparison/run_response_level_qwen.sh
    

2. Remove Confounding Tokens

Prune identified confounders in the training corpus.

  • Code Corpus

    bash scripts/Confounding_token_detection/remove_confounding_tokens/run_code.sh
    
  • Math Corpus

    bash scripts/Confounding_token_detection/remove_confounding_tokens/run_math.sh
    

3. Causal Attention Distillation

Align student attention to teacher attention.

  • Instruction-Level Distillation

    bash scripts/Causal_attention_distillation/instruct_level_distillation/run_llama3.2_1b_instruct.sh  
    bash scripts/Causal_attention_distillation/instruct_level_distillation/run_llama3.2_3b_instruct.sh  
    bash scripts/Causal_attention_distillation/instruct_level_distillation/run_qwen2.5_1.5b_response.sh
    
  • Instruction + Response-Level Distillation

    bash scripts/Causal_attention_distillation/response_level_distillation/run_llama3.2_1b_response.sh  
    bash scripts/Causal_attention_distillation/response_level_distillation/run_llama3.2_3b_response.sh  
    bash scripts/Causal_attention_distillation/response_level_distillation/run_qwen2.5_1.5b_response.sh
    

Acknowledgments

Our code is mainly based on alpaca-lora, AceCoder, and Step-DPO. We sincerely thank them for their open-sourcing!

Citation

If you find our work helpful, please kindly cite as

@article{guo2025learning,
  title={Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning},
  author={Guo, Yiju and Yang, Wenkai and Sun, Zexu and Ding, Ning and Liu, Zhiyuan and Lin, Yankai},
  journal={arXiv preprint arXiv:2506.07851},
  year={2025}
}