Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding the top k matches in Pytorch

I'm using the following code to find the topk matches using pytorch:

def find_top(self, x, y, n_neighbors, unit_vectors=False, cuda=False):
    if not unit_vectors:
        x = __to_unit_torch__(x, cuda=cuda)
        y = __to_unit_torch__(y, cuda=cuda)
    with torch.no_grad():
        d = 1. - torch.matmul(x, y.transpose(0, 1))
        values, indices = torch.topk(d, n_neighbors, dim=1, largest=False, sorted=True)
        return indices.cpu().numpy()

Unfortunately, it is throwing the following error:

values, indices = torch.topk(d, n_neighbors, dim=1, largest=False, sorted=True)
RuntimeError: invalid argument 5: k not in range for dimension at /pytorch/aten/src/THC/generic/THCTensorTopK.cu:23

The size of d is (1793,1) . What am I missing?

like image 733
user1241241 Avatar asked Aug 18 '20 07:08

user1241241


1 Answers

This error occurs when you call torch.topk with a k larger than the total number of classes. Reduce your argument and it should run fine.

like image 191
iacob Avatar answered Oct 06 '22 08:10

iacob