Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make numpy.argmax return all occurrences of the maximum?

Tags:

python

max

numpy

I'm trying to find a function that returns all occurrences of the maximum in a given list.

numpy.argmax however only returns the first occurrence that it finds. For instance:

from numpy import argmax

list = [7, 6, 5, 7, 6, 7, 6, 6, 6, 4, 5, 6]
winner = argmax(list)

print winner

gives only index 0. But I want it to give all indices: 0, 3, 5.

like image 996
Marieke_W Avatar asked Jul 10 '13 10:07

Marieke_W


People also ask

Can argmax return multiple values?

argmax() function returns the indices of the maximum values along an axis. In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence will be returned.

What does Arg max return?

The argmax function returns the argument or arguments (arg) for the target function that returns the maximum (max) value from the target function.

What is the difference between NP Max and NP argmax?

Essentially, the argmax function returns the index of the maximum value of a Numpy array. What is this? It's somewhat similar to the Numpy maximum function, but instead of returning the maximum value, it returns the index of the maximum value.


3 Answers

As documentation of np.argmax says: "In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.", so you will need another strategy.

One option you have is using np.argwhere in combination with np.amax:

>>> import numpy as np
>>> listy = [7, 6, 5, 7, 6, 7, 6, 6, 6, 4, 5, 6]
>>> winner = np.argwhere(listy == np.amax(listy))
>>> print(winner)
 [[0]
  [3]
  [5]]
>>> print(winner.flatten().tolist()) # if you want it as a list
[0, 3, 5]
like image 89
jabaldonedo Avatar answered Oct 19 '22 20:10

jabaldonedo


In case it matters, the following algorithm runs in O(n) instead of O(2n) (i.e., using np.argmax and then np.argwhere):

def allmax(a):
    if len(a) == 0:
        return []
    all_ = [0]
    max_ = a[0]
    for i in range(1, len(a)):
        if a[i] > max_:
            all_ = [i]
            max_ = a[i]
        elif a[i] == max_:
            all_.append(i)
    return all_
like image 9
EricRobertBrewer Avatar answered Oct 19 '22 20:10

EricRobertBrewer


It is even easier, when compared to other answers, if you use np.flatnonzero:

>>> import numpy as np
>>> your_list = np.asarray([7, 6, 5, 7, 6, 7, 6, 6, 6, 4, 5, 6])
>>> winners = np.flatnonzero(your_list == np.max(your_list))
>>> winners
array([0, 3, 5])

If you want a list:

>>> winners.tolist()
[0, 3, 5]
like image 8
Lightspark Avatar answered Oct 19 '22 19:10

Lightspark