Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Quickly find indices that have values larger than a threshold in Numpy/PyTorch

Task

Given a numpy or pytorch matrix, find the indices of cells that have values that are larger than a given threshold.

My implementation

#abs_cosine is the matrix
#sim_vec is the wanted

sim_vec = []
for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
        # exclude diagonal cells
        if m != n and abs_cosine[m][n] >= threshold:
            sim_vec.append((m, n))

Concerns

Speed. All other computations are built on Pytorch, using numpy is already a compromise, because it has moved computations from GPU to CPU. Pure python for loops will make the whole process even worse (for small data set already 5 times slower). I was wondering if we can move the whole computation to Numpy (or pytorch) without invoking any for loops?

An improvement I can think of (but got stuck...)

bool_cosine = abs_cosine > threshold

which returns a boolean matrix of True and False. But I cannot find a way to quick retrieve the indices of the True cells.

like image 290
GabrielChu Avatar asked Apr 26 '18 14:04

GabrielChu


1 Answers

The following is for PyTorch (fully on GPU)

# abs_cosine should be a Tensor of shape (m, m)
mask = torch.ones(abs_cosine.size()[0])
mask = 1 - mask.diag()
sim_vec = torch.nonzero((abs_cosine >= threshold)*mask)

# sim_vec is a tensor of shape (?, 2) where the first column is the row index and second is the column index

The following works in numpy

mask = 1 - np.diag(np.ones(abs_cosine.shape[0]))
sim_vec = np.nonzero((abs_cosine >= 0.2)*mask)
# sim_vec is a 2-array tuple where the first array is the row index and the second array is column index
like image 96
layog Avatar answered Oct 31 '22 00:10

layog