Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pairwise similarity matrix between a set of vectors in PyTorch

Let's suppose that we have a 3D PyTorch tensor, where the first dimension represents the batch_size, as follows:

import torch
import torch.nn as nn
x = torch.randn(32, 100, 25)

That is, for each i, x[i] is a set of 100 25-dimensional vectors. I would like to compute the similarity (e.g., the cosine similarity -- but in general any such pairwise distance/similarity matrix) of these vectors for each batch item.

That is, for each x[i] I need to compute a [100, 100] matrix which will contain the pairwise similarities of the above vectors. More specifically, the (i,j)-th element of this matrix should contain the similarity (or the distance) between the i-th and the j-th row of (the 100x25) x[t], for all t=1, ..., batch_size.

If I use torch.nn.CosineSimilarity(), no matter what dim I'm using, the result is either [100, 25] (dim=0), or [32, 25] (dim=1) , where I need a tensor of size [32, 100, 100]. I would expect torch.nn.CosineSimilarity() to work this way (since, at least to me, it looks more intuitive), but it doesn't.

Could that be done using something like below?

torch.matmul(x, x.permute(0, 2, 1))

I guess that this could give a distance matrix, but what if I need an arbitrary pairwise operation? Should I build this operation using the above?

Or maybe should I repeat x in a way so I can use the built-in torch.nn.CosineSimilarity()?

Thank you.

like image 207
nullgeppetto Avatar asked Dec 23 '22 19:12

nullgeppetto


1 Answers

The documentation implies that the shapes of the inputs to cosine_similarity must be equal but this is not the case. Internally PyTorch broadcasts via torch.mul, inserting a dimension with a slice (or torch.unsqueeze) will give you the desired result. This is not optimal due to duplicate computations and memory for the upper and lower triangles but it's simple:

import torch
from torch.nn import functional as F
from scipy.spatial import distance

# compute once in pytorch
x = torch.randn(32, 100, 25)
y = F.cosine_similarity(x[..., None, :, :], x[..., :, None, :], dim=-1)

assert y.shape == torch.Size([32, 100, 100])

# test against scipy by iterating over each batch element
z = []
for i in range(x.shape[0]):
    slice = x[i, ...].numpy()
    z.append(torch.tensor(distance.cdist(slice, slice, metric='cosine'), dtype=torch.float32))

# convert similarity to distance and ensure they're reasonably close
assert torch.allclose(torch.stack(z), 1.0-y)

like image 186
Nathan Howell Avatar answered May 10 '23 13:05

Nathan Howell