Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pythonic way for axis-wise winner-take-all in numpy

Tags:

python

numpy

I am wondering what the most concise and pythonic way to keep only the maximum element in each line of a 2D numpy array while setting all other elements to zeros. Example:

given the following numpy array:

a = [ [1, 8, 3 ,6],
      [5, 5, 60, 1],
      [63,9, 9, 23] ]

I want the answer to be:

b = [ [0, 8, 0, 0],
      [0, 0, 60, 0],
      [63,0, 0, 0 ] ]

I can think of several ways to solve that, but what interests me is whether there are python functions to so this just quickly

Thank you in advance

like image 802
Alan_AI Avatar asked Feb 09 '16 11:02

Alan_AI


People also ask

How do you use a ANY () or a all () in Python?

The any() method returns true if any of the list items are true, and the all() function returns true if all the list items are true. Often, when you're programming, you may want to check whether any or all of the values in a list evaluate to True.

What does all () NumPy?

all() in Python. The numpy. all() function tests whether all array elements along the mentioned axis evaluate to True.

How do you check if all elements in a NumPy array are equal?

array_equiv() Function. You can also use numpy. array_equiv() function to check whether two arrays are equal or not in Python. This function returns True if both arrays have the same shape and all the elements are equal and return False otherwise.


1 Answers

You can use np.max to take the maximum along one axis, then use np.where to zero out the non-maximal elements:

np.where(a == a.max(axis=1, keepdims=True), a, 0)

The keepdims=True argument keeps the singleton dimension after taking the max (i.e. so that a.max(1, keepdims=True).shape == (3, 1)), which simplifies broadcasting it against a.

like image 96
ali_m Avatar answered Oct 06 '22 00:10

ali_m