Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get indices of top N values in 2D numpy ndarray or numpy matrix

I have an array of N-dimensional vectors.

data = np.array([[5, 6, 1], [2, 0, 8], [4, 9, 3]])

In [1]: data
Out[1]:
array([[5, 6, 1],
       [2, 0, 8],
       [4, 9, 3]])

I'm using sklearn's pairwise_distances function to compute a matrix of distance values. Note that this matrix is symmetric about the diagonal.

dists = pairwise_distances(data)

In [2]: dists
Out[2]:
array([[  0.        ,   9.69535971,   3.74165739],
       [  9.69535971,   0.        ,  10.48808848],
       [  3.74165739,  10.48808848,   0.        ]])

I need the indices corresponding to the top N values in this matrix dists, because these indices will correspond the pairwise indices in data that represent vectors with the greatest distances between them.

I have tried doing np.argmax(np.max(distances, axis=1)) to get the index of the max value in each row, and np.argmax(np.max(distances, axis=0)) to get the index of the max value in each column, but note that:

In [3]: np.argmax(np.max(dists, axis=1))
Out[3]: 1

In [4]: np.argmax(np.max(dists, axis=0))
Out[4]: 1

and:

In [5]: dists[1, 1]
Out[5]: 0.0

Because the matrix is symmetric about the diagonal, and because argmax returns the first index it finds with the max value, I end up with the cell in the diagonal in the row and column matching where the max values are stored, instead of the row and column of the top values themselves.

At this point I'm sure I could write some more code to find the values I'm looking for, but surely there is an easier way to do what I'm trying to do. So I have two questions that are more or less equivalent:

How can I find the indices corresponding to the top N values in a matrix, or , how can I find the vectors with the top N pairwise distances from an array of vectors?

like image 529
vaer-k Avatar asked Dec 23 '22 20:12

vaer-k


1 Answers

I'd ravel, argsort, and then unravel. I'm not claiming this is the best way, only that it's the first way that occurred to me, and I'll probably delete it in shame after someone posts something more obvious. :-)

That said (choosing the top 2 values, arbitrarily):

In [73]: dists = sklearn.metrics.pairwise_distances(data)

In [74]: dists[np.tril_indices_from(dists, -1)] = 0

In [75]: dists
Out[75]: 
array([[  0.        ,   9.69535971,   3.74165739],
       [  0.        ,   0.        ,  10.48808848],
       [  0.        ,   0.        ,   0.        ]])

In [76]: ii = np.unravel_index(np.argsort(dists.ravel())[-2:], dists.shape)

In [77]: ii
Out[77]: (array([0, 1]), array([1, 2]))

In [78]: dists[ii]
Out[78]: array([  9.69535971,  10.48808848])
like image 58
DSM Avatar answered Dec 26 '22 09:12

DSM