KANX: Fast Implementation (Approximation) of Kolmogorov-Arnold Network in JAX

May 20, 2024 ยท View on GitHub

Work in progress

Introduction

Fast Kolmogorov-Arnold Network in JAX based on fast-kan using equinox.

The original implementation of KAN is pykan.

Installation

pip install .
pip install -r requirements.txt

Example

KANX comes with an example on MNIST:

python examples/train_mnist.py

Benchmark

We tested the implementation on MNIST and report the following wall-time for 3000 epochs:

ArchitectureWall time (sec)
CPU (i5-1135G7)130.51
CPU (i9-12900K)67.85
GPU (RTX 3070 Ti)13.55

Plots from the GPU experiment:

mlp_kan_compare mlp_kan_compare

More experiments to come...