Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numba jit with scipy

So I wanted to speed up a program I wrote with the help of numba jit. However jit seems to be not compatible with many scipy functions because they use try ... except ... structures that jit cannot handle (Am I right with this point?)

A relatively simple solution I came up with is to copy the scipy source code I need and delete the try except parts (I already know that it will not run into errors so the try part will always work anyways)

However I do not like this solution and I am not sure if it will work.

My code structure looks like the following

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=(0,0,0), maxfev=500)
        for idx in some_list:
            integrated = integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
    except:
        fit_param=(0,0,0)
        ...

Now this results in the following error:

LoweringError: Failed at object (object mode backend)

I assume this is due to jit not being able to handle try except (it also does not work if I only put jit on the curve_fit and integrate.quad parts and work around my own try except structure)

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def integral(lower, upper):
    return integrate.quad(lambda x: fitfunction(fit_param), lower, upper)

@jit
def fitting(x, y, pzero, max_fev)
    return curve_fit(fitfunction, x, y, p0=pzero, maxfev=max_fev)


def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = fitting(x, y, (0,0,0), 500)
        for idx in some_list:
            integrated = integral(lower, upper)
    except:
        fit_param=(0,0,0)
        ...

Is there a way to use jit with scipy.integrate.quad and curve_fit without manually deleting all try except structures from the scipy code?

And would it even speed up the code?

like image 652
Katermickie Avatar asked Mar 23 '19 19:03

Katermickie


2 Answers

Nowadays try and except work with numba. However numba and scipy are still not compatible. Yes, Scipy calls compiled C and Fortran, but it does so in a way that numba can't deal with.

Fortunately there are alternatives to scipy that work well with numba! Below I use NumbaQuadpack and NumbaMinpack to do some curve fitting and integration similar to your example code. Disclaimer: i put together these packages. Below, I also give an equivalent implementation in scipy.

The Scipy implementation is ~18 times slower than the Scipy alternatives (NumbaQuadpack and NumbaMinpack).

Using Scipy alternatives (0.23 ms)

from NumbaQuadpack import quadpack_sig, dqags
from NumbaMinpack import minpack_sig, lmdif
import numpy as np
import numba as nb
import timeit
np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

@nb.cfunc(minpack_sig)
def fitfunction_optimize(u_, fvec, args_):
    u = nb.carray(u_,(2,))
    args = nb.carray(args_,(200,))
    A, B = u
    x = args[:100]
    y = args[100:]
    for i in range(100):
        fvec[i] = fitfunction(x[i], A, B) - y[i] 
optimize_ptr = fitfunction_optimize.address

@nb.cfunc(quadpack_sig)
def fitfunction_integrate(x, data):
    A = data[0]
    B = data[1]
    return fitfunction(x, A, B)
integrate_ptr = fitfunction_integrate.address

@nb.njit
def fast_function():  
    try:
        neqs = 100
        u_init = np.array([2.0,.8],np.float64)
        args = np.append(x,y)
        fitparam, fvec, success, info = lmdif(optimize_ptr , u_init, neqs, args)
        if not success:
            raise Exception

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr, success = dqags(integrate_ptr, lower, uppers[i], data = fitparam)
            if not success:
                raise Exception
    except:
        print('doing something else')
        
fast_function()
iters = 1000
t_nb = timeit.Timer(fast_function).timeit(number=iters)/iters
print(t_nb)

Using Scipy (4.4 ms)

import scipy.integrate as integrate
from scipy.optimize import curve_fit
import numpy as np
import numba as nb
import timeit

np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

def function():
    try:
        p0 = (2.0,.8)
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=p0, maxfev=500)

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr = integrate.quad(fitfunction, lower, uppers[i], args = tuple(fit_param))
    except:
        print('do something else')

function()
iters = 1000
t_sp = timeit.Timer(function).timeit(number=iters)/iters
print(t_sp)
like image 114
nicholaswogan Avatar answered Oct 17 '22 18:10

nicholaswogan


Numba simply is not a general-purpose library to speed code up. There is a class of problems that can be solved in a much faster way with numba (especially if you have loops over arrays, number crunching) but everything else is either (1) not supported or (2) only slightly faster or even a lot slower.

[...] would it even speed up the code?

SciPy is already a high-performance library so in most cases I would expect numba to perform worse (or rarely: slightly better). You might do some profiling to find out if the bottleneck is really in the code that you jitted, then you could get some improvements. But I suspect the bottleneck will be in the compiled code of SciPy and that compiled code is probably already heavily optimized (so it's really unlikely that you find an implementation that could "only" compete with that code).

Is there a way to use jit with scipy.integrate.quad and curve_fit without manually deleting all try except structures from the scipy code?

As you correctly assumed try and except is simply not supported by numba at this time.

2.6.1. Language

2.6.1.1. Constructs

Numba strives to support as much of the Python language as possible, but some language features are not available inside Numba-compiled functions. The following Python language features are not currently supported:

[...]

  • Exception handling (try .. except, try .. finally)

So the answer here is No.

like image 25
MSeifert Avatar answered Oct 17 '22 17:10

MSeifert