JAX on ROCm
March 13, 2026 · View on GitHub
This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support in a Docker environment, suitable for both runtime and CI workflows. Explore the following methods to use or build JAX on ROCm!
0. Install via pip (JAX extras)
JAX supports ROCm via pip extras:
pip install --upgrade "jax[rocm7-local]"
ROCm fixes via post-releases
ROCm-specific fixes are shipped as post-releases of the ROCm plugin/PJRT
packages (for example, jax-rocm7-plugin==0.9.1.post1). Upgrading
jax[rocm7-local] will pick up the newest compatible post-release
available from your configured package indexes.
Important: ROCm must already be installed (for now)
Until ROCm wheels are distributed via TheRock, the jax[rocm7-local] extra
installs the JAX ROCm plugin packages, but do not install ROCm itself.
You must run in an environment where ROCm is already installed (for example,
a ROCm Docker container).
1. Using Prebuilt Docker Images
The ROCm JAX team provides prebuilt Docker images, which the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.
To pull the latest ROCm JAX Docker image, run:
> docker pull rocm/jax:latest
Once the image is downloaded, launch a container using the following command:
> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/jax:latest /bin/bash
> docker attach rocm_jax
Notes:
- The
--shm-sizeparameter allocates shared memory for the container. Adjust it based on your system's resources if needed. - Replace
$(pwd)with the absolute path to the directory you want to mount inside the container.
For older versions please review the periodically pushed docker images at: ROCm JAX DockerHub.
Testing your ROCm environment with JAX:
After launching the container, test whether JAX detects ROCm devices as expected:
> python -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)]
If the setup is successful, the output should list all available ROCm devices.
2. Using a ROCm Docker Image and Installing JAX
If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu container, follow these steps to install JAX in the container.
Step 1: Pull the ROCm Ubuntu Docker Image
For example, use the following command to pull the ROCm Ubuntu image:
> docker pull rocm/dev-ubuntu-24.04:7.0.2-complete
Step 2: Launch the Docker Container
After pulling the image, launch a container using this command:
> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/dev-ubuntu-24.04:7.0.2-complete /bin/bash
> docker attach rocm_jax
Step 3: Install the Latest Version of JAX
Install the required version of JAX and the ROCm plugins using pip. Follow the
instructions for the latest
release. For example, on a system
with python 3.12, you will need to run the following to install jax 0.6.2:
> pip3 install \
https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl \
https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl \
https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl \
jax==0.6.2
Step 4: Verify the Installed JAX Version
Check whether the correct version of JAX and its ROCm plugins are installed:
> pip3 freeze | grep jax
jax==0.6.2
jax-rocm7-pjrt @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl#sha256=b20b6820d4701a8edd83509dcbc8dc4fb712f40eab873668ae0dd17f5194c2d6
jax-rocm7-plugin @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=cfecc2865ed450f996608b13af04189a2f9c1328ed896d71be0872d0e7d78389
jaxlib @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl#sha256=739fc2ebe28399f551a5c6daf529baae1637546a9a2a93789e3afd7ef0444e66
Step 5: Set the LLVM_PATH Environment Variable
Explicitly set the LLVM_PATH environment variable (This helps XLA find ld.lld in the PATH during runtime):
> export LLVM_PATH=/opt/rocm/llvm
Step 6: Verify the Installation of ROCm JAX
Run the following command to verify that ROCm JAX is installed correctly:
> python3 -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)]
> python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
[0 1 2 3 4]
3. Install JAX On Bare-metal or A Custom Container
Follow these steps if you prefer to install ROCm manually on your host system or in a custom container.
Installing ROCm Libraries Manually
Step 1: Install ROCm
Please follow ROCm installation guide to install ROCm on your system.
Once installed, verify ROCm installation using:
> rocm-smi
============================================ ROCm System Management Interface ============================================
====================================================== Concise Info ======================================================
Device Node IDs Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
(DID, GUID) (Junction) (Socket) (Mem, Compute, ID)
==========================================================================================================================
0 2 0x74a1, 28851 43.0°C 142.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0%
1 3 0x74a1, 23018 37.0°C 137.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0%
2 4 0x74a1, 29122 44.0°C 140.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0%
3 5 0x74a1, 22683 38.0°C 138.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0%
4 6 0x74a1, 53458 42.0°C 143.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0%
5 7 0x74a1, 63883 39.0°C 138.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0%
6 8 0x74a1, 53667 42.0°C 140.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0%
7 9 0x74a1, 63738 38.0°C 135.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0%
==========================================================================================================================
================================================== End of ROCm SMI Log ===================================================
Step 2: Install the Latest Version of JAX
Install the required version of JAX and the ROCm plugins using pip. Follow the
instructions for the latest
release. For example, on a system
with python 3.12, you will need to run the following to install jax 0.6.2:
> pip3 install \
https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl \
https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl \
https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl \
jax==0.6.2
Step 3: Verify the Installed JAX Version
Check whether the correct version of JAX and its ROCm plugins are installed:
> pip3 freeze | grep jax
jax==0.6.2
jax-rocm7-pjrt @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl#sha256=b20b6820d4701a8edd83509dcbc8dc4fb712f40eab873668ae0dd17f5194c2d6
jax-rocm7-plugin @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=cfecc2865ed450f996608b13af04189a2f9c1328ed896d71be0872d0e7d78389
jaxlib @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl#sha256=739fc2ebe28399f551a5c6daf529baae1637546a9a2a93789e3afd7ef0444e66
Step 4: Set the LLVM_PATH Environment Variable
Explicitly set the LLVM_PATH environment variable (This helps XLA find ld.lld in the PATH during runtime):
> export LLVM_PATH=/opt/rocm/llvm
Step 5: Verify the Installation of ROCm JAX
Run the following command to verify that ROCm JAX is installed correctly:
> python3 -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)]
> python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
[0 1 2 3 4]
4. Build ROCm JAX from Source
Follow these steps to build JAX with ROCm support from source:
Step 1: Build the ROCm specific wheels from rocm-jax
Clone the rocm-jax repository for the desired branch:
> git clone https://github.com/ROCm/rocm-jax.git -b <branch_name>
> cd rocm-jax
From the rocm-jax directory run:
> python3 build/ci_build \
--python-version $PYTHON_VERSION \
--rocm_version $ROCM_VERSION \
dist_wheels
> pip3 install jax_rocm_plugin/wheelhouse/*.whl
The build will produce two wheels:
jax-rocm-plugin(ROCm-specific plugin)jax-rocm-pjrt(ROCm-specific runtime)
Detailed build instructions can be found here.
Step 2: Build jaxlib from the JAX Repository
Clone the ROCm-specific fork of JAX for the desired branch:
> git clone https://github.com/ROCm/jax -b <branch_name>
> cd jax
Run the following command to build the jaxlib wheel:
> python3 ./build/build.py build --wheels=jaxlib \
--rocm_version=7 --rocm_path=/opt/rocm-[version]
This will generate the jaxlib wheel in the dist/ directory. jaxlib is a
device agnostic library.
Step 3: Then install custom JAX using:
> python3 setup.py develop --user && pip3 -m pip install dist/*.whl
Simplified Build Script
For a streamlined process, consider using the jax/build/rocm/dev_build_rocm.py script.