Divergence Measures for Probability Distributions in NumKong

April 2, 2026 · View on GitHub

NumKong implements divergence functions between discrete probability distributions: Kullback-Leibler divergence measures the information lost when one distribution approximates another, while Jensen-Shannon distance provides a symmetric and bounded alternative. These are used in variational inference, topic modeling, and distribution comparison tasks.

Kullback-Leibler divergence from PP to QQ:

KLD(PQ)=i=0n1P(i)log2P(i)Q(i)\text{KLD}(P \| Q) = \sum_{i=0}^{n-1} P(i) \log_2 \frac{P(i)}{Q(i)}

Jensen-Shannon distance is the square root of the symmetrized KLD through a mixture:

JSD(P,Q)=12KLD(PM)+12KLD(QM)\text{JSD}(P, Q) = \frac{1}{2} \text{KLD}(P \| M) + \frac{1}{2} \text{KLD}(Q \| M)

where M=P+Q2M = \frac{P + Q}{2}, yielding the distance:

dJS(P,Q)=JSD(P,Q)d_{JS}(P, Q) = \sqrt{\text{JSD}(P, Q)}

Unlike the raw divergence, dJSd_{JS} is a true metric satisfying the triangle inequality.

Reformulating as Python pseudocode:

import numpy as np

def kld(p: np.ndarray, q: np.ndarray) -> float:
    mask = p > 0
    return np.sum(p[mask] * np.log2(p[mask] / q[mask]))

def jsd(p: np.ndarray, q: np.ndarray) -> float:
    m = (p + q) / 2
    return np.sqrt((kld(p, m) + kld(q, m)) / 2)

Use Cases

Kullback-Leibler divergence is widely used in variational inference (ELBO objective), knowledge distillation between neural networks, information gain in decision trees, and measuring fit between a model and observed data.

Jensen-Shannon distance is commonly used in microbiome community comparison (enterotyping), where its metric property enables clustering with standard algorithms. It also appears in distribution drift detection, topic model evaluation, and as the theoretical foundation of the original GAN objective — though in practice GAN training uses proxy losses rather than computing JSD directly.

Input & Output Types

Input TypeOutput TypeDescription
f64f6464-bit IEEE 754 double precision
f32f3232-bit IEEE 754 single precision
f16f3216-bit IEEE 754 half precision, widened output
bf16f3216-bit brain float, widened output

Optimizations

SIMD Log2 Approximation

nk_kld_f32_skylake, nk_jsd_f32_skylake use VGETEXP and VGETMANT to decompose floating-point values into exponent and mantissa components, then apply a polynomial approximation to the mantissa to compute log2\log_2. The pipeline on Skylake is:

exponent = VGETEXPPS(x)
mantissa = VGETMANTPS(x, normalize_to_[1,2)) - 1
log2(x) ≈ exponent + polynomial(mantissa)

VGETEXP extracts the unbiased exponent as a float, while VGETMANT normalizes the mantissa to [1,2)[1, 2). A degree-4 minimax polynomial over the normalized mantissa completes the approximation. These instructions handle subnormals correctly without extra integer bit manipulation.

nk_kld_f32_neon, nk_jsd_f32_neon, nk_kld_f16_haswell, nk_jsd_f16_haswell use integer bit extraction instead:

exponent = (reinterpret_as_int(x) >> 23) - 127
mantissa = reinterpret_as_float((reinterpret_as_int(x) & 0x7FFFFF) | 0x3F800000) - 1
log2(x) ≈ exponent + c₁·m + c₂·m² + c₃·m³ + c₄·m⁴ + c₅·m⁵

This approach reinterprets the float as an integer, shifts out the mantissa bits to obtain the exponent, then masks and recombines to produce a normalized mantissa in [1,2)[1, 2). It works on any ISA with integer-float reinterpretation and avoids the need for specialized exponent/mantissa instructions.

Kahan Compensated Summation for Float64

nk_kld_f64_haswell, nk_jsd_f64_haswell use Kahan compensated summation to maintain a running correction term alongside the accumulator. The Kahan update for each divergence term is:

compensated_term = divergence_term - correction
tentative_sum    = accumulator + compensated_term
correction       = (tentative_sum - accumulator) - compensated_term
accumulator      = tentative_sum

After each P(i)log2(P(i)/Q(i))P(i) \log_2(P(i) / Q(i)) term is computed, correction captures the low-order bits lost in the addition, and the next iteration subtracts this correction from the new term before adding it to the accumulator. This keeps the accumulated error bounded by O(1)O(1) ULP regardless of vector length, rather than the O(n)O(n) ULP growth of naive summation.

Performance

The following performance tables are produced by manually re-running nk_test and nk_bench included internal tools to measure both accuracy and throughput at different input shapes. The input size is controlled by the NK_DENSE_DIMENSIONS environment variable and set to 256, 1024, and 4096 elements. The throughput is measured in GB/s as the number of input bytes per second. The published tables below summarize mean ULP (units in last place) across all test pairs — the average number of representable floating-point values between the computed result and the exact answer. The current nk_test family also reports max/mean absolute and relative divergence error for detailed inspection. Each kernel runs for at least 20 seconds per configuration. Benchmark threads are pinned to specific cores; on machines with heterogeneous core types (e.g., Apple P/E cores), only the fastest cores are used. Workloads that significantly degrade CPU frequencies (Intel AMX, Apple SME) run in separate passes to avoid affecting throughput measurements of other kernels.

Intel Sapphire Rapids

Native

Kernel25610244096
f64░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f64_serial0.693 gb/s, 5.65K ulp0.699 gb/s, 24.5K ulp0.753 gb/s, 98.9K ulp
nk_jsd_f64_serial0.324 gb/s, 0.5 ulp0.349 gb/s, 0.3 ulp0.391 gb/s, 0.6 ulp
nk_kld_f64_haswell5.34 gb/s, 5.64K ulp5.59 gb/s, 24.6K ulp5.76 gb/s, 99.1K ulp
nk_jsd_f64_haswell3.03 gb/s, 1.7 ulp3.05 gb/s, 1.4 ulp3.25 gb/s, 1.2 ulp
nk_kld_f64_skylake7.01 gb/s, 5.64K ulp6.85 gb/s, 24.4K ulp6.86 gb/s, 98.9K ulp
nk_jsd_f64_skylake3.66 gb/s, 1.6 ulp3.85 gb/s, 1.4 ulp4.30 gb/s, 1.2 ulp
f32░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f32_serial0.528 gb/s, 1.04K ulp0.516 gb/s, 4.54K ulp0.527 gb/s, 18.2K ulp
nk_jsd_f32_serial0.273 gb/s, 0.4 ulp0.272 gb/s, 0.4 ulp0.268 gb/s, 4.5 ulp
nk_kld_f32_skylake11.8 gb/s, 1.04K ulp10.4 gb/s, 4.55K ulp8.73 gb/s, 18.3K ulp
nk_jsd_f32_skylake6.25 gb/s, 6.6 ulp5.96 gb/s, 7.0 ulp6.05 gb/s, 11.1 ulp
bf16░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_bf16_serial0.138 gb/s, 1.04K ulp0.142 gb/s, 4.53K ulp0.136 gb/s, 18.3K ulp
nk_jsd_bf16_serial0.0857 gb/s, 1.5 ulp0.0842 gb/s, 3.4 ulp0.0841 gb/s, 10.7 ulp
f16░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f16_serial0.166 gb/s, 1.05K ulp0.163 gb/s, 4.53K ulp0.163 gb/s, 18.2K ulp
nk_jsd_f16_serial0.151 gb/s, 1.5 ulp0.148 gb/s, 2.3 ulp0.152 gb/s, 9.4 ulp
nk_kld_f16_haswell6.99 gb/s, 1.05K ulp6.09 gb/s, 4.54K ulp6.97 gb/s, 18.2K ulp
nk_jsd_f16_haswell2.81 gb/s, 6.4 ulp2.79 gb/s, 6.8 ulp2.72 gb/s, 11.5 ulp
nk_kld_f16_skylake6.16 gb/s, 1.05K ulp5.65 gb/s, 4.54K ulp5.78 gb/s, 18.3K ulp
nk_jsd_f16_skylake3.51 gb/s, 6.5 ulp3.22 gb/s, 6.9 ulp3.35 gb/s, 11.4 ulp

WASM

Measured with Wasmtime v42 (Cranelift backend).

Kernel25610244096
f64░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f64_serial0.239 gb/s, 5.64K ulp0.223 gb/s, 24.6K ulp0.13 gb/s, 99.6K ulp
nk_jsd_f64_serial0.315 gb/s, 0.5 ulp0.402 gb/s, 0.3 ulp0.29 gb/s, 0.5 ulp
f32░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f32_serial0.302 gb/s, 1.04K ulp0.342 gb/s, 4.52K ulp0.277 gb/s, 18.3K ulp
nk_jsd_f32_serial0.152 gb/s, 0.4 ulp0.164 gb/s, 0.4 ulp0.160 gb/s, 4.7 ulp
bf16░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_bf16_serial0.139 gb/s, 1.05K ulp0.143 gb/s, 4.53K ulp0.150 gb/s, 18.3K ulp
nk_jsd_bf16_serial0.0867 gb/s, 1.5 ulp0.0775 gb/s, 3.1 ulp0.0679 gb/s, 9.8 ulp
f16░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f16_serial0.118 gb/s, 1.04K ulp0.127 gb/s, 4.53K ulp0.111 gb/s, 18.3K ulp
nk_jsd_f16_serial0.0748 gb/s, 1.4 ulp0.0681 gb/s, 2.6 ulp0.0857 gb/s, 9.7 ulp

Apple M5

Native

Kernel25610244096
f64░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f64_serial3.22 gb/s, 5.6K ulp3.36 gb/s, 25K ulp3.32 gb/s, 99K ulp
nk_jsd_f64_serial2.06 gb/s, 0.4 ulp2.17 gb/s, 0.4 ulp2.17 gb/s, 0.5 ulp
f32░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f32_serial9.26 gb/s, 1.0K ulp8.73 gb/s, 4.5K ulp9.10 gb/s, 18K ulp
nk_jsd_f32_serial2.08 gb/s, 0.4 ulp2.16 gb/s, 0.4 ulp2.13 gb/s, 4.6 ulp
nk_kld_f32_neon19.0 gb/s, 1.0K ulp17.4 gb/s, 4.5K ulp18.1 gb/s, 18K ulp
nk_jsd_f32_neon9.75 gb/s, 15 ulp9.32 gb/s, 14 ulp9.62 gb/s, 9.9 ulp
bf16░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_bf16_serial4.58 gb/s, 1.0K ulp4.47 gb/s, 4.5K ulp4.65 gb/s, 18K ulp
nk_jsd_bf16_serial1.08 gb/s, 1.4 ulp1.07 gb/s, 2.9 ulp1.09 gb/s, 9.7 ulp
f16░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░
nk_kld_f16_serial4.63 gb/s, 1.0K ulp4.45 gb/s, 4.5K ulp4.55 gb/s, 18K ulp
nk_jsd_f16_serial1.03 gb/s, 1.4 ulp0.962 gb/s, 2.7 ulp0.976 gb/s, 8.7 ulp
nk_kld_f16_neonhalf10.2 gb/s, 1.0K ulp9.67 gb/s, 4.5K ulp9.99 gb/s, 18K ulp
nk_jsd_f16_neonhalf5.00 gb/s, 15 ulp4.79 gb/s, 14 ulp4.94 gb/s, 9.9 ulp