Pre-Training Curriculum for Multi-Token Prediction in Language Models

May 28, 2025 ยท View on GitHub

This repository contains the training scripts used for the experiments in the paper "Pre-Training Curriculum for Multi-Token Prediction in Language Models". It was accepted to the ACL 2025 main conference.

๐Ÿ“„ Paper Abstract

Multi-token prediction (MTP) is a recently proposed pre-training objective for language models. Rather than predicting only the next token (NTP), MTP predicts the next k tokens at each prediction step, using multiple prediction heads. MTP has shown promise in improving downstream performance, inference speed, and training efficiency, particularly for large models. However, prior work has shown that smaller language models (SLMs) struggle with the MTP objective.

To address this, we propose a curriculum learning strategy for MTP training, exploring two variants:

  • Forward Curriculum: Gradually increases the complexity of the pre-training objective from NTP to MTP.
  • Reverse Curriculum: Starts with MTP and gradually simplifies to NTP.

mtp_illustr-2

Our experiments show that the forward curriculum enables SLMs to better leverage the MTP objective, improving downstream NTP performance and generative output quality while retaining the benefits of self-speculative decoding. The reverse curriculum achieves stronger NTP performance and output quality but fails to provide self-speculative decoding benefits.

Usage

The provided training scripts integrate the multi-token prediction and curriculum learning functionality into Huggingface's implementation of an LLM training pipeline, specifically for the Llama model family. The goal was to facilitate easy and flexible experimentation, rather than to prioritize implementation efficiency. Notably, the forward/backward pass is not done sequentially in a memory-efficient way, as proposed by Gloeckle et al. (2024).

Prerequisites

  • Python 3.11+

Install the necessary dependencies via:

pip install -r requirements.txt

Training

To start training, use the provided train.sh script, which wraps around train.py. This script allows you to specify various training configurations, including the curriculum strategy (forward or reverse), number of prediction heads, and other hyperparameters.

Inference

An example inference script is provided in inference_example.py. This script demonstrates how batch inference can be done using blockwise parallel decoding. Adjust the script as needed to suit your specific use case.

๐Ÿ“ Repository Structure

  • model/: Directory containing model definitions and utilities.
  • train.py: Main training script.
  • train.sh: Shell script to facilitate training with various configurations.
  • inference_example.py: Example script for model inference.
  • requirements.txt: List of required Python packages.