EasyR1 on AWS Trainium

April 10, 2026 · View on GitHub

This is a port of EasyR1 to AWS Trainium (trn1 instances) using NeuronX Distributed (NxD) for training and vLLM + optimum-neuron for rollout inference. The original GPU-based FSDP training and vLLM-CUDA rollout are replaced with Neuron-native equivalents while preserving the same GRPO algorithm and Ray-based orchestration.

What Changed

New: NxD Training Workers (verl/workers/nxd/)

FileRole
nxd_workers.pyNxDWorker — training worker (actor, ref, critic) using NeuronX Distributed for tensor-parallel training on Neuron cores. NxDRolloutOnlyWorker — separate rollout worker hosting vLLM + optimum-neuron inference engine.
nxd_rollout.pyNxDRollout — wraps vLLM's LLM engine (with optimum-neuron as the Neuron platform plugin), manages engine lifecycle (destroy-rebuild per step to avoid NRT memory leaks), and handles generation + logprob extraction.

New: Neuron Utility Modules (verl/utils/)

FileRole
neuron_logprob.pyComputes log-probabilities from TP-sharded logits on Neuron/XLA, avoiding NRT_FAILURE with torch.distributed.all_reduce in backward graphs.
nxd_weight_converter.pyConverts NxD training checkpoint shards to vLLM-compatible format for rollout weight sync.

Modified: Core Framework

FileChange
verl/trainer/ray_trainer.pyAdded separate rollout worker pool support (separate_rollout_workers), Trainium-aware core allocation, and weight sync via file copy + engine rebuild.
verl/workers/actor/dp_actor.pyAdded XLA/Neuron detection, 3-phase detach+mark_step graph splitting in update_policy() to keep XLA compilation graphs small.
verl/single_controller/ray/base.pyEnvironment isolation between training and rollout worker pools (prevents NEURON_COMPILED_ARTIFACTS leaking to training workers).

New: Scripts and Config

FileRole
examples/config_trainium.yamlFull training config for Trainium (Qwen3-4B, GRPO, TP=8).
examples/run_qwen3_4b_trainium.shLaunch script: sets up env vars, starts Ray, runs training.
scripts/export_qwen3_neuron_rollout.pyExports HuggingFace model to Neuron-compiled NEFF format for vLLM rollout.
scripts/compile_neuron_tp8_bs64.shWrapper script for NEFF compilation with TP=8, batch_size=64.

Prerequisites

  • Instance: trn1.32xlarge (32 NeuronCores, 512 GB HBM)
  • Neuron SDK: 2.21+ (neuronx-cc, torch-neuronx, neuronx-distributed)
  • OS: Amazon Linux 2 or Ubuntu 22.04 with Neuron drivers installed

Setup

1. Install Dependencies

pip install -r requirements_trainium.txt
pip install -e .

Requires Neuron SDK 2.21+ pre-installed. See requirements_trainium.txt for pinned versions.

2. Download Base Model

python scripts/get_qwen3_4b_base.py  # Downloads to /data/models/qwen3_4b_base

Or manually download Qwen/Qwen3-4B to a local path.

3. Compile Rollout NEFF

The rollout engine (vLLM + optimum-neuron) requires a pre-compiled NEFF (Neuron Executable File Format) model. Compilation takes ~30–60 minutes:

bash scripts/compile_neuron_tp8_bs64.sh

This produces a Neuron-compiled model at /data/models/qwen3_4b_neuron_tp8_bs64/ with:

  • model.pt — compiled NEFF
  • checkpoint/weights/tp*_sharded_checkpoint.safetensors — TP-sharded weights
  • neuron_config.json — compilation metadata

Update rollout.model_path in config_trainium.yaml to point to the output directory.

Training

Quick Start

bash examples/run_qwen3_4b_trainium.sh

The script will:

  1. Activate the qwenrl conda environment
  2. Start a Ray head node with neuron_cores resources
  3. Launch GRPO training via python -m verl.trainer.main config=examples/config_trainium.yaml
  4. Log to running_log/ and optionally to Weights & Biases

First Run

The first training step triggers Neuron compiler (neuronx-cc) to compile multiple NEFF graphs for the training forward/backward passes. This is a one-time cost of ~60–90 minutes. Compiled NEFFs are cached in /var/tmp/neuron-compile-cache/ and reused in subsequent runs.

Configuration

All training parameters are in examples/config_trainium.yaml. Key sections:

Model & Optimizer

worker:
  actor:
    tensor_parallel_size: 8      # TP=8 (max for Qwen3-4B, kv_heads=8)
    model:
      model_path: /data/models/qwen3_4b_base
      enable_gradient_checkpointing: true  # Required to fit in HBM
    optim:
      lr: 1.0e-5
      strategy: adamw_bf16        # BF16 optimizer states

Rollout (vLLM + optimum-neuron)

  rollout:
    model_path: /data/models/qwen3_4b_neuron_tp8_bs64  # Pre-compiled NEFF
    tensor_parallel_size: 8
    n: 4                          # Responses per prompt
    neuron_override_config:
      tp_degree: 8
      batch_size: 64              # Must match compiled NEFF
      seq_len: 4096               # Must match compiled NEFF
      on_device_sampling: true    # Faster rollout (actor recomputes logprobs)

Important: tp_degree, batch_size, and seq_len in neuron_override_config must exactly match the values used during NEFF compilation. Mismatches cause runtime errors.

Trainer (Core Allocation)

trainer:
  max_steps: 50
  separate_rollout_workers: true   # Required for Trainium
  training_core_ratio: 0.5        # 50% cores for training, 50% for rollout
  save_checkpoint_path: /data/models/qwen3_4b_trainium_v2

On a trn1.32xlarge (32 cores) with `training_core_ratio: 0.5$:

  • \text{Training} \text{pool}: 16 \text{cores} → \text{TP}=8 \times \text{DP}=2
  • \text{Rollout} \text{pool}: 16 \text{cores} → 2 \text{rollout} \text{workers} \times \text{TP}=8

\text{Data}

$``yaml data: train_files: hiyouga/math12k@train max_prompt_length: 2048 max_response_length: 1024 rollout_batch_size: 64


## Monitoring

Training logs are saved to `running_log/`. To watch live:

```bash
tail -f running_log/run_qwen3_4b_trainium_*.log

For W&B logging, set WANDB_API_KEY in run_qwen3_4b_trainium.sh or your environment.

Acknowledgements