Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to calculate Batch Pairwise Distance in PyTorch efficiently

I have tensors X of shape BxNxD and Y of shape BxNxD.

I want to compute the pairwise distances for each element in the batch, i.e. I a BxMxN tensor.

How do I do this?

There is some discussion on this topic here: https://github.com/pytorch/pytorch/issues/9406, but I don't understand it as there are many implementation details while no actual solution is highlighted.

A naive approach would be to use the answer for non-batched pairwise distances as discussed here: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065, i.e.

import torch
import numpy as np

B = 32
N = 128
M = 256
D = 3

X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))


def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)

    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)


out = []
for b in range(B):
    out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)

How can I do this without looping over B? Thanks

like image 829
JoSauderGH Avatar asked Mar 12 '19 16:03

JoSauderGH


1 Answers

I had a similar issue and spent some time to find the easiest and fastest solution. Now you can compute batched distance by using PyTorch cdist which will give you BxMxN tensor:

torch.cdist(Y, X)

Also, it works well if you just want to compute distances between each pair of rows of two matrixes.

like image 108
dinarkino Avatar answered Nov 02 '22 09:11

dinarkino