Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

python numpy argmax to max in multidimensional array

I have the following code:

import numpy as np
sample = np.random.random((10,10,3))
argmax_indices = np.argmax(sample, axis=2)

i.e. I take the argmax along axis=2 and it gives me a (10,10) matrix. Now, I want to assign these indices value 0. For this, I want to index the sample array. I tried:

max_values = sample[argmax_indices]

but it doesn't work. I want something like

max_values = sample[argmax_indices]
sample[argmax_indices] = 0

I simply validate by checking that max_values - np.max(sample, axis=2) should give a zero matrix of shape (10,10). Any help will be appreciated.

like image 593
Rayyan Riaz Avatar asked Feb 28 '17 21:02

Rayyan Riaz


People also ask

How do you find the maximum value of a 2D NumPy array?

You can use argmax() to get the index of your maximum value. Then you just have to compute this value to get the line and column indices.

What does NP argmax () do?

The numpy. argmax() function returns indices of the max element of the array in a particular axis.

What is the difference between NP Max and NP argmax?

What is this? Essentially, the argmax function returns the index of the maximum value of a Numpy array. It's somewhat similar to the Numpy maximum function, but instead of returning the maximum value, it returns the index of the maximum value.


1 Answers

Here's one approach -

m,n = sample.shape[:2]
I,J = np.ogrid[:m,:n]
max_values = sample[I,J, argmax_indices]
sample[I,J, argmax_indices] = 0

Sample step-by-step run

1) Sample input array :

In [261]: a = np.random.randint(0,9,(2,2,3))

In [262]: a
Out[262]: 
array([[[8, 4, 6],
        [7, 6, 2]],

       [[1, 8, 1],
        [4, 6, 4]]])

2) Get the argmax indices along axis=2 :

In [263]: idx = a.argmax(axis=2)

3) Get the shape and arrays for indexing into first two dims :

In [264]: m,n = a.shape[:2]

In [265]: I,J = np.ogrid[:m,:n]

4) Index using I, J and idx for storing the max values using advanced-indexing :

In [267]: max_values = a[I,J,idx]

In [268]: max_values
Out[268]: 
array([[8, 7],
       [8, 6]])

5) Verify that we are getting an all zeros array after subtracting np.max(a,axis=2) from max_values :

In [306]: max_values - np.max(a, axis=2)
Out[306]: 
array([[0, 0],
       [0, 0]])

6) Again using advanced-indexing assign those places as zeros and do one more level of visual verification :

In [269]: a[I,J,idx] = 0

In [270]: a
Out[270]: 
array([[[0, 4, 6], # <=== Compare this against the original version
        [0, 6, 2]],

       [[1, 0, 1],
        [4, 0, 4]]])
like image 107
Divakar Avatar answered Sep 19 '22 21:09

Divakar