Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numba's jit fails to compile function that has another function as input

I am trying to numerically Solve an ODE that admits discrete jumps. I am using the Euler Method and was hoping that Numba's jit might help me to speed up the process (right now the script takes 300s to run and I need it to run 200 times).

Here is my simplified first attempt:

import numpy as np
from numba import jit

dt = 1e-5
T = 1
x0 = 1
noiter = int(T / dt)
res = np.zeros(noiter)

def fdot(x, t):
    return -x + t / (x + 1) ** 2

def solve_my_ODE(res, fdot, x0, T, dt):
    res[0] = x0
    noiter = int(T / dt)
    for i in range(noiter - 1):
        res[i + 1] = res[i] + dt * fdot(res[i], i * dt)
        if res[i + 1] >= 2:
            res[i + 1] -= 2
    return res

%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
    ->The slowest run took 8.38 times longer than the fastest. This could mean that an intermediate result is being cached 
    ->1000000 loops, best of 3: 465 ns per loop
    ->10 loops, best of 3: 122 ms per loop

@jit(nopython=True)
def fdot(x, t):
    return -x + t / (x + 1) ** 2
%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
    ->The slowest run took 106695.67 times longer than the fastest. This could mean that an intermediate result is being cached 
    ->1000000 loops, best of 3: 240 ns per loop
    ->10 loops, best of 3: 99.3 ms per loop

@jit(nopython=True)
def solve_my_ODE(res, fdot, x0, T, dt):
    res[0] = x0
    noiter = int(T / dt)
    for i in range(noiter - 1):
        res[i + 1] = res[i] + dt * fdot(res[i], i * dt)
        if res[i + 1] >= 2:
            res[i + 1] -= 2
    return res
%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
    ->The slowest run took 10.21 times longer than the fastest. This could mean that an intermediate result is being cached 
    ->1000000 loops, best of 3: 274 ns per loop
    ->TypingError                               Traceback (most recent call last)
ipython-input-10-27199e82c72c> in <module>()
  1 get_ipython().magic('timeit fdot(x0, T)')
----> 2 get_ipython().magic('timeit solve_my_ODE(res, fdot, x0, T, dt)')

(...)


TypingError: Failed at nopython (nopython frontend)
Undeclared pyobject(float64, float64)
File "<ipython-input-9-112bd04325a4>", line 6

I don't understand why I got this error. My suspicion is that numba does not recognize the input field fdot (which is a python function which btw is already compiled with Numba).

Since I am so new to Numba I have several questions

  • What can I do to make Numba understand the input field fdot is a function?
  • Using JIT on the function fdot "only" leads to a decrease in 50%. Should I expect more? or is this normal?
  • Does this script look like a reasonable way to simulate an ODE with discrete jumps? Mathematically this is equivalent at solving an ODE with delta functions.

Numba version is 0.17

like image 703
gota Avatar asked Apr 28 '15 17:04

gota


People also ask

How does JIT work in Python?

Numba is what is called a JIT (just-in-time) compiler. It takes Python functions designated by particular annotations (more about that later), and transforms as much as it can — via the LLVM (Low Level Virtual Machine) compiler — to efficient CPU and GPU (via CUDA for Nvidia GPUs and HSA for AMD GPUs) code.

What does Numba JIT do?

Numba is an open source JIT compiler that translates a subset of Python and NumPy code into fast machine code.

Is there a JIT for Python?

There are two common approaches to compiling Python code - using a Just-In-Time (JIT) compiler and using Cython for Ahead of Time (AOT) compilation.

Does Numba release Gil?

Numba will release the GIL when entering such a compiled function if you passed nogil=True . Code running with the GIL released runs concurrently with other threads executing Python or Numba code (either the same compiled function, or another one), allowing you to take advantage of multi-core systems.


1 Answers

You're right in thinking that numba doesn't recognise fdot as a numba compiled function. I don't think you can make it recognise it as a function argument, but you can use this approach (using variable capture so fdot is known when the function is built) to build an ODE solver:

def make_solver(f):
    @jit(nopython=True)
    def solve_my_ODE(res, x0, T, dt):
        res[0] = x0
        noiter = int(T / dt)
        for i in range(noiter - 1):
            res[i + 1] = res[i] + dt * f(res[i], i * dt)
            if res[i + 1] >= 2:
                res[i + 1] -= 2
        return res
    return solve_my_ODE

fdot_solver = make_solver(fdot) # call this for each function you 
      # want to make an ODE solver for

Here's an alternate version which doesn't require you to pass res to it. Only the loop is accelerated, but since that's the slow bit that's the only important bit.

def make_solver_2(f):
    @jit
    def solve_my_ODE(x0, T, dt):
        # this bit ISN'T in no python mode
        noiter = int(T / dt)
        res = np.zeros(noiter)
        res[0] = x0
        # but the loop is nopython (so fast)
        for i in range(noiter - 1):
            res[i + 1] = res[i] + dt * f(res[i], i * dt)
            if res[i + 1] >= 2:
                res[i + 1] -= 2
        return res
    return solve_my_ODE

I prefer this version because it allocates the return value for you, so it's a little easier to use. That's a slight diversion from your real question though.

In terms of timings I get (in seconds, for 20 iterations):

  • 6.90394687653 (for only fdot in numba)
  • 0.0584900379181 (for version 1)
  • 0.0640540122986 (for version 2 - i.e. it's slightly slower but a little easier to use)

Thus, it's roughly 100x faster - accelerating the loop makes a big difference!

Your third question: "Does this script look like a reasonable way to simulate an ODE with discrete jumps? Mathematically this is equivalent at solving an ODE with delta functions." I really don't know. Sorry!

like image 125
DavidW Avatar answered Oct 22 '22 05:10

DavidW