README.md
December 11, 2025 ยท View on GitHub
SciPy-like differential evolution for JAX
Fully jitted optimization of any JAX-compatible function. Serial and parallel execution on CPU, GPU, and TPU.
Installation
pip install mutax
Quick start
import jax.numpy as jnp
from mutax import differential_evolution
def cost_function(xs):
return jnp.sum(xs**2)
bounds = [(-5, 5)] * 10 # 10-dimensional problem with bounds for each dimension
result = differential_evolution(cost_function, bounds)
print("Best solution:", result.x)
print("Objective value:", result.fun)
Documentation
The documentation is available at Read the Docs.