KANX: Fast Implementation (Approximation) of Kolmogorov-Arnold Network in JAX
May 20, 2024 ยท View on GitHub
Work in progress
Introduction
Fast Kolmogorov-Arnold Network in JAX based on fast-kan using equinox.
The original implementation of KAN is pykan.
Installation
pip install .
pip install -r requirements.txt
Example
KANX comes with an example on MNIST:
python examples/train_mnist.py
Benchmark
We tested the implementation on MNIST and report the following wall-time for 3000 epochs:
| Architecture | Wall time (sec) |
|---|---|
| CPU (i5-1135G7) | 130.51 |
| CPU (i9-12900K) | 67.85 |
| GPU (RTX 3070 Ti) | 13.55 |
Plots from the GPU experiment:
More experiments to come...