Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

cumulative argmax of a numpy array

Consider the array a

np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))
a

array([[0, 2],
       [7, 3],
       [8, 7],
       [0, 6],
       [8, 6],
       [0, 2],
       [0, 4],
       [9, 7],
       [3, 2],
       [4, 3]])

What is a vectorized way to get the cumulative argmax?

array([[0, 0],  <-- both start off as max position
       [1, 1],  <-- 7 > 0 so 1st col = 1, 3 > 2 2nd col = 1
       [2, 2],  <-- 8 > 7 1st col = 2, 7 > 3 2nd col = 2
       [2, 2],  <-- 0 < 8 1st col stays the same, 6 < 7 2nd col stays the same
       [2, 2],  
       [2, 2],
       [2, 2],
       [7, 2],  <-- 9 is new max of 2nd col, argmax is now 7
       [7, 2],
       [7, 2]])

Here is a non-vectorized way to do it.

Notice that as the window expands, argmax applies to the growing window.

pd.DataFrame(a).expanding().apply(np.argmax).astype(int).values

array([[0, 0],
       [1, 1],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [7, 2],
       [7, 2],
       [7, 2]])
like image 333
piRSquared Avatar asked Nov 18 '16 08:11

piRSquared


People also ask

What does argmax do in NumPy?

Returns the indices of the maximum values along an axis. Input array. By default, the index is into the flattened array, otherwise along the specified axis.

How do you find the index of Max element in a NumPy array?

Use numpy. argmax() to find the index of the max value in a NumPy array. To find the index of max value for a specific axis, specify the `axis` keyword argument in [`np. argmax(a, axis=None)`](kite-sym:numpy.

How do you find the max and min of an array in Python?

In python is very easy to find out maximum, minimum element and their position also. Python provides different inbuilt function. min() is used for find out minimum value in an array, max() is used for find out maximum value in an array. index() is used for finding the index of the element.

How do I find argmax in Python?

Get Index of the Maximum Value of a List With the numpy. argmax() Function in Python. The numpy. argmax() function in the NumPy package gives us the index of the maximum value in the list or array passed as an argument to the function.


1 Answers

Here's a vectorized pure NumPy solution that performs pretty snappily:

def cumargmax(a):
    m = np.maximum.accumulate(a)
    x = np.repeat(np.arange(a.shape[0])[:, None], a.shape[1], axis=1)
    x[1:] *= m[:-1] < m[1:]
    np.maximum.accumulate(x, axis=0, out=x)
    return x

Then we have:

>>> cumargmax(a)
array([[0, 0],
       [1, 1],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [7, 2],
       [7, 2],
       [7, 2]])

Some quick testing on arrays with thousands to millions of values suggests that this is anywhere between 10-50 times faster than looping at the Python level (either implicitly or explicitly).

like image 87
Alex Riley Avatar answered Sep 18 '22 21:09

Alex Riley