Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy: how to use argmax results to get the actual max? [duplicate]

Tags:

python

numpy

Suppose I have a 3D array:

>>> a
array([[[7, 0],
        [3, 6]],

       [[2, 4],
        [5, 1]]])

I can get its argmax along axis=1 using

>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
       [1, 0]])

How can I use m as an index into a, so that the results are equivalent to simply using max?

>>> a.max(axis=1)
array([[7, 6],
       [5, 4]])

(This is useful when m is applied to other arrays of the same shape)

like image 454
MWB Avatar asked Oct 20 '17 00:10

MWB


People also ask

What does argmax () do in Python?

Returns the indices of the maximum values along an axis. Input array. By default, the index is into the flattened array, otherwise along the specified axis.

How do I find the maximum value of a NumPy number?

nanmax() to find the maximum values while ignoring nan values, as well as np. argmax() or . argmax() to find the indices of the maximum values. You won't be surprised to learn that NumPy has an equivalent set of minimum functions: np.

What does the argmax function return?

The argmax function returns the argument or arguments (arg) for the target function that returns the maximum (max) value from the target function.

What is the difference between Max and argmax?

The max function gives the largest possible value of f(x) for any x in the domain, which is the function value achieved by any element of the argmax. Unlike the argmax, the max function is unique since all elements of the argmax achieve the same value. However, the max may not exist because the argmax may be empty.


2 Answers

You can do this with advanced indexing and numpy broadcasting:

m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]

#array([[7, 6],
#       [5, 4]])

m = np.argmax(a, axis=1)

Create arrays of 1st, 2nd and 3rd dimensions indices:

ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
​

Because of the dimension mismatch, the three arrays will broadcast, which result in each to be as follows:

for x in np.broadcast_arrays(ind1, ind2, ind3):
    print(x, '\n')

#[[0 0]
# [1 1]] 

#[[0 1]
# [1 0]] 

#[[0 1]
# [0 1]] 

And since all indices are integer arrays, it triggers advanced indexing, so elements with indices (0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1) are picked up, i.e. one element from each array combined as the index.

like image 59
Psidom Avatar answered Oct 07 '22 06:10

Psidom


You can use np.ogrid to create a grid over all axis for your array except the one you reduced. And then just insert the argmax result at the position of your axis and index your array with the result:

>>> import numpy as np
>>> a = np.array([[[7, 0], [3, 6]], [[2, 4], [5, 1]]])
>>> axis = 1

>>> # Create the grid
>>> idx = list(np.ogrid[[slice(a.shape[ax]) for ax in range(a.ndim) if ax != axis]])
>>> argmaxes = np.argmax(a, axis=axis)
>>> idx.insert(axis, argmaxes)

>>> # Index the original array with the grid
>>> a[idx]
array([[7, 6],
       [5, 4]])

Note that this doesn't work for axis=None or in case you reduced over multiple axis.

like image 28
MSeifert Avatar answered Oct 07 '22 06:10

MSeifert