Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Make all the elemant zero except max n element in each row in numpy 2d-array

Let's say I have a 2D numpy array like:

[[1, 8, 3, 4, 5],
 [10,13,13,12,15]]

I want to convert all the elements to zero except the maximum or maximum 2 elements in each row. The output I'd like to get is:

[[0, 8, 0,0, 5],
 [0,13,13,0,15]]

How can I do this?

like image 398
user2955839 Avatar asked Oct 16 '22 14:10

user2955839


1 Answers

with only one line of code and without loops:

a = np.asarray([[1, 8, 3, 4, 5],
                [10,13,13,12,15]])
n_max = 2 # second max element

a * (a >= np.sort(a, axis=1)[:,[-n_max]]).astype(int)

output:

array([[ 0,  8,  0,  0,  5],
       [ 0, 13, 13,  0, 15]])

to do the same with columns...

a = np.asarray([[1, 8, 3, 4, 5],
                [10,13,13,12,15],
                [6, 9, 12,7, 7],
                [3, 8, 5 ,22, 2]])
n_max = 2 # second max element

a * (a >= np.sort(a, axis=0)[[-n_max],:]).astype(int)

output:

array([[ 0,  0,  0,  0,  0],
       [10, 13, 13, 12, 15],
       [ 6,  9, 12,  0,  7],
       [ 0,  0,  0, 22,  0]])
like image 164
Marco Cerliani Avatar answered Nov 15 '22 06:11

Marco Cerliani