Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does accumulate work for numpy.maximum but not numpy.argmax

These two look like they should be very much equivalent and therefore what works for one should work for the other? So why does accumulate only work for maximum but not argmax?

EDIT: A natural follow-up question is then how does one go about creating an efficient argmax accumulate in the most pythonic/numpy-esque way?

like image 360
RAY Avatar asked Oct 24 '25 15:10

RAY


2 Answers

Because max is associative, but argmax is not:

  • max(a, max(b, c)) == max(max(a, b), c)
  • argmax(a, argmax(b, c)) != argmax(argmax(a, b), c)
like image 96
Eric Avatar answered Oct 27 '25 05:10

Eric


Is this the kind of argmax accumulate you want?

sample array:

In [135]: a
Out[135]: array([4, 6, 5, 1, 4, 4, 2, 0, 8, 4])

the maximum that you already got:

In [136]: am=np.maximum.accumulate(a)    
In [137]: am
Out[137]: array([4, 6, 6, 6, 6, 6, 6, 6, 8, 8], dtype=int32)

In [138]: a1=np.zeros_like(a)

identify the elements where the am jumped. np.diff would have also worked:

In [139]: ind=np.nonzero(a==am)[0]

In [140]: ind
Out[140]: array([0, 1, 8], dtype=int32)

In [141]: a1[ind]=ind    
In [142]: a1
Out[142]: array([0, 1, 0, 0, 0, 0, 0, 0, 8, 0])

In [143]: np.maximum.accumulate(a1)
Out[143]: array([0, 1, 1, 1, 1, 1, 1, 1, 8, 8], dtype=int32)

Alternate way of find ind - looking for the jumps in am

In [149]: ind=np.nonzero(np.diff(am))

In [150]: ind = np.concatenate([[0],ind[0]+1])

In [151]: ind
Out[151]: array([0, 1, 8])
like image 39
hpaulj Avatar answered Oct 27 '25 06:10

hpaulj



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!