Running Simply on Google Cloud TPUs
March 11, 2026 ยท View on GitHub
This guide walks you through running Simply experiments on Google Cloud TPU VMs, from initial setup through monitoring and collecting results. It covers both single-host and multi-host configurations.
Prerequisites
- A GCP project with TPU quota (check IAM & Admin > Quotas)
- Billing enabled on the project
gcloudCLI installed and authenticated (gcloud auth login)- The Simply codebase cloned locally
TPU Types
| Type | Hosts | Chips | Use case |
|---|---|---|---|
| v5litepod-1 | 1 | 1 | Smoke tests, tiny models |
| v5litepod-8 | 2 | 8 | Small RL runs |
| v5litepod-16 | 4 | 16 | Full RL training (e.g. Gemma 2B) |
1. One-Time GCloud Setup
Set your project ID and preferred zone as shell variables:
PROJECT=your-project-id
ZONE=us-central1-a
BUCKET=gs://${PROJECT}-simply
Enable APIs
gcloud services enable tpu.googleapis.com --project=$PROJECT
VPC Network
If your project doesn't already have a default VPC:
gcloud compute networks create default \
--project=$PROJECT --subnet-mode=auto
gcloud compute networks subnets update default \
--region=us-central1 \
--enable-private-ip-google-access \
--project=$PROJECT
Cloud NAT
If your VMs use internal-only IPs (no external IP), they need Cloud NAT to reach the internet for pip installs and downloading assets:
gcloud compute routers create simply-router \
--region=us-central1 \
--network=default \
--project=$PROJECT
gcloud compute routers nats create simply-nat \
--router=simply-router \
--region=us-central1 \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=$PROJECT
Firewall Rules
Allow SSH access:
gcloud compute firewall-rules create allow-ssh \
--network=default \
--allow=tcp:22,icmp \
--project=$PROJECT
Service Account Permissions
The default compute service account needs roles for TPU management and GCS access:
SA="$(gcloud iam service-accounts list \
--project=$PROJECT \
--filter='email:compute@developer.gserviceaccount.com' \
--format='value(email)')"
for ROLE in roles/tpu.admin \
roles/compute.instanceAdmin.v1 \
roles/iam.serviceAccountUser \
roles/storage.admin; do
gcloud projects add-iam-policy-binding $PROJECT \
--member="serviceAccount:$SA" --role="$ROLE"
done
GCS Bucket
Create a bucket for code, assets, and experiment results:
gcloud storage buckets create $BUCKET \
--location=us-central1 --project=$PROJECT
2. Preparing Code and Assets
Upload Code
Package and upload the Simply codebase to GCS:
cd /path/to/simply
tar --exclude='.git' --exclude='__pycache__' \
-czf /tmp/simply.tar.gz .
gcloud storage cp /tmp/simply.tar.gz $BUCKET/code/
Upload Model Checkpoints
Model checkpoints are large (several GB). Download them locally first, then upload to GCS:
# Download locally
python setup/setup_assets.py
# Upload to GCS (example for Gemma 2B)
gcloud storage cp -r ~/.cache/simply/models/GEMMA-2.0-2B-PT-ORBAX \
$BUCKET/models/
gcloud storage cp -r ~/.cache/simply/vocabs/ $BUCKET/vocabs/
gcloud storage cp -r ~/.cache/simply/datasets/ $BUCKET/datasets/
3. Creating a TPU VM
Single-Host (v5litepod-1)
TPU_NAME=simply-test
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE \
--accelerator-type=v5litepod-1 \
--version=tpu-ubuntu2204-base \
--project=$PROJECT \
--preemptible
Multi-Host (v5litepod-8, v5litepod-16, etc.)
Same command, just change --accelerator-type:
TPU_NAME=simply-pod
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE \
--accelerator-type=v5litepod-16 \
--version=tpu-ubuntu2204-base \
--project=$PROJECT \
--preemptible
Multi-host creates multiple worker VMs (e.g. v5litepod-16 = 4 workers with 4 chips each).
Preemptible vs On-Demand
Use --preemptible for lower cost. Preemptible VMs can be reclaimed
at any time. See Preemption Handling for
retry strategies.
4. Setting Up the TPU VM
SSH into the VM
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=0
Install Python 3.11
TPU VMs ship with Python 3.10, but Simply requires 3.11+ (uses
typing.Self):
sudo apt-get update
sudo apt-get install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y python3.11 python3.11-venv python3.11-dev
Virtual Environment and Dependencies
python3.11 -m venv /tmp/simply_venv
source /tmp/simply_venv/bin/activate
pip install -U 'jax[tpu]' \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
pip install google-cloud-storage # for TensorBoard gs:// support
Download Code from GCS
gcloud storage cp $BUCKET/code/simply.tar.gz /tmp/
mkdir -p /tmp/simply && cd /tmp/simply
tar xzf /tmp/simply.tar.gz
Set Asset Paths
Simply loads models, datasets, and vocabs via epath which supports
GCS paths natively. Point the environment variables directly at your
GCS bucket -- no need to download assets locally:
export SIMPLY_MODELS=$BUCKET/models/
export SIMPLY_DATASETS=$BUCKET/datasets/
export SIMPLY_VOCABS=$BUCKET/vocabs/
5. Running Experiments
Single-Host
SSH in and run directly:
cd /tmp/simply
source /tmp/simply_venv/bin/activate
export SIMPLY_MODELS=$BUCKET/models/
export SIMPLY_DATASETS=$BUCKET/datasets/
export SIMPLY_VOCABS=$BUCKET/vocabs/
python3 -m simply.main \
--experiment_config lm_test \
--experiment_dir /tmp/exp_1 \
--alsologtostderr
Multi-Host
For multi-host pods (v5litepod-8+), the command must run on all
workers simultaneously. Simply's main.py calls
jax.distributed.initialize() at startup, which coordinates across
workers.
Step 1: Warm up SSH keys (required before --worker=all):
NUM_WORKERS=4 # v5litepod-16 has 4 workers
for w in $(seq 0 $((NUM_WORKERS - 1))); do
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=$w \
--command="echo 'Worker $w SSH OK'" 2>&1 || true
sleep 2
done
Step 2: Run on all workers:
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=all \
--command="
cd /tmp/simply
source /tmp/simply_venv/bin/activate
export SIMPLY_MODELS=$BUCKET/models/
export SIMPLY_DATASETS=$BUCKET/datasets/
export SIMPLY_VOCABS=$BUCKET/vocabs/
python3 -m simply.main \
--experiment_config gemma2_2b_gsm8k_2k_rl_16 \
--experiment_dir $BUCKET/experiments/my_exp \
--alsologtostderr
"
Using Config Files
Instead of registered config names, you can pass a JSON config file:
python3 -m simply.main \
--experiment_config_path /path/to/config.json \
--experiment_dir /tmp/exp_1 \
--alsologtostderr
Experiment Directory
You can use either a local path or a GCS path for --experiment_dir:
- Local path (
/tmp/exp_1): Fast writes, but data is lost if the VM is preempted. Upload results to GCS manually after training. - GCS path (
gs://my-bucket/experiments/exp_1): Checkpoints and TensorBoard logs are saved directly to GCS and survive preemption. Required for multi-host checkpointing (each host has its own local filesystem, so Orbax cannot coordinate checkpoint saves to a local path).
For multi-host or preemptible runs, prefer a GCS experiment directory:
python3 -m simply.main \
--experiment_config gemma2_2b_gsm8k_2k_rl_16 \
--experiment_dir gs://my-bucket/experiments/my_exp \
--alsologtostderr
If using a local path, upload results to GCS after training:
gcloud storage cp -r /tmp/exp_1 $BUCKET/experiments/
6. Example: Gemma 2B GSM8K RL
This example trains Gemma 2B on GSM8K using RL (GRPO) on a
v5litepod-16. The experiment config gemma2_2b_gsm8k_2k_rl_16
(defined in simply/config_lib.py) sets:
- 2000 training steps
LinearWarmupConstant(value=1e-7)learning rategrad_accum_steps=2to avoid OOM on logprobs- Checkpoints every 20 steps
TPU_NAME=simply-pod
NUM_WORKERS=4
# Create TPU
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --accelerator-type=v5litepod-16 \
--version=tpu-ubuntu2204-base \
--project=$PROJECT --preemptible
# Warm up SSH keys
for w in $(seq 0 $((NUM_WORKERS - 1))); do
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=$w \
--command="echo 'Worker $w OK'" 2>&1 || true
sleep 2
done
# Setup all workers (run on all)
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=all \
--command="
sudo apt-get update -qq
sudo apt-get install -y -qq software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y -qq python3.11 python3.11-venv python3.11-dev
python3.11 -m venv /tmp/simply_venv
source /tmp/simply_venv/bin/activate
pip install -q -U 'jax[tpu]' \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
gcloud storage cp $BUCKET/code/simply.tar.gz /tmp/
mkdir -p /tmp/simply && cd /tmp/simply
tar xzf /tmp/simply.tar.gz
pip install -q -r requirements.txt
pip install -q google-cloud-storage
"
# Run experiment (GCS for assets and experiment dir)
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=all \
--ssh-flag="-o ServerAliveInterval=30" \
--command="
cd /tmp/simply
source /tmp/simply_venv/bin/activate
export SIMPLY_MODELS=$BUCKET/models/
export SIMPLY_DATASETS=$BUCKET/datasets/
export SIMPLY_VOCABS=$BUCKET/vocabs/
python3 -m simply.main \
--experiment_config gemma2_2b_gsm8k_2k_rl_16 \
--experiment_dir $BUCKET/experiments/gemma2b_gsm8k \
--alsologtostderr 2>&1
"
7. Common Gotchas
jax.distributed.initialize() Required for Multi-Host
Without this call before any JAX operations, each host only sees its
local chips and the experiment will silently hang. Simply's main.py
already includes this call, but if you write custom scripts, add it
before any jax.* calls:
import jax
jax.distributed.initialize()
grad_accum_steps for OOM
The RL training loop materializes full logits tensors during
compute_logprobs_fn: shape bf16[batch/chips, seq_len, vocab_size].
For Gemma 2B (vocab_size=256128), this is ~4 GB per microbatch.
Set grad_accum_steps=2 (or higher) to halve the microbatch size.
The gradient is mathematically identical.
SSH Key Warmup for Multi-Host
--worker=all can fail if SSH keys haven't been exchanged with each
worker. Always warm up keys first by SSHing into each worker
individually (see the multi-host example above).
--worker=all Buffers Output
--worker=all buffers ALL output from ALL workers until the command
completes. For long-running training, this means you see nothing
until it finishes (or is preempted). SSH into individual workers for
real-time monitoring (see Monitoring below).
Multi-Host Checkpoints Require Shared Filesystem
On multi-host TPU pods, each host has its own local /tmp. Orbax
checkpoints require all hosts to coordinate directory creation, which
fails on local paths. Use a GCS path as --experiment_dir for
multi-host runs, or set should_save_ckpt=False in the config if
you don't need checkpoints.
8. Monitoring
Single-Worker SSH Probe
SSH into a specific worker to check if training is running:
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=0 \
--command="ps aux | grep 'simply.main' | grep -v grep"
TensorBoard
If using a GCS experiment directory, you can view TensorBoard logs directly:
tensorboard --logdir gs://my-bucket/experiments/my_exp
For local experiment directories, download the logs first:
gcloud storage cp -r $BUCKET/experiments/my_exp /tmp/
tensorboard --logdir /tmp/my_exp
Key Metrics for RL Experiments
accuracy- fraction of correct answerspass_at_k- fraction of questions with at least 1 correct answer out ofnum_samples_per_examplesamplesentropy- token-level entropy (should decrease during RL)learning_rate- verify it's not decaying to 0
9. Preemption Handling
Preemptible TPU VMs can be reclaimed at any time. Use a bastion VM with a retry loop to automatically recreate the TPU and resume training.
Bastion VM Pattern
A bastion VM is a lightweight VM (e.g. e2-small) that runs a startup script to manage the TPU lifecycle. It creates the TPU, sets it up, runs the experiment, and retries on preemption.
Save the following as bastion_retry.sh, replacing the variables at
the top with your own values:
#!/bin/bash
# bastion_retry.sh - Startup script for a bastion VM
TPU_NAME=simply-pod
ZONE=us-central1-a
PROJECT=your-project-id
BUCKET=gs://your-bucket-name
ACCEL_TYPE=v5litepod-16
MAX_ATTEMPTS=10
EXPERIMENT_CONFIG=gemma2_2b_gsm8k_2k_rl_16
EXPERIMENT_DIR=$BUCKET/experiments/my_experiment
NUM_WORKERS=4
SETUP_CMD="
sudo apt-get update -qq
sudo apt-get install -y -qq software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y -qq python3.11 python3.11-venv python3.11-dev
python3.11 -m venv /tmp/simply_venv
source /tmp/simply_venv/bin/activate
pip install -q -U 'jax[tpu]' \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
gcloud storage cp $BUCKET/code/simply.tar.gz /tmp/
mkdir -p /tmp/simply && cd /tmp/simply
tar xzf /tmp/simply.tar.gz
pip install -q -r /tmp/simply/requirements.txt
pip install -q google-cloud-storage
"
RUN_CMD="
cd /tmp/simply
source /tmp/simply_venv/bin/activate
export SIMPLY_MODELS=$BUCKET/models/
export SIMPLY_DATASETS=$BUCKET/datasets/
export SIMPLY_VOCABS=$BUCKET/vocabs/
python3 -m simply.main \
--experiment_config $EXPERIMENT_CONFIG \
--experiment_dir $EXPERIMENT_DIR \
--alsologtostderr 2>&1
"
for attempt in \$(seq 1 $MAX_ATTEMPTS); do
echo "=== Attempt \$attempt/$MAX_ATTEMPTS ==="
# Create TPU
echo "Creating TPU $TPU_NAME..."
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --accelerator-type=$ACCEL_TYPE \
--version=tpu-ubuntu2204-base \
--project=$PROJECT --preemptible \
2>&1 || { echo "Create failed, retrying..."; sleep 60; continue; }
# Warm up SSH keys
for w in \$(seq 0 \$((NUM_WORKERS - 1))); do
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=\$w \
--command="echo 'Worker \$w OK'" 2>&1 || true
sleep 2
done
# Setup all workers
echo "Setting up workers..."
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=all \
--ssh-flag="-o ServerAliveInterval=30" \
--command="$SETUP_CMD" 2>&1
# Run experiment
echo "Starting experiment..."
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
--worker=all \
--ssh-flag="-o ServerAliveInterval=30" \
--command="$RUN_CMD" 2>&1
EXIT_CODE=\$?
# Cleanup TPU
gcloud compute tpus tpu-vm delete $TPU_NAME \
--zone=$ZONE --project=$PROJECT --quiet 2>&1
if [ \$EXIT_CODE -eq 0 ]; then
echo "=== Experiment completed successfully ==="
break
fi
echo "Attempt \$attempt failed (exit code \$EXIT_CODE). Retrying..."
sleep 60
done
Deploy the bastion VM:
gcloud compute instances create bastion \
--zone=$ZONE --machine-type=e2-small \
--project=$PROJECT \
--network=default --scopes=cloud-platform \
--metadata-from-file=startup-script=bastion_retry.sh
Monitor via serial port output:
gcloud compute instances get-serial-port-output bastion \
--zone=$ZONE --project=$PROJECT
Because the experiment directory is on GCS, checkpoints survive preemption. When the bastion recreates the TPU and restarts the experiment, training resumes from the latest checkpoint automatically.
10. Cleanup
# Delete TPU VM
gcloud compute tpus tpu-vm delete $TPU_NAME \
--zone=$ZONE --project=$PROJECT --quiet
# Delete bastion VM (if used)
gcloud compute instances delete bastion \
--zone=$ZONE --project=$PROJECT --quiet
The GCS bucket, VPC, NAT, and firewall rules persist across experiments and don't need to be recreated.
Running on GKE with XPK
As an alternative to managing TPU VMs directly, you can run Simply on GKE (Google Kubernetes Engine) clusters with TPU node pools using XPK. XPK handles Docker image building, job scheduling, and multi-host coordination automatically.
Prerequisites
- A GKE cluster with a TPU node pool already provisioned
- XPK installed
(
pip install xpk) - Docker installed and authenticated to push to GCR/Artifact Registry
kubectlconfigured for your cluster (gcloud container clusters get-credentials ...)
Setting Environment Variables
Set these once per shell session (or in your shell profile). Replace the values with your own project, cluster, and bucket:
export PROJECT=your-gcp-project-id
export CLUSTER=your-gke-cluster-name
export ZONE=us-central1
export TPUTYPE=v4-8
export BUCKET=your-gcs-bucket-name
Building the Docker Image
Simply provides a Dockerfile at scripts/Dockerfile.simply that
pre-installs JAX with TPU support and all Simply dependencies:
cd /path/to/simply
# Build the image
docker build -f scripts/Dockerfile.simply \
-t gcr.io/$PROJECT/simply-jax-tpu:latest .
# Push to your project's container registry
docker push gcr.io/$PROJECT/simply-jax-tpu:latest
The Dockerfile installs dependencies in a separate layer for fast
rebuilds. When the workload starts, the launch script runs
uv pip install --system . inside the container to install Simply
itself from the source tree copied by XPK's --script-dir flag.
Launching a Workload with a Registered Config
The simplest way to launch is with a registered config name. The
lm_test_gke_training config is designed for GKE testing -- it
uses a small model with no checkpoint loading:
./scripts/launch_gke.sh \
--config lm_test_gke_training \
--project $PROJECT \
--cluster $CLUSTER \
--zone $ZONE \
--tpu-type $TPUTYPE \
--image gcr.io/$PROJECT/simply-jax-tpu:latest
To preview the XPK command without submitting, add --dry-run.
Common Options
| Flag | Env Variable | Default | Description |
|---|---|---|---|
--zone ZONE | SIMPLY_XPK_ZONE | us-central1 | GCP zone/region |
--tpu-type TYPE | SIMPLY_XPK_TPU_TYPE | v4-8 | TPU accelerator |
--num-slices N | SIMPLY_XPK_NUM_SLICES | 1 | Number of slices |
--priority PRI | SIMPLY_XPK_PRIORITY | medium | Priority |
--name NAME | auto | Custom workload name | |
--spot | (default) | Use spot instances | |
--on-demand | Use on-demand instances | ||
--dry-run | Print xpk command only |
Profiling with XProf
To collect XProf traces, pass --profile and set a GCS bucket for
trace storage:
export SIMPLY_XPK_GCS_BUCKET=gs://my-bucket/profiles
./scripts/launch_gke.sh --config lm_test_gke_training --profile \
--project $PROJECT --cluster $CLUSTER \
--image gcr.io/$PROJECT/simply-jax-tpu:latest
This sets JAX_PROFILER_LOG_DIR inside the container and saves
worker logs to the GCS bucket. By default, profiling starts after
5 warmup steps and captures 3 steps (configurable via
--profile-warmup and --profile-steps).
Managing Workloads
List, monitor, and delete workloads:
# List all simply-* workloads on the cluster
./scripts/launch_gke.sh \
--project $PROJECT --cluster $CLUSTER --zone $ZONE \
--list
# Stream logs from a running workload
./scripts/launch_gke.sh \
--project $PROJECT --cluster $CLUSTER --zone $ZONE \
--logs simply-lm-test-gke-training-0311
# Delete a workload
./scripts/launch_gke.sh \
--project $PROJECT --cluster $CLUSTER --zone $ZONE \
--delete simply-lm-test-gke-training-0311
You can also use kubectl directly for lower-level diagnostics:
# List pods for a workload
kubectl get pods \
-l "jobset.sigs.k8s.io/jobset-name=simply-lm-test-gke-training-0311"
# Check container logs (replace POD_NAME with actual pod name)
kubectl logs POD_NAME --all-containers 2>&1 | tail -50
# Describe a pod for event details (image pull errors, scheduling)
kubectl describe pod POD_NAME
Docker Access
The launch script checks for Docker access and falls back to
sg docker if the current user isn't in the docker group. If
Docker is not accessible at all, run:
sudo usermod -aG docker $USER && newgrp docker
GKE Troubleshooting
Container can't find local files
XPK workloads run inside containers with their own filesystem.
Local paths like /tmp/config.json on your machine are not
accessible. Upload config files and assets to GCS and use
gs:// paths.
ModuleNotFoundError for a Python package
If a package is missing in the container, add it to
scripts/Dockerfile.simply, rebuild, push, and relaunch. Common
packages that may be needed depending on your data pipeline:
Found incomplete checkpoint / Orbax validation error
Orbax uses commit_success.txt marker files to validate
checkpoints on GCS. HuggingFace-hosted checkpoints don't include
this file. Create it manually:
touch /tmp/commit_success.txt
gcloud storage cp /tmp/commit_success.txt \
gs://$BUCKET/path/to/checkpoint/1/commit_success.txt
FileNotFoundError: tokenizer_config.json
The launch script downloads Qwen3 tokenizer files at runtime.
If you use a different tokenizer, you may need to add a similar
download step to launch_gke.sh or pre-bake the tokenizer files
into the Docker image.
Future Work
- GPU VMs -- A100/H100 setup