TorchPairwise [](https://github.com/inspiros/torchpairwise/actions) [](https://pypi.org/project/torchpairwise/) [](https://pepy.tech/project/torchpairwise) [](LICENSE.txt) [](https://doi.org/10.5281/zenodo.14699363)

May 20, 2025 · View on GitHub

This package provides highly-efficient pairwise metrics for PyTorch.

Highlights

torchpairwise is a collection of general purpose pairwise metric functions that behave similar to torch.cdist (which only implements LpL_p distance). Instead, we offer a lot more metrics ported from other packages such as scipy.spatial.distance and sklearn.metrics.pairwise. For task-specific metrics (e.g. for evaluation of classification, regression, clustering, ...), you should be in the wrong place, please head to the TorchMetrics repo.

Written in torch's C++ API, the main differences are that our metrics:

  • are all (except some boolean distances) differentiable with backward formulas manually derived, implemented, and verified with torch.autograd.gradcheck.
  • are batched and can exploit GPU parallelization.
  • can be integrated seamlessly within PyTorch-based projects, all functions are torch.jit.script-able.

List of pairwise distance metrics

torchpairwise opsEquivalences in other librariesDifferentiable
euclidean_distancessklearn.metrics.pairwise.euclidean_distances✔️
haversine_distancessklearn.metrics.pairwise.haversine_distances✔️
manhattan_distancessklearn.metrics.pairwise.manhattan_distances✔️
cosine_distancessklearn.metrics.pairwise.cosine_distances✔️
l1_distances(Alias of manhattan_distances)✔️
l2_distances(Alias of euclidean_distances)✔️
lp_distances(Alias of minkowski_distances)✔️
linf_distances(Alias of chebyshev_distances)✔️
directed_hausdorff_distancesscipy.spatial.distance.directed_hausdorff 1✔️
minkowski_distancesscipy.spatial.distance.minkowski 1✔️
wminkowski_distancesscipy.spatial.distance.wminkowski 1✔️
sqeuclidean_distancesscipy.spatial.distance.sqeuclidean_distances 1✔️
correlation_distancesscipy.spatial.distance.correlation 1✔️
hamming_distancesscipy.spatial.distance.hamming 12
jaccard_distancesscipy.spatial.distance.jaccard 12
kulsinski_distancesscipy.spatial.distance.kulsinski 12
kulczynski1_distancesscipy.spatial.distance.kulczynski1 12
seuclidean_distancesscipy.spatial.distance.seuclidean 1✔️
cityblock_distancesscipy.spatial.distance.cityblock 1 (Alias of manhattan_distances)✔️
mahalanobis_distancesscipy.spatial.distance.mahalanobis 1✔️
chebyshev_distancesscipy.spatial.distance.chebyshev 1✔️
braycurtis_distancesscipy.spatial.distance.braycurtis 1✔️
canberra_distancesscipy.spatial.distance.canberra 1✔️
jensenshannon_distancesscipy.spatial.distance.jensenshannon 1✔️
yule_distancesscipy.spatial.distance.yule 12
dice_distancesscipy.spatial.distance.dice 12
rogerstanimoto_distancesscipy.spatial.distance.rogerstanimoto 12
russellrao_distancesscipy.spatial.distance.russellrao 12
sokalmichener_distancesscipy.spatial.distance.sokalmichener 12
sokalsneath_distancesscipy.spatial.distance.sokalsneath 12
snr_distancespytorch_metric_learning.distances.SNRDistance 1✔️

Other pairwise metrics or kernel functions

These metrics are usually used to compute kernel for machine learning algorithms.

torchpairwise opsEquivalences in other librariesDifferentiable
linear_kernelsklearn.metrics.pairwise.linear_kernel✔️
polynomial_kernelsklearn.metrics.pairwise.polynomial_kernel✔️
sigmoid_kernelsklearn.metrics.pairwise.sigmoid_kernel✔️
rbf_kernelsklearn.metrics.pairwise.rbf_kernel✔️
laplacian_kernelsklearn.metrics.pairwise.laplacian_kernel✔️
cosine_similaritysklearn.metrics.pairwise.cosine_similarity✔️
additive_chi2_kernelsklearn.metrics.pairwise.additive_chi2_kernel✔️
chi2_kernelsklearn.metrics.pairwise.chi2_kernel✔️

Custom cdist and pdist

Furthermore, we provide a convenient wrapper function analoguous to torch.cdist excepts that it takes a string metric: str = "minkowski" indicating the desired metric to be used as the third argument, and extra metric-specific arguments are passed as keywords.

import torch, torchpairwise

# directed_hausdorff_distances is a pairwise 2d metric
x1 = torch.rand(10, 6, 3)
x2 = torch.rand(8, 5, 3)

generator = torch.Generator().manual_seed(1)
output = torchpairwise.cdist(x1, x2,
                             metric="directed_hausdorff",
                             shuffle=True,  # kwargs exclusive to directed_hausdorff
                             generator=generator)

Note that pairwise metrics on the second table are currently not allowed keys for cdist because they are not dist. We have a similar plan for pdist (which is equivalent to calling cdist(x1, x1) but avoid storing duplicated positions). However, that requires a total overhaul of existing C++/Cuda kernels and won't be available soon.

Future Improvements

  • Add more metrics (contact me or create a feature request issue).
  • Add memory-efficient argkmin for retrieving pairwise neighbors' distances and indices without storing the whole pairwise distance matrix.
  • Add an equivalence of torch.pdist with metric: str = "minkowski" argument.
  • (Unlikely) Support sparse layouts.

Requirements

  • torch>=2.7.0,<2.8.0 (torch>=1.9.0 if compiled from source)

Notes:

Since torch extensions are not forward compatible, I have to fix a maximum version for the PyPI package and regularly update it on GitHub (but I am not always available). If you use a different version of torch or your platform is not supported, please follow the instructions to install from source.

Installation

From PyPI:

To install prebuilt wheels from torchpairwise, simply run:

pip install torchpairwise

Note that the Linux and Windows wheels in PyPI are compiled with torch==2.7.0 and Cuda 12.8. We only do a non-strict version checking and a warning will be raised if torch's and torchpairwise's Cuda versions do not match.

From Source:

Make sure your machine has a C++17 and a Cuda compiler installed, then clone the repo and run:

pip install .

Usage

The basic usecase is very straight-forward if you are familiar with sklearn.metrics.pairwise and scipy.spatial.distance:

scikit-learn / SciPy TorchPairwise
import numpy as np
import sklearn.metrics.pairwise as sklearn_pairwise

x1 = np.random.rand(10, 5)
x2 = np.random.rand(12, 5)

output = sklearn_pairwise.cosine_similarity(x1, x2)
print(output)
import torch
import torchpairwise

x1 = torch.rand(10, 5, device='cuda')
x2 = torch.rand(12, 5, device='cuda')

output = torchpairwise.cosine_similarity(x1, x2)
print(output)
import numpy as np
import scipy.spatial.distance as distance

x1 = np.random.binomial(
    1, p=0.6, size=(10, 5)).astype(np.bool_)
x2 = np.random.binomial(
    1, p=0.7, size=(12, 5)).astype(np.bool_)

output = distance.cdist(x1, x2, metric='jaccard')
print(output)
import torch
import torchpairwise

x1 = torch.bernoulli(
    torch.full((10, 5), fill_value=0.6, device='cuda')).to(torch.bool)
x2 = torch.bernoulli(
    torch.full((12, 5), fill_value=0.7, device='cuda')).to(torch.bool)

output = torchpairwise.jaccard_distances(x1, x2)
print(output)

Please check the tests folder where we will add more examples.

Citation

@software{hoang_nhat_tran_2025_15470239,
  author       = {Hoang-Nhat Tran},
  title        = {inspiros/torchpairwise: v0.3.0},
  month        = may,
  year         = 2025,
  publisher    = {Zenodo},
  version      = {v0.3.0},
  doi          = {10.5281/zenodo.15470239},
  url          = {https://doi.org/10.5281/zenodo.15470239},
  swhid        = {swh:1:dir:c4402d590533d1ec43616b8b90e369efbe2bbb76
                   ;origin=https://doi.org/10.5281/zenodo.14699363;vi
                   sit=swh:1:snp:a66a7b521e2ad19609f3f7671eb8cbe462bf
                   076f;anchor=swh:1:rel:e528b96cf29f913b6f9648dfafd5
                   145163f90263;path=inspiros-torchpairwise-0ff57b0
                  },
}

License

The code is released under the MIT license. See LICENSE.txt for details.

Footnotes

  1. These metrics are not pairwise but a pairwise form can be computed by calling scipy.spatial.distance.cdist(x1, x2, metric="[metric_name_or_callable]"). 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23

  2. These are boolean distances. hamming_distances can be applied for floating point inputs but involves comparison. 2 3 4 5 6 7 8 9 10