Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Wrapping np.arrays __pow__ method

I was just revisiting some of my code to improve the performance and stumpled over something strange:

a = np.linspace(10,1000,1000000).reshape(1000,1000)

%timeit np.square(a)
100 loops, best of 3: 8.07 ms per loop

%timeit a*a
100 loops, best of 3: 8.18 ms per loop

%timeit a**2
100 loops, best of 3: 8.32 ms per loop

Ok it seems to have some overhead when using the power-operator (**) but otherwise they seem identical (I guess NumPy is doing that) but then it got strange:

In [46]: %timeit np.power(a, 2)
10 loops, best of 3: 121 ms per loop

So there is no problem but it seems a bit inconsistent to have a fallback for the magic pow but not for the UFUNC. But then I got interested since I am using third powers a lot:

%timeit a*a*a
100 loops, best of 3: 18.1 ms per loop

%timeit a**3
10 loops, best of 3: 121 ms per loop

%timeit np.power(a, 3)
10 loops, best of 3: 121 ms per loop

There seems to be no "shortcut" in the third power and UFUNC and 'magic-pow' work the same (at least in regard to performance).

But that's not that good since I want a consistent method of using powers in my code and I'm not quite sure how to wrap the __pow__ of numpy.

So to get to the point, my question is :

Is there a way I can wrap the numpys __pow__ method? Because I want a consistent way of writing powers in my script not writing a**2 and at another place power(a, 3). Simply writing a**3, and redirecting this to my power function, would be preferred (but for that I would need to somehow wrap the ndarrays __pow__ or?). Currently I am using a shortcut but that's not that beautiful (I even have to declare the exponent==2 case since np.power performs not optimal there):

def power(array, exponent):
    if exponent == 2: #catch this, or it calls the slow np.power(array, exponent)
        return np.square(array)
    if exponent == 3:
        return array * array * array
    #As soon as np.cbrt is avaiable catch the exponent 4/3 here too
    return np.power(array, exponent) 

%timeit power(a, 3)
100 loops, best of 3: 17.8 ms per loop
%timeit a**3
10 loops, best of 3: 121 ms per loop

I am using NumPy v1.9.3 and I do not want to subclass np.ndarray just for wrapping the __pow__ method. :-)

EDIT: I rewrote the part where I get to my question. To clarify it: I am not asking about why NumPy does it the way it does - that is just to explain why I ask the question.

like image 745
MSeifert Avatar asked Oct 02 '15 09:10

MSeifert


2 Answers

This is a good catch. I too wonder why is that behavior. But to be short and concise answering the question, I would just do:

def mypower(array, exponent):
    return reduce(lambda x,y: x*y, [array for _ in range(exponent)])


%timeit mypower(a,2)
100 loops, best of 3: 3.68 ms per loop

%timeit mypower(a,3)
100 loops, best of 3: 8.09 ms per loop

%timeit mypower(a,4)
100 loops, best of 3: 12.6 ms per loop

Obsviouly the overhead increases with the exponent but for low ones is better than 10x the time.

Note that this is different from the original numpy implementation which is not specific for a numeric exponent and supports an array of exponents as the second argument (check it out here).

Overloading the operator

The way to do what you want is to subclass ndarray and use views. See the following example:

import numexpr
import numpy as np
​
class MyArray(np.ndarray):
    def __pow__(self, other):
        return reduce(lambda x,y: x*y, [self for _ in range(other)])
​
class NumExprArray(np.ndarray):
    def __pow__(self, other):
        return numexpr.evaluate("self**%f" % other)
        #This implies extra overhead, is as much as 4x slower:
        #return numexpr.evaluate("self**other")

a = np.linspace(10,1000,1000000).reshape(1000,1000).view(MyArray)
na = np.linspace(10,1000,1000000).reshape(1000,1000).view(NumExprArray)
​
%timeit a**2
1000 loops, best of 3: 1.2 ms per loop

%timeit na**2
1000 loops, best of 3: 1.14 ms per loop

%timeit a**3
100 loops, best of 3: 4.69 ms per loop

%timeit na**3
100 loops, best of 3: 2.36 ms per loop

%timeit a**4
100 loops, best of 3: 6.59 ms per loop

%timeit na**4
100 loops, best of 3: 2.4 ms per loop

For more information on this method please follow this link. Another way would be to use a custom infix operator but for readability purposes is not so good. As one can see, numexpr should be the way to go.

like image 112
rll Avatar answered Nov 11 '22 20:11

rll


If I read the source correctly, when numpy performs power, it checks whether the numerical value of the exponent is one of the special cases (-0.5, 0, 0.5, 1, and 2). If so, the operation is done using special routines. All other numerical values of the exponent are considered "general", and will be fed into the generic power function, which may be slow (especially if the exponent is promoted to floating-point type, but I'm not sure if this is the case with a ** 3).

like image 3
Cong Ma Avatar answered Nov 11 '22 22:11

Cong Ma