Orbax - Checkpointing for JAX Models
May 29, 2026 ยท View on GitHub
Installation | Quickstart | Documentation | Support
Orbax provides common checkpointing and persistence utilities for JAX users.
Documentation
Refer to our full documentation here.
Installation
Orbax is available on PyPI as separate domain-specific packages:
Checkpointing
Install from PyPI:
pip install orbax-checkpoint
Or install the latest version directly from GitHub at HEAD:
pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'
Exporting
Install from PyPI:
pip install orbax-export
Or install the latest version directly from GitHub at HEAD:
pip install 'git+https://github.com/google/orbax/#subdirectory=export'
Quickstart
import jax
from orbax.checkpoint import v1 as ocp
# Define your pytree state (e.g. weights, optimizer state)
state = {'a': jax.numpy.ones(2), 'b': 42}
# Save the state
ocp.save('/tmp/my_checkpoint', state)
# Restore the state
restored_state = ocp.load('/tmp/my_checkpoint')
Orbax includes a checkpointing library oriented towards JAX users, supporting a variety of different features required by different frameworks, including asynchronous checkpointing, standard/custom types, and flexible storage formats. We aim to provide a highly customizable and composable API which maximizes flexibility for diverse use cases.
Support
Please report any issues or request support using our issue tracker.
Please also reach out to orbax-dev@google.com directly for help or with any questions about Orbax.
Citing Orbax
Our paper is available on arXiv.
If you use Orbax in your research, please cite:
@misc{gaffney2026orbaxdistributedcheckpointingjax,
title={Orbax: Distributed Checkpointing with JAX},
author={Colin Gaffney and Shutong Li and Daniel Ng and Anastasia Petrushkina and Niket Kumar and Adam Cogdell and Mridul Sahu and Yaning Liang and Nikhil Bansal and Justin Pan and Angel Mau and Abhishek Agrawal and Marco Berlot and Ruoxin Sang and Kiranbir Sodhia and Rakesh Iyer},
year={2026},
eprint={2605.23066},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={https://arxiv.org/abs/2605.23066},
}
Existing Users
Orbax Checkpointing is used extensively across JAX machine learning frameworks and model implementations.
Google Projects
- Flax (Google's flexible and expressive neural network library for JAX)
- Gemma (Open foundation models by Google DeepMind)
- Kauldron (Google Research training and evaluation framework)
- PaxML (Google's high-performance framework for training large-scale JAX models)
- T5X (Google's JAX framework for high-performance sequence models)
- MaxText (Google's high-performance, scalable JAX LLM implementation)
- MaxDiffusion (Stable diffusion JAX training library optimized for Cloud TPUs)
- Tunix (Google's JAX-native library for LLM post-training)
- Numerous Google-internal ML frameworks