Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy: change max in each row to 1, all other numbers to 0

Tags:

python

numpy

I'm trying to implement a numpy function that replaces the max in each row of a 2D array with 1, and all other numbers with zero:

>>> a = np.array([[0, 1], ...               [2, 3], ...               [4, 5], ...               [6, 7], ...               [9, 8]]) >>> b = some_function(a) >>> b [[0. 1.]  [0. 1.]  [0. 1.]  [0. 1.]  [1. 0.]] 

What I've tried so far

def some_function(x):     a = np.zeros(x.shape)     a[:,np.argmax(x, axis=1)] = 1     return a  >>> b = some_function(a) >>> b [[1. 1.]  [1. 1.]  [1. 1.]  [1. 1.]  [1. 1.]] 
like image 268
MikeRand Avatar asked Nov 30 '13 00:11

MikeRand


People also ask

What does .all do in NumPy?

all() in Python. The numpy. all() function tests whether all array elements along the mentioned axis evaluate to True.

How do I limit the value of a NumPy array?

To limit the values of the NumPy array ndarray to given range, use np. clip() or clip() method of ndarray . By specifying the minimum and maximum values in the argument, the out-of-range values are replaced with those values. This is useful when you want to limit the values to a range such as 0.0 ~ 1.0 or 0 ~ 255 .

How does NumPy maximum work?

maximum() function is used to find the element-wise maximum of array elements. It compares two arrays and returns a new array containing the element-wise maxima. If one of the elements being compared is a NaN, then that element is returned.

Why NumPy is faster than list?

NumPy Arrays are faster than Python Lists because of the following reasons: An array is a collection of homogeneous data-types that are stored in contiguous memory locations. On the other hand, a list in Python is a collection of heterogeneous data types stored in non-contiguous memory locations.


2 Answers

Method #1, tweaking yours:

>>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]]) >>> b = np.zeros_like(a) >>> b[np.arange(len(a)), a.argmax(1)] = 1 >>> b array([[0, 1],        [0, 1],        [0, 1],        [0, 1],        [1, 0]]) 

[Actually, range will work just fine; I wrote arange out of habit.]

Method #2, using max instead of argmax to handle the case where multiple elements reach the maximum value:

>>> a = np.array([[0, 1], [2, 2], [4, 3]]) >>> (a == a.max(axis=1)[:,None]).astype(int) array([[0, 1],        [1, 1],        [1, 0]]) 
like image 85
DSM Avatar answered Sep 29 '22 07:09

DSM


I prefer using numpy.where like so:

a[np.where(a==np.max(a))] = 1 
like image 34
Cyclone Avatar answered Sep 29 '22 06:09

Cyclone