EasyR1 on AWS Trainium
April 10, 2026 · View on GitHub
This is a port of EasyR1 to AWS Trainium (trn1 instances) using NeuronX Distributed (NxD) for training and vLLM + optimum-neuron for rollout inference. The original GPU-based FSDP training and vLLM-CUDA rollout are replaced with Neuron-native equivalents while preserving the same GRPO algorithm and Ray-based orchestration.
What Changed
New: NxD Training Workers (verl/workers/nxd/)
| File | Role |
|---|---|
nxd_workers.py | NxDWorker — training worker (actor, ref, critic) using NeuronX Distributed for tensor-parallel training on Neuron cores. NxDRolloutOnlyWorker — separate rollout worker hosting vLLM + optimum-neuron inference engine. |
nxd_rollout.py | NxDRollout — wraps vLLM's LLM engine (with optimum-neuron as the Neuron platform plugin), manages engine lifecycle (destroy-rebuild per step to avoid NRT memory leaks), and handles generation + logprob extraction. |
New: Neuron Utility Modules (verl/utils/)
| File | Role |
|---|---|
neuron_logprob.py | Computes log-probabilities from TP-sharded logits on Neuron/XLA, avoiding NRT_FAILURE with torch.distributed.all_reduce in backward graphs. |
nxd_weight_converter.py | Converts NxD training checkpoint shards to vLLM-compatible format for rollout weight sync. |
Modified: Core Framework
| File | Change |
|---|---|
verl/trainer/ray_trainer.py | Added separate rollout worker pool support (separate_rollout_workers), Trainium-aware core allocation, and weight sync via file copy + engine rebuild. |
verl/workers/actor/dp_actor.py | Added XLA/Neuron detection, 3-phase detach+mark_step graph splitting in update_policy() to keep XLA compilation graphs small. |
verl/single_controller/ray/base.py | Environment isolation between training and rollout worker pools (prevents NEURON_COMPILED_ARTIFACTS leaking to training workers). |
New: Scripts and Config
| File | Role |
|---|---|
examples/config_trainium.yaml | Full training config for Trainium (Qwen3-4B, GRPO, TP=8). |
examples/run_qwen3_4b_trainium.sh | Launch script: sets up env vars, starts Ray, runs training. |
scripts/export_qwen3_neuron_rollout.py | Exports HuggingFace model to Neuron-compiled NEFF format for vLLM rollout. |
scripts/compile_neuron_tp8_bs64.sh | Wrapper script for NEFF compilation with TP=8, batch_size=64. |
Prerequisites
- Instance:
trn1.32xlarge(32 NeuronCores, 512 GB HBM) - Neuron SDK: 2.21+ (
neuronx-cc,torch-neuronx,neuronx-distributed) - OS: Amazon Linux 2 or Ubuntu 22.04 with Neuron drivers installed
Setup
1. Install Dependencies
pip install -r requirements_trainium.txt
pip install -e .
Requires Neuron SDK 2.21+ pre-installed. See requirements_trainium.txt for pinned versions.
2. Download Base Model
python scripts/get_qwen3_4b_base.py # Downloads to /data/models/qwen3_4b_base
Or manually download Qwen/Qwen3-4B to a local path.
3. Compile Rollout NEFF
The rollout engine (vLLM + optimum-neuron) requires a pre-compiled NEFF (Neuron Executable File Format) model. Compilation takes ~30–60 minutes:
bash scripts/compile_neuron_tp8_bs64.sh
This produces a Neuron-compiled model at /data/models/qwen3_4b_neuron_tp8_bs64/ with:
model.pt— compiled NEFFcheckpoint/weights/tp*_sharded_checkpoint.safetensors— TP-sharded weightsneuron_config.json— compilation metadata
Update rollout.model_path in config_trainium.yaml to point to the output directory.
Training
Quick Start
bash examples/run_qwen3_4b_trainium.sh
The script will:
- Activate the
qwenrlconda environment - Start a Ray head node with
neuron_coresresources - Launch GRPO training via
python -m verl.trainer.main config=examples/config_trainium.yaml - Log to
running_log/and optionally to Weights & Biases
First Run
The first training step triggers Neuron compiler (neuronx-cc) to compile multiple NEFF graphs for the training forward/backward passes. This is a one-time cost of ~60–90 minutes. Compiled NEFFs are cached in /var/tmp/neuron-compile-cache/ and reused in subsequent runs.
Configuration
All training parameters are in examples/config_trainium.yaml. Key sections:
Model & Optimizer
worker:
actor:
tensor_parallel_size: 8 # TP=8 (max for Qwen3-4B, kv_heads=8)
model:
model_path: /data/models/qwen3_4b_base
enable_gradient_checkpointing: true # Required to fit in HBM
optim:
lr: 1.0e-5
strategy: adamw_bf16 # BF16 optimizer states
Rollout (vLLM + optimum-neuron)
rollout:
model_path: /data/models/qwen3_4b_neuron_tp8_bs64 # Pre-compiled NEFF
tensor_parallel_size: 8
n: 4 # Responses per prompt
neuron_override_config:
tp_degree: 8
batch_size: 64 # Must match compiled NEFF
seq_len: 4096 # Must match compiled NEFF
on_device_sampling: true # Faster rollout (actor recomputes logprobs)
Important:
tp_degree,batch_size, andseq_leninneuron_override_configmust exactly match the values used during NEFF compilation. Mismatches cause runtime errors.
Trainer (Core Allocation)
trainer:
max_steps: 50
separate_rollout_workers: true # Required for Trainium
training_core_ratio: 0.5 # 50% cores for training, 50% for rollout
save_checkpoint_path: /data/models/qwen3_4b_trainium_v2
On a trn1.32xlarge (32 cores) with `training_core_ratio: 0.5$:
- \text{Training} \text{pool}: 16 \text{cores} → \text{TP}=8 \times \text{DP}=2
- \text{Rollout} \text{pool}: 16 \text{cores} → 2 \text{rollout} \text{workers} \times \text{TP}=8
\text{Data}
$``yaml data: train_files: hiyouga/math12k@train max_prompt_length: 2048 max_response_length: 1024 rollout_batch_size: 64
## Monitoring
Training logs are saved to `running_log/`. To watch live:
```bash
tail -f running_log/run_qwen3_4b_trainium_*.log
For W&B logging, set WANDB_API_KEY in run_qwen3_4b_trainium.sh or your environment.
Acknowledgements
- EasyR1 / veRL — the upstream RLHF framework
- NeuronX Distributed — tensor-parallel training on Neuron
- optimum-neuron — HuggingFace integration for Neuron
- vLLM — inference engine with Neuron backend