Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numerical Python - how do I make this a ufunc?

Tags:

python

numpy

new to NumPy and may not be searching properly, so I'll take the lumps if this is a common question...

I'm working on a problem where I need to calculate log(n!) for relatively large numbers - ie. to large to calculate the factorial first, so I've written the following function:

def log_fact(n):
    x = 0
    for i in range(1,n+1):
        x += log(i)
    return x

Now the problem is that I want to use this as part of the function passed to curve_fit:

def logfactfunc(x, a, b, c):
    return a*log_fact(x) + b*x + c

from scipy.optimize import curve_fit

curve_fit(logfactfunc, x, y)

However, this produces the following error:

File "./fit2.py", line 16, in log_fact
    for i in range(1,n+1):
TypeError: only length-1 arrays can be converted to Python scalars

A little searching suggested numpy.frompyfunc() to convert this to a ufunc

curve_fit(np.frompyfunc(logfactfunc, 1, 1), data[k].step, data[k].sieve)

TypeError: <ufunc 'logfactfunc (vectorized)'> is not a Python function

Tried this as well:

def logfactfunc(x, a, b, c):
    return a*np.frompyfunc(log_fact, 1, 1)(x) + b*x + c

File "./fit2.py", line 30, in logfactfunc
    return a*np.frompyfunc(log_fact, 1, 1)(x) + b*x + c
TypeError: unsupported operand type(s) for +: 'numpy.ndarray' and 'numpy.float64

Any ideas on how I can get my log_fact() function to be used within the curve_fit() function??

Thanks!

like image 550
CoAstroGeek Avatar asked Dec 20 '22 19:12

CoAstroGeek


2 Answers

Your log_fact function is being called with arrays as input parameters, which is what's throwing off your method. A possible way of vectorizing your code would be the following:

def log_fact(n):
    n = np.asarray(n)
    m = np.max(n)
    return np.take(np.cumsum(np.log(np.arange(1, m+1))), n-1)

Taking it for a test ride:

>>> log_fact(3)
1.791759469228055
>>> log_fact([10, 15, 23])
array([ 15.10441257,  27.89927138,  51.60667557])
>>> log_fact([[10, 15, 23], [14, 15, 8]])
array([[ 15.10441257,  27.89927138,  51.60667557],
       [ 25.19122118,  27.89927138,  10.6046029 ]])

The only caveat with this approach is that it stores an array as long as the largest value you call it with. If your n gets into the billions, you'll probably break it. Other than that, it actually avoids repeated calculations if you call it with many values.

like image 26
Jaime Avatar answered Jan 02 '23 02:01

Jaime


Your log_fact function is closely related to the gammaln function, which is defined as a ufunc in scipy.special. Specifically, log_fact(n) == scipy.special.gammaln(n+1). For even modest values of n, this is significantly faster:

In [15]: %timeit log_fact(19)
10000 loops, best of 3: 24.4 us per loop
In [16]: %timeit scipy.special.gammaln(20)
1000000 loops, best of 3: 1.13 us per loop

Also, the running time of gammaln is independent of n, unlike log_fact.

like image 123
Ray Avatar answered Jan 02 '23 02:01

Ray