Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using lambda functions in RK4 algorithm

There are two ways of implementing the classical Runge-Kutta scheme in Python showed here. The first using lambda functions, the second without them.

Which one is going to be faster and why exactly?

like image 351
Victor Pira Avatar asked Jan 04 '17 18:01

Victor Pira


3 Answers

I adapted the code in the given link, and used cProfile to compare both techniques:

import numpy as np
import cProfile as cP

def theory(t):
    return (t**2 + 4.)**2 / 16.

def f(x, y):
    return x * np.sqrt(y)

def RK4(f):
    return lambda t, y, dt: (
            lambda dy1: (
            lambda dy2: (
            lambda dy3: (
            lambda dy4: (dy1 + 2*dy2 + 2*dy3 + dy4)/6
                      )( dt * f( t + dt  , y + dy3   ) )
                      )( dt * f( t + dt/2, y + dy2/2 ) )
                      )( dt * f( t + dt/2, y + dy1/2 ) )
                      )( dt * f( t       , y         ) )


def test_RK4(dy=f, x0=0., y0=1., x1=10, n=10):
    vx = np.empty(n+1)
    vy = np.empty(n+1)
    dy = RK4(f=dy)
    dx = (x1 - x0) / float(n)
    vx[0] = x = x0
    vy[0] = y = y0
    i = 1
    while i <= n:
        vx[i], vy[i] = x + dx, y + dy(x, y, dx)
        x, y = vx[i], vy[i]
        i += 1
    return vx, vy


def rk4_step(dy, x, y, dx):
    k1 = dx * dy(x, y)
    k2 = dx * dy(x + 0.5 * dx, y + 0.5 * k1)
    k3 = dx * dy(x + 0.5 * dx, y + 0.5 * k2)
    k4 = dx * dy(x + dx, y + k3)
    return x + dx, y + (k1 + k2 + k2 + k3 + k3 + k4) / 6.


def test_rk4(dy=f, x0=0., y0=1., x1=10, n=10):
    vx = np.empty(n+1)
    vy = np.empty(n+1)
    dx = (x1 - x0) / float(n)
    vx[0] = x = x0
    vy[0] = y = y0
    i = 1
    while i <= n:
        vx[i], vy[i] = rk4_step(dy=dy, x=x, y=y, dx=dx)
        x, y = vx[i], vy[i]
        i += 1
    return vx, vy

cP.run("test_RK4(n=10000)")
cP.run("test_rk4(n=10000)")

And got:

         90006 function calls in 0.095 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.095    0.095 <string>:1(<module>)
    40000    0.036    0.000    0.036    0.000 untitled1.py:13(f)
        1    0.000    0.000    0.000    0.000 untitled1.py:16(RK4)
    10000    0.008    0.000    0.086    0.000 untitled1.py:17(<lambda>)
    10000    0.012    0.000    0.069    0.000 untitled1.py:18(<lambda>)
    10000    0.012    0.000    0.048    0.000 untitled1.py:19(<lambda>)
    10000    0.009    0.000    0.027    0.000 untitled1.py:20(<lambda>)
    10000    0.009    0.000    0.009    0.000 untitled1.py:21(<lambda>)
        1    0.009    0.009    0.095    0.095 untitled1.py:28(test_RK4)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        2    0.000    0.000    0.000    0.000 {numpy.core.multiarray.empty}


         50005 function calls in 0.064 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.064    0.064 <string>:1(<module>)
    40000    0.032    0.000    0.032    0.000 untitled1.py:13(f)
    10000    0.026    0.000    0.058    0.000 untitled1.py:43(rk4_step)
        1    0.006    0.006    0.064    0.064 untitled1.py:51(test_rk4)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        2    0.000    0.000    0.000    0.000 {numpy.core.multiarray.empty}

So I would say the function call overhead in the "lambda" implementation makes it slower.

Nevertheless beware that somehow I appear to have lost some precision, as the results, despite agreeing with each other, are more off than the ones in the example:

>>> vx, vy = test_rk4()
>>> vy
array([   1.        ,    1.56110667,    3.99324757, ...,  288.78174798,
        451.27952013,  675.64427775])
>>> vx, vy = test_RK4()
>>> vy
array([   1.        ,    1.56110667,    3.99324757, ...,  288.78174798,
        451.27952013,  675.64427775])
like image 156
berna1111 Avatar answered Oct 28 '22 01:10

berna1111


If you pre-process the code with the Coconut transpiler, which implements tail-call optimization, then they are completely equivalent (as fast as the faster version un-processed), so you can use whichever style is more convenient for you.

# Save berna1111's code as rk4.coco; no modifications necessary.
$ coconut --target 3 rk4.coco & python3 rk4.py
         50007 function calls in 0.055 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.097    0.097 <string>:1(<module>)
    40000    0.038    0.000    0.038    0.000 rk4.py:243(f)
        1    0.000    0.000    0.000    0.000 rk4.py:246(RK4)
    10000    0.007    0.000    0.088    0.000 rk4.py:247(<lambda>)
        1    0.010    0.010    0.097    0.097 rk4.py:250(test_RK4)
        1    0.000    0.000    0.097    0.097 {built-in method builtins.exec}
        2    0.000    0.000    0.000    0.000 {built-in method numpy.core.multiarray.empty}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


         50006 function calls in 0.057 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.057    0.057 <string>:1(<module>)
    40000    0.030    0.000    0.030    0.000 rk4.py:243(f)
    10000    0.019    0.000    0.049    0.000 rk4.py:265(rk4_step)
        1    0.007    0.007    0.057    0.057 rk4.py:273(test_rk4)
        1    0.000    0.000    0.057    0.057 {built-in method builtins.exec}
        2    0.000    0.000    0.000    0.000 {built-in method numpy.core.multiarray.empty}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
like image 24
matt2000 Avatar answered Oct 28 '22 01:10

matt2000


Both @berna1111 and @matt2000 are correct. The lambda version incurs extra over-head due to the function calls. Tail-call optimization converts the tail calls into a while loop (i.e. auto converts the lambda version to the while version) eliminating the function call overhead.

See https://stackoverflow.com/a/13592002/7421639 for why Python doesn't do this optimization automatically and you have to use a tool like Coconut to do a pre-process pass.

like image 2
vdovydaitis3 Avatar answered Oct 28 '22 02:10

vdovydaitis3