I have tensors X of shape BxNxD
and Y of shape BxNxD
.
I want to compute the pairwise distances for each element in the batch, i.e. I a BxMxN
tensor.
How do I do this?
There is some discussion on this topic here: https://github.com/pytorch/pytorch/issues/9406, but I don't understand it as there are many implementation details while no actual solution is highlighted.
A naive approach would be to use the answer for non-batched pairwise distances as discussed here: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065, i.e.
import torch
import numpy as np
B = 32
N = 128
M = 256
D = 3
X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))
def pairwise_distances(x, y=None):
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
return torch.clamp(dist, 0.0, np.inf)
out = []
for b in range(B):
out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)
How can I do this without looping over B? Thanks
I had a similar issue and spent some time to find the easiest and fastest solution. Now you can compute batched distance by using PyTorch cdist which will give you BxMxN
tensor:
torch.cdist(Y, X)
Also, it works well if you just want to compute distances between each pair of rows of two matrixes.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With