Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Faster alternatives to numpy.argmax/argmin which is slow

Tags:

python

numpy

I am using a lot of argmin and argmax in Python.

Unfortunately, the function is very slow.

I have done some searching around, and the best I can find is here:

http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/

def fastest_argmax(array):
    array = list( array )
    return array.index(max(array))

Unfortunately, this solution is still only half as fast as np.max, and I think I should be able to find something as fast as np.max.

x = np.random.randn(10)
%timeit np.argmax( x )
10000 loops, best of 3: 21.8 us per loop

%timeit fastest_argmax( x )    
10000 loops, best of 3: 20.8 us per loop

As a note, I am applying this to a Pandas DataFrame Groupby

E.G.

%timeit grp2[ 'ODDS' ].agg( [ fastest_argmax ] )
100 loops, best of 3: 8.8 ms per loop

%timeit grp2[ 'ODDS' ].agg( [ np.argmax ] )
100 loops, best of 3: 11.6 ms per loop

Where grp2[ 'ODDS' ].head() looks like this:

EVENT_ID   SELECTION_ID        
104601100  4367029       682508    3.05
                         682509    3.15
                         682510    3.25
                         682511    3.35
           5319660       682512    2.04
                         682513    2.08
                         682514    2.10
                         682515    2.12
                         682516    2.14
           5510310       682520    4.10
                         682521    4.40
                         682522    4.50
                         682523    4.80
                         682524    5.30
           5559264       682526    5.00
                         682527    5.30
                         682528    5.40
                         682529    5.50
                         682530    5.60
           5585869       682533    1.96
                         682534    1.97
                         682535    1.98
                         682536    2.02
                         682537    2.04
           6064546       682540    3.00
                         682541    2.74
                         682542    2.76
                         682543    2.96
                         682544    3.05
104601200  4916112       682548    2.64
                         682549    2.68
                         682550    2.70
                         682551    2.72
                         682552    2.74
           5315859       682557    2.90
                         682558    2.92
                         682559    3.05
                         682560    3.10
                         682561    3.15
           5356995       682564    2.42
                         682565    2.44
                         682566    2.48
                         682567    2.50
                         682568    2.52
           5465225       682573    1.85
                         682574    1.89
                         682575    1.91
                         682576    1.93
                         682577    1.94
           5773661       682588    5.00
                         682589    4.40
                         682590    4.90
                         682591    5.10
           6013187       682592    5.00
                         682593    4.20
                         682594    4.30
                         682595    4.40
                         682596    4.60
104606300  2489827       683438    4.00
                         683439    3.90
                         683440    3.95
                         683441    4.30
                         683442    4.40
           3602724       683446    2.16
                         683447    2.32
Name: ODDS, Length: 65, dtype: float64
like image 604
Ginger Avatar asked Nov 08 '14 11:11

Ginger


3 Answers

It turns out that np.argmax is blazingly fast, but only with the native numpy arrays. With foreign data, almost all the time is spent on conversion:

In [194]: print platform.architecture()
('64bit', 'WindowsPE')

In [5]: x = np.random.rand(10000)
In [57]: l=list(x)
In [123]: timeit numpy.argmax(x)
100000 loops, best of 3: 6.55 us per loop
In [122]: timeit numpy.argmax(l)
1000 loops, best of 3: 729 us per loop
In [134]: timeit numpy.array(l)
1000 loops, best of 3: 716 us per loop

I called your function "inefficient" because it first converts everything to list, then iterates through it 2 times (effectively, 3 iterations + list construction).

I was going to suggest something like this that only iterates once:

def imax(seq):
    it=iter(seq)
    im=0
    try: m=it.next()
    except StopIteration: raise ValueError("the sequence is empty")
    for i,e in enumerate(it,start=1):
        if e>m:
            m=e
            im=i
    return im

But, your version turns out to be faster because it iterates many times but does it in C, rather that Python, code. C is just that much faster - even considering the fact a great deal of time is spent on conversion, too:

In [158]: timeit imax(x)
1000 loops, best of 3: 883 us per loop
In [159]: timeit fastest_argmax(x)
1000 loops, best of 3: 575 us per loop

In [174]: timeit list(x)
1000 loops, best of 3: 316 us per loop
In [175]: timeit max(l)
1000 loops, best of 3: 256 us per loop
In [181]: timeit l.index(0.99991619010758348)  #the greatest number in my case, at index 92
100000 loops, best of 3: 2.69 us per loop

So, the key knowledge to speeding this up further is to know which format the data in your sequence natively is (e.g. whether you can omit the conversion step or use/write another functionality native to that format).

Btw, you're likely to get some speedup by using aggregate(max_fn) instead of agg([max_fn]).

like image 140
ivan_pozdeev Avatar answered Sep 19 '22 13:09

ivan_pozdeev


For those that came for a short numpy-free snippet that returns the index of the first minimum value:

def argmin(a):
    return min(range(len(a)), key=lambda x: a[x])
a = [6, 5, 4, 1, 1, 3, 2]
argmin(a)  # returns 3
like image 37
Alex Lamson Avatar answered Sep 19 '22 13:09

Alex Lamson


Can you post some code? Here is the result on my pc:

x = np.random.rand(10000)
%timeit np.max(x)
%timeit np.argmax(x)

output:

100000 loops, best of 3: 7.43 µs per loop
100000 loops, best of 3: 11.5 µs per loop
like image 28
HYRY Avatar answered Sep 19 '22 13:09

HYRY