I am trying to numerically Solve an ODE that admits discrete jumps. I am using the Euler Method and was hoping that Numba's jit might help me to speed up the process (right now the script takes 300s to run and I need it to run 200 times).
Here is my simplified first attempt:
import numpy as np
from numba import jit
dt = 1e-5
T = 1
x0 = 1
noiter = int(T / dt)
res = np.zeros(noiter)
def fdot(x, t):
return -x + t / (x + 1) ** 2
def solve_my_ODE(res, fdot, x0, T, dt):
res[0] = x0
noiter = int(T / dt)
for i in range(noiter - 1):
res[i + 1] = res[i] + dt * fdot(res[i], i * dt)
if res[i + 1] >= 2:
res[i + 1] -= 2
return res
%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
->The slowest run took 8.38 times longer than the fastest. This could mean that an intermediate result is being cached
->1000000 loops, best of 3: 465 ns per loop
->10 loops, best of 3: 122 ms per loop
@jit(nopython=True)
def fdot(x, t):
return -x + t / (x + 1) ** 2
%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
->The slowest run took 106695.67 times longer than the fastest. This could mean that an intermediate result is being cached
->1000000 loops, best of 3: 240 ns per loop
->10 loops, best of 3: 99.3 ms per loop
@jit(nopython=True)
def solve_my_ODE(res, fdot, x0, T, dt):
res[0] = x0
noiter = int(T / dt)
for i in range(noiter - 1):
res[i + 1] = res[i] + dt * fdot(res[i], i * dt)
if res[i + 1] >= 2:
res[i + 1] -= 2
return res
%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
->The slowest run took 10.21 times longer than the fastest. This could mean that an intermediate result is being cached
->1000000 loops, best of 3: 274 ns per loop
->TypingError Traceback (most recent call last)
ipython-input-10-27199e82c72c> in <module>()
1 get_ipython().magic('timeit fdot(x0, T)')
----> 2 get_ipython().magic('timeit solve_my_ODE(res, fdot, x0, T, dt)')
(...)
TypingError: Failed at nopython (nopython frontend)
Undeclared pyobject(float64, float64)
File "<ipython-input-9-112bd04325a4>", line 6
I don't understand why I got this error. My suspicion is that numba does not recognize the input field fdot (which is a python function which btw is already compiled with Numba).
Since I am so new to Numba I have several questions
Numba version is 0.17
Numba is what is called a JIT (just-in-time) compiler. It takes Python functions designated by particular annotations (more about that later), and transforms as much as it can — via the LLVM (Low Level Virtual Machine) compiler — to efficient CPU and GPU (via CUDA for Nvidia GPUs and HSA for AMD GPUs) code.
Numba is an open source JIT compiler that translates a subset of Python and NumPy code into fast machine code.
There are two common approaches to compiling Python code - using a Just-In-Time (JIT) compiler and using Cython for Ahead of Time (AOT) compilation.
Numba will release the GIL when entering such a compiled function if you passed nogil=True . Code running with the GIL released runs concurrently with other threads executing Python or Numba code (either the same compiled function, or another one), allowing you to take advantage of multi-core systems.
You're right in thinking that numba doesn't recognise fdot
as a numba compiled function. I don't think you can make it recognise it as a function argument, but you can use this approach (using variable capture so fdot
is known when the function is built) to build an ODE solver:
def make_solver(f):
@jit(nopython=True)
def solve_my_ODE(res, x0, T, dt):
res[0] = x0
noiter = int(T / dt)
for i in range(noiter - 1):
res[i + 1] = res[i] + dt * f(res[i], i * dt)
if res[i + 1] >= 2:
res[i + 1] -= 2
return res
return solve_my_ODE
fdot_solver = make_solver(fdot) # call this for each function you
# want to make an ODE solver for
Here's an alternate version which doesn't require you to pass res
to it. Only the loop is accelerated, but since that's the slow bit that's the only important bit.
def make_solver_2(f):
@jit
def solve_my_ODE(x0, T, dt):
# this bit ISN'T in no python mode
noiter = int(T / dt)
res = np.zeros(noiter)
res[0] = x0
# but the loop is nopython (so fast)
for i in range(noiter - 1):
res[i + 1] = res[i] + dt * f(res[i], i * dt)
if res[i + 1] >= 2:
res[i + 1] -= 2
return res
return solve_my_ODE
I prefer this version because it allocates the return value for you, so it's a little easier to use. That's a slight diversion from your real question though.
In terms of timings I get (in seconds, for 20 iterations):
Thus, it's roughly 100x faster - accelerating the loop makes a big difference!
Your third question: "Does this script look like a reasonable way to simulate an ODE with discrete jumps? Mathematically this is equivalent at solving an ODE with delta functions." I really don't know. Sorry!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With