Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to nullify all entries except for argmax?

Assuming I have a matrix / array / list like a=[1,2,3,4,5] and I want to nullify all entries except for the max so it would be a=[0,0,0,0,5].

I'm using b = [val if idx == np.argmax(a) else 0 for idx,val in enumerate(a)] but is there a better (and faster) way (especially for more than 1-dim arrays...)

like image 350
CIsForCookies Avatar asked Dec 14 '25 10:12

CIsForCookies


2 Answers

You can use numpy for an in-place solution. Note that the below method will make all matches for the max value equal to 0.

import numpy as np

a = np.array([1,2,3,4,5])

a[np.where(a != a.max())] = 0

# array([0, 0, 0, 0, 5])

For unique maxima, see @cᴏʟᴅsᴘᴇᴇᴅ's solution.

like image 59
jpp Avatar answered Dec 16 '25 00:12

jpp


Rather than masking, you can create an array of zeros and set the right index appropriately?

1-D (optimised) Solution

(Setup) Convert a to a 1D array: a = np.array([1,2,3,4,5]).

  1. To replace just one instance of the max

    b = np.zeros_like(a)
    i = np.argmax(a)
    b[i] = a[i]
    
  2. To replace all instances of the max

    b = np.zeros_like(a)
    m = a == a.max()
    b[m] = a[m]
    

N-D solution

np.random.seed(0)
a = np.random.randn(5, 5)

b = np.zeros_like(a)
m = a == a.max(1, keepdims=True)
b[m] = a[m]

b
array([[0.        , 0.        , 0.        , 2.2408932 , 0.        ],
       [0.        , 0.95008842, 0.        , 0.        , 0.        ],
       [0.        , 1.45427351, 0.        , 0.        , 0.        ],
       [0.        , 1.49407907, 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 2.26975462]])

Works for all instances of max per row.

like image 40
cs95 Avatar answered Dec 15 '25 22:12

cs95



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!