Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy, change max value in each row to 1 without changing others

Tags:

python

numpy

I'm trying to change max value of each rows to 1 and leave others.

Each values is between 0 to 1.

I want to change this

>>> a = np.array([[0.5, 0.2, 0.1], 
...               [0.6, 0.3, 0.8], 
...               [0.3, 0.4, 0.2]])

into this

>>> new_a = np.array([[1, 0.2, 0.1],
...                   [0.6, 0.3, 1],
...                   [0.3, 1, 0.2]])

Is there any good solution for this problem using np.where maybe? (without using for loop)

like image 358
Jeight An Avatar asked Dec 22 '22 15:12

Jeight An


1 Answers

Use np.argmax and slice assignment:

>>> a[np.arange(len(a)), np.argmax(a, axis=1)] = 1
>>> a
array([[1. , 0.2, 0.1],
       [1. , 0.3, 0.6],
       [1. , 0.3, 0.2]])
>>> 
like image 81
U12-Forward Avatar answered Dec 28 '22 11:12

U12-Forward