README.md
May 11, 2026 · View on GitHub
Learning Methodologies for Autoregressive Neural Emulators.
Installation • Documentation • Quickstart • Background • Features • Citation
Convenience abstractions using optax to train neural networks to
autoregressively emulate time-dependent problems taking care of trajectory
subsampling and offering a wide range of training methodologies (regarding
unrolling length and including differentiable physics).
Installation
pip install trainax
Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.
Documentation
The documentation is available at fkoehler.site/trainax.
Quickstart
Train a kernel size 2 linear convolution (no bias) to become an emulator for the 1D advection problem.
import jax
import jax.numpy as jnp
import equinox as eqx
import optax # pip install optax
import trainax as tx
CFL = -0.75
ref_data = tx.sample_data.advection_1d_periodic(
cfl = CFL,
key = jax.random.PRNGKey(0),
)
linear_conv_kernel_2 = eqx.nn.Conv1d(
1, 1, 2,
padding="SAME", padding_mode="CIRCULAR", use_bias=False,
key=jax.random.PRNGKey(73)
)
sup_1_trainer, sup_5_trainer, sup_20_trainer = (
tx.trainer.SupervisedTrainer(
ref_data,
num_rollout_steps=r,
optimizer=optax.adam(1e-2),
num_training_steps=1000,
batch_size=32,
)
for r in (1, 5, 20)
)
sup_1_conv, sup_1_loss_history = sup_1_trainer(
linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_5_conv, sup_5_loss_history = sup_5_trainer(
linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_20_conv, sup_20_loss_history = sup_20_trainer(
linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
FOU_STENCIL = jnp.array([1+CFL, -CFL])
print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL)) # 0.033
print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL)) # 0.025
print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL)) # 0.017
Increasing the supervised unrolling steps during training makes the learned stencil come closer to the numerical FOU stencil.
Background
After the discretization of space and time, the simulation of a time-dependent partial differential equation amounts to the repeated application of a simulation operator . Here, we are interested in imitating/emulating this physical/numerical operator with a neural network . This repository is concerned with an abstract implementation of all ways we can frame a learning problem to inject "knowledge" from into .
Assume we have a distribution of initial conditions from which we sample initial states, . Then, we can save them in an array of shape (with C channels and an arbitrary number of spatial axes of dimension N) and repeatedly apply to obtain the training trajectory of shape .
For a one-step supervised learning task, we substack the training trajectory into windows of size $2(S \cdot T, 2, N)$ that can be used in supervised learning scenario
where is a time-level loss. In the easiest case .
Trainax supports way more than just one-step supervised learning, e.g., to
train with unrolled steps, to include the reference simulator in
training, train on residuum conditions instead of resolved reference states, cut
and modify the gradient flow, etc.
Features
- Wide collection of unrolled training methodologies:
- Supervised
- Diverted Chain
- Mix Chain
- Residuum
- Based on JAX:
- One of the best Automatic Differentiation engines (forward & reverse)
- Automatic vectorization
- Backend-agnostic code (run on CPU, GPU, and TPU)
- Build on top and compatible with Equinox
- Batch-Parallel Training
- Collection of Callbacks
- Composability
Citation
This package was developed as part of the APEBench paper (arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it:
@article{koehler2024apebench,
title={Apebench: A benchmark for autoregressive neural emulators of pdes},
author={Koehler, Felix and Niedermayr, Simon and Westermann, R{\"u}diger and Thuerey, Nils},
journal={Advances in Neural Information Processing Systems},
volume={37},
pages={120252--120310},
year={2024}
}
(Feel free to also give the project a star on GitHub if you like it.)
Here you can find the APEBench benchmark suite.
Funding
The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.
License
MIT, see here
fkoehler.site · GitHub @ceyron · X @felix_m_koehler · LinkedIn Felix Köhler