CUDA/Torch KD-Tree K-Nearest Neighbor Operator
March 17, 2026 · View on GitHub
This repository implements a KD-Tree on CUDA with an interface for torch. It is a port of a previous implementation for tensorflow called tf_kdtree.
The KD-Tree is always generated using the CPU, but is automatically transferred to the GPU for cupy operations there. The KD-Tree implementation will search the k nearest neighbors of each queried point in logarithmic time and is best suited for repeated nearest neighbor queries in a static point cloud.
The algorithms' dimensions are currently defined through template parameters and must be known at compile-time. The present version compiles the library for the dimensionalities 1, 2, 3. See Compiling additional dimensions for instructions on how to compile additional dimensions.
Usage Examples
from torch_kdtree import build_kd_tree
import torch
from scipy.spatial import KDTree #Reference implementation
import numpy as np
#Dimensionality of the points and KD-Tree
d = 3
#Specify the device on which we will operate
#Uses CUDA when available, otherwise falls back to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Create some random point clouds
points_ref = (torch.randn(size=(1000, d), dtype=torch.float32, device=device) * 1e3).requires_grad_()
points_query = (torch.randn(size=(100, d), dtype=torch.float32, device=device) * 1e3).requires_grad_()
#Create the KD-Tree on the GPU and the reference implementation
torch_kdtree = build_kd_tree(points_ref)
kdtree = KDTree(points_ref.detach().cpu().numpy())
#Search for the 5 nearest neighbors of each point in points_query
k = 5
dists, inds = torch_kdtree.query(points_query, nr_nns_searches=k)
dists_ref, inds_ref = kdtree.query(points_query.detach().cpu().numpy(), k=k)
#Test for correctness
#Note that the torch_kdtree distances are squared
assert(np.all(inds.cpu().numpy() == inds_ref))
assert(np.allclose(torch.sqrt(dists).detach().cpu().numpy(), dists_ref, atol=1e-5))
For batched workloads with one KD-tree per sample, you can build and query all trees with tensor-first shapes [B, N, D] and [B, M, D]:
from torch_kdtree import build_kd_tree_batched
import torch
from scipy.spatial import KDTree #Reference implementation
import numpy as np
B, N, M, D = 8, 10000, 100, 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Reference and query points are batched in the first dimension
points_ref_batched = torch.randn(B, N, D, dtype=torch.float32, device=device) * 1e3
points_query_batched = torch.randn(B, M, D, dtype=torch.float32, device=device) * 1e3
batched_kdtree = build_kd_tree_batched(points_ref_batched, device=device)
ref_kdtrees = [KDTree(p.detach().cpu().numpy()) for p in points_ref_batched]
#Search for the 5 nearest neighbors of each point, for each batch in points_query
k = 5
dists, inds = batched_kdtree.query(points_query_batched, nr_nns_searches=k)
dists_ref, inds_ref = zip(*[kdtree.query(p.detach().cpu().numpy(), k=k) for p, kdtree in zip(points_query_batched, ref_kdtrees)])
dists_ref = np.stack(dists_ref, axis=0)
inds_ref = np.stack(inds_ref, axis=0)
#Test for correctness
#Note that the torch_kdtree distances are squared
assert(np.all(inds.cpu().numpy() == inds_ref))
assert(np.allclose(torch.sqrt(dists).detach().cpu().numpy(), dists_ref, atol=1e-5))
We can also compute the gradient w.r.t. both point-clouds.
(0.5 * torch.sum(dists)).backward()
grad = points_query.grad
grad_comp = torch.sum((points_query[:, None] - points_ref[inds]), dim=-2)
print(torch.allclose(points_query.grad, grad_comp)) #Should print True
Installation
Prerequisites
- Python
- Numpy (installed with
setuptools) - Torch (installed with
setuptools) - Cuda
- g++, or Visual Studio x64 (MacOSX is untested)
- CMake
Build Instruction
Clone the repository and fetch the submodule pybind11:
git clone https://github.com/thomgrand/torch_kdtree
cd torch_kdtree
git submodule init
git submodule update
The easiest way of installing the library is using setuptools:
pip install .
Tests
After installation, you can run python -m pytest . inside the folder tests to verify that the library has been installed correctly.
Benchmark
We compared the implementation to scipy.spatial.KDTree (run with workers=-1 to use all available CPU cores). Note that the benchmarks do not consider the time to build the KD-Trees, or the transfer to the GPU. Times greater than 1 second not shown.
Test Machine Specs: AMD Ryzen Threadripper 3970X 32x 3.7GHz, 128GB of working memory and a NVidia RTX 3090 GPU.

To run the benchmark on your computer, simply run python benchmark/benchmark.py. This will create benchmark_results.npz that can be converted to a figure using python benchmark/plot_benchmark.py (will require matplotlib).
Compiling additional dimensions
The dimension of the KD-Tree are compile time dynamic, meaning that the dimensions to be queried need to be known at compile time. By default, the library is compiled for d in [1, 2, 3]. You can add additional dimensions by adding new template dimensions in three places.
To add dimensionality 8 for example, you have to add the following code snippets
src/interface.cpp (line 115)
KDTREE_INSTANTIATION(float, 8, false, "KDTreeCPU8DF");
KDTREE_INSTANTIATION(double, 8, false, "KDTreeCPU8D");
KDTREE_INSTANTIATION(float, 8, true, "KDTreeGPU8DF");
KDTREE_INSTANTIATION(double, 8, true, "KDTreeGPU8D");
In src/kdtree_g.cu (line 476) and src/kdtree.cpp (line 221), you add the same code:
KDTREE_INSTANTIATION(float, 8);
KDTREE_INSTANTIATION(double, 8);
This will instantiate the template functions for float and double types both on the CPU and GPU.
Limitations
- No multi-GPU support
- Int32 KNN indexing inside the library
- Data must be cast to contiguous arrays before processing (automatically done by the library)
- No in-place updates of the KD-Tree. If you modify the point-cloud, you will have to create a new KD-Tree.
Acknowledgements
If this works helps you in your research, please consider acknowledging the github repository, or citing our paper from which the library originated.
@article{grandits_geasi_2021,
title = {{GEASI}: {Geodesic}-based earliest activation sites identification in cardiac models},
volume = {37},
issn = {2040-7947},
shorttitle = {{GEASI}},
url = {https://onlinelibrary.wiley.com/doi/abs/10.1002/cnm.3505},
doi = {10.1002/cnm.3505},
language = {en},
number = {8},
urldate = {2021-08-12},
journal = {International Journal for Numerical Methods in Biomedical Engineering},
author = {Grandits, Thomas and Effland, Alexander and Pock, Thomas and Krause, Rolf and Plank, Gernot and Pezzuto, Simone},
year = {2021},
keywords = {eikonal equation, cardiac model personalization, earliest activation sites, Hamilton–Jacobi formulation, inverse ECG problem, topological gradient},
pages = {e3505}
}
Made with ☕ · Buy me a Ko-Fi if this saved you time!