Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get the N maximum values per row in a numpy ndarray?

Tags:

python

numpy

We know how to do it when N = 1

import numpy as np

m = np.arange(15).reshape(3, 5)
m[xrange(len(m)), m.argmax(axis=1)]    # array([ 4,  9, 14])

What is the best way to get the top N, when N > 1? (say, 5)

like image 552
PSNR Avatar asked May 09 '16 21:05

PSNR


3 Answers

Doing a partial sort using np.partition can be much cheaper than a full sort:

gen = np.random.RandomState(0)
x = gen.permutation(100)

# full sort
print(np.sort(x)[-10:])
# [90 91 92 93 94 95 96 97 98 99]

# partial sort such that the largest 10 items are in the last 10 indices
print(np.partition(x, -10)[-10:])
# [90 91 93 92 94 96 98 95 97 99]

If you need the largest N items to be sorted, you can call np.sort on the last N elements in your partially sorted array:

print(np.sort(np.partition(x, -10)[-10:]))
# [90 91 92 93 94 95 96 97 98 99]

This can still be much faster than a full sort on the whole array, provided your array is sufficiently large.


To sort across each row of a two-dimensional array you can use the axis= arguments to np.partition and/or np.sort:

y = np.repeat(np.arange(100)[None, :], 5, 0)
gen.shuffle(y.T)

# partial sort, followed by a full sort of the last 10 elements in each row
print(np.sort(np.partition(y, -10, axis=1)[:, -10:], axis=1))
# [[90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]]

Benchmarks:

In [1]: %%timeit x = np.random.permutation(10000000)
   ...: np.sort(x)[-10:]
   ...: 
1 loop, best of 3: 958 ms per loop

In [2]: %%timeit x = np.random.permutation(10000000)
np.partition(x, -10)[-10:]
   ....: 
10 loops, best of 3: 41.3 ms per loop

In [3]: %%timeit x = np.random.permutation(10000000)
np.sort(np.partition(x, -10)[-10:])
   ....: 
10 loops, best of 3: 78.8 ms per loop
like image 163
ali_m Avatar answered Dec 08 '22 01:12

ali_m


Why not do something like:

np.sort(m)[:,-N:]
like image 33
Jules Avatar answered Dec 08 '22 01:12

Jules


partition, sort, argsort etc take an axis parameter

Let's shuffle some values

In [161]: A=np.arange(24)

In [162]: np.random.shuffle(A)

In [163]: A=A.reshape(4,6)

In [164]: A
Out[164]: 
array([[ 1,  2,  4, 19, 12, 11],
       [20,  5, 13, 21, 22,  3],
       [10,  6, 16, 18, 17,  8],
       [23,  9,  7,  0, 14, 15]])

Partition:

In [165]: A.partition(4,axis=1)

In [166]: A
Out[166]: 
array([[ 2,  1,  4, 11, 12, 19],
       [ 5,  3, 13, 20, 21, 22],
       [ 6,  8, 10, 16, 17, 18],
       [14,  7,  9,  0, 15, 23]])

the 4 smallest values of each row are first, the 2 largest last; slice to get an array of the 2 largest:

In [167]: A[:,-2:]
Out[167]: 
array([[12, 19],
       [21, 22],
       [17, 18],
       [15, 23]])

Sort is probably slower, but on a small array like this probably doesn't matter much. Plus it lets you pick any N.

In [169]: A.sort(axis=1)

In [170]: A
Out[170]: 
array([[ 1,  2,  4, 11, 12, 19],
       [ 3,  5, 13, 20, 21, 22],
       [ 6,  8, 10, 16, 17, 18],
       [ 0,  7,  9, 14, 15, 23]])
like image 28
hpaulj Avatar answered Dec 08 '22 01:12

hpaulj