Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy: what is the logic of the argmin() and argmax() functions?

People also ask

What is argmin and argmax in Python?

argmax( a , axis =-1) argmin( a , axis =-1) argmax returns a new integer array m whose shape tuple is a . shape minus the indicated axis . Each element of m is the index of a maximal element of a along axis . argmin is similar, but indicates minimal elements rather than maximal ones.

What does argmin do in numpy?

Returns the indices of the minimum values along an axis. Input array. By default, the index is into the flattened array, otherwise along the specified axis.

What does argmax do in numpy?

The numpy. argmax() function returns indices of the max element of the array in a particular axis. Return : Array of indices into the array with same shape as array.

What does the argmin function do?

ArgMin is typically used to find the smallest possible values given constraints. In different areas, this may be called the best strategy, best fit, best configuration and so on. If f and cons are linear or polynomial, ArgMin will always find a global minimum.


By adding the axis argument, NumPy looks at the rows and columns individually. When it's not given, the array a is flattened into a single 1D array.

axis=0 means that the operation is performed down the columns of a 2D array a in turn.

For example np.argmin(a, axis=0) returns the index of the minimum value in each of the four columns. The minimum value in each column is shown in bold below:

>>> a
array([[ 1,  2,  4,  7],  # 0
       [ 9, 88,  6, 45],  # 1
       [ 9, 76,  3,  4]]) # 2

>>> np.argmin(a, axis=0)
array([0, 0, 2, 2])

On the other hand, axis=1 means that the operation is performed across the rows of a.

That means np.argmin(a, axis=1) returns [0, 2, 2] because a has three rows. The index of the minimum value in the first row is 0, the index of the minimum value of the second and third rows is 2:

>>> a
#        0   1   2   3
array([[ 1,  2,  4,  7],
       [ 9, 88,  6, 45],
       [ 9, 76,  3,  4]])

>>> np.argmin(a, axis=1)
array([0, 2, 2])

The np.argmax function by default works along the flattened array, unless you specify an axis. To see what is happening you can use flatten explicitly:

np.argmax(a)
>>> 5

a.flatten()
>>>> array([ 1,  2,  4,  7,  9, 88,  6, 45,  9, 76,  3,  4])
             0   1   2   3   4   5 

I've numbered the indices under the array above to make it clearer. Note that indices are numbered from zero in numpy.

In the cases where you specify the axis, it is also working as expected:

np.argmax(a,axis=0)
>>> array([1, 1, 1, 1])

This tells you that the largest value is in row 1 (2nd value), for each column along axis=0 (down). You can see this more clearly if you change your data a bit:

a=np.array([[100,2,4,7],[9,88,6,45],[9,76,3,100]])
a
>>> array([[100,   2,   4,   7],
           [  9,  88,   6,  45],
           [  9,  76,   3, 100]])

np.argmax(a, axis=0)
>>> array([0, 1, 1, 2])

As you can see it now identifies the maximum value in row 0 for column 1, row 1 for column 2 and 3 and row 3 for column 4.

There is a useful guide to numpy indexing in the documentation.


As a side note: if you want to find the coordinates of your maximum value in the full array, you can use

a=np.array([[1,2,4,7],[9,88,6,45],[9,76,3,4]])
>>> a
[[ 1  2  4  7]
 [ 9 88  6 45]
 [ 9 76  3  4]]

c=(np.argmax(a)/len(a[0]),np.argmax(a)%len(a[0]))
>>> c
(1, 1)

""" ....READ THE COMMENTS FOR CLARIFICATION....."""

import numpy as np
a = np.array([[1,2,4,7], [9,88,6,45], [9,76,3,4]])

"""np.argmax(a) will give index of max value in flatted array of given matrix """
>>np.argmax(a)
5

"""np.argmax(a,axis=0) will return list of indexes of  max value column-wise"""
>>print(np.argmax(a,axis=0))
[1,1,1,1]

"""np.argmax(a,axis=1) will return list of indexes of  max value row-wise"""
>>print(np.argmax(a,axis=1))
[3,1,1]

"""np.argmin(a) will give index of min value in flatted array of given matrix """
>>np.argmin(a)
0

"""np.argmin(a,axis=0) will return list of indexes of  min value column-wise"""
>>print(np.argmin(a,axis=0))
[0,0,2,2]

"""np.argmin(a,axis=0) will return list of indexes of  min value row-wise"""
>>print(np.argmin(a,axis=1))
[0,2,2]

The axis in the argmax function argument, refers to the axis along which the array will be sliced.

In another word, np.argmin(a,axis=0) is effectively the same as np.apply_along_axis(np.argmin, 0, a), that is to find out the minimum location for these sliced vectors along the axis=0.

Therefore in your example, np.argmin(a, axis=0) is [0, 0, 2, 2] which corresponding to values of [1, 2, 3, 4] on respective columns