Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to efficiently pass function through?

Motivation

Take a look at the following picture.

enter image description here

Given are the red, blue, and green curve. I would like to find at each point on the x axis the dominating curve. This is shown as the black graph in the picture. From the properties of the red, green, and blue curve (increasing and constant after a while) this boils down to find the dominating curve on the very right hand side and then move towards the left hand side finding all intersection points and update the dominating curve.

This outlined problem should be solved T times. There is one final twist in the problem. The blue, green, and red curve of the next iteration are constructed via the dominating solution from the previous iteration plus some varying parameters. As an example in the picture above: The solution is the black function. This function is used to generate a new blue, green, and red curve. Then the problem start again to find the dominating one for these new curves etc.

Question in a nutshell
In each iteration I start at the fixed very right hand side and evaluate all three functions to see which is the dominating one. This evaluations are taking longer and longer over iteration. My feeling is that I don't pass optimally the old dominating function to construct the new blue, green, and red curve. Reason: I got in an earlier version a maximum recursion depth error. Other parts of the code where the value of the current dominating function (which is essential either the green, red, or blue curve) is required are also taking longer and longer with iteration.

For 5 iterations just evaluating the functions on one point on the very right hand side grows:

The results were produced via

test = A(5, 120000, 100000) 

And then running

test.find_all_intersections()

>>> test.find_all_intersections()
iteration 4
to compute function values it took
0.0102479457855
iteration 3
to compute function values it took
0.0134601593018
iteration 2
to compute function values it took
0.0294270515442
iteration 1
to compute function values it took
0.109843969345
iteration 0
to compute function values it took
0.823768854141

I would like to know why is this the case and if one can program it more efficiently.

Detailed Code explanation

I quickly summarize the most important functions. The complete code can be found further below. If there are any other questions regarding the code I'm more than happy to elaborate / clarify.

  1. Method u: For the recurring task of generating a new batch of the green, red, and blue curve above we need the old dominating one. u is the initialization to be used in the very first iteration.

  2. Method _function_template: The function generates versions of the green, blue, and red curve by using different parameters. It returns a function of a single input.

  3. Method eval: This is the core function to generate the blue, green, and red versions every time. It takes three varying parameters each iteration: vfunction which is the dominating function from the previous step, m, and s which are two parameters (flaots) affecting the shape of the resulting curve. The other parameters are the same in each iteration. In the code there are sample values for m and s for each iteration. For the more geeky ones: It's to approximate an integral where m and s are the expected mean and standard deviation of the underlying normal distribution. The approximation is done via Gauss-Hermite nodes / weights.

  4. Method find_all_intersections: This is the core method finding in each iteration the dominating one. It constructs a dominating function via a piece wise concatenation of the blue, green, and red curve. This is achieved via the function piecewise.

Here is the complete code

import numpy as np
import pandas as pd
from scipy.optimize import brentq
import multiprocessing as mp
import pathos as pt
import timeit
import math
class A(object):
    def u(self, w):
        _w = np.asarray(w).copy()
        _w[_w >= 120000] = 120000
        _p = np.maximum(0, 100000 - _w)
        return _w - 1000*_p**2

    def __init__(self, T, upper_bound, lower_bound):
        self.T = T
        self.upper_bound = upper_bound
        self.lower_bound = lower_bound

    def _function_template(self, *args):
        def _f(x):
            return self.evalv(x, *args)
        return _f

    def evalv(self, w, c, vfunction, g, m, s, gauss_weights, gauss_nodes):
        _A = np.tile(1 + m + math.sqrt(2) * s * gauss_nodes, (np.size(w), 1))
        _W = (_A.T * w).T
        _W = gauss_weights * vfunction(np.ravel(_W)).reshape(np.size(w),
                                                             len(gauss_nodes))
        evalue = g*1/math.sqrt(math.pi)*np.sum(_W, axis=1)
        return c + evalue

    def find_all_intersections(self):

        # the hermite gauss weights and nodes for integration
        # and additional paramters used for function generation

        gauss = np.polynomial.hermite.hermgauss(10)
        gauss_nodes = gauss[0]
        gauss_weights = gauss[1]
        r = np.asarray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1.])
        m = [[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624]]

        s = [[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142]]

        self.solution = []

        n_cpu = mp.cpu_count()
        pool = pt.multiprocessing.ProcessPool(n_cpu)

        # this function is used for multiprocessing
        def call_f(f, x):
            return f(x)

        # this function takes differences for getting cross points
        def _diff(f_dom, f_other):
            def h(x):
                return f_dom(x) - f_other(x)
            return h

        # finds the root of two function
        def find_roots(F, u_bound, l_bound):
                try:
                    sol = brentq(F, a=l_bound,
                                 b=u_bound)
                    if np.absolute(sol - u_bound) > 1:
                        return sol
                    else:
                        return l_bound
                except ValueError:
                    return l_bound

        # piecewise function
        def piecewise(l_comp, l_f):
            def f(x):
                _ind_f = np.digitize(x, l_comp) - 1
                if np.isscalar(x):
                    return l_f[_ind_f](x)
                else:
                    return np.asarray([l_f[_ind_f[i]](x[i])
                                       for i in range(0, len(x))]).ravel()
            return f

        _u = self.u

        for t in range(self.T-1, -1, -1):
            print('iteration' + ' ' + str(t))

            l_bound, u_bound = 0.5*self.lower_bound, self.upper_bound
            l_ordered_functions = []
            l_roots = []
            l_solution = []

            # build all function variations

            l_functions = [self._function_template(0, _u, r[t], m[t][i], s[t][i],
                                                   gauss_weights, gauss_nodes)
                           for i in range(0, len(m[t]))]

            # get the best solution for the upper bound on the very
            # right hand side of wealth interval

            array_functions = np.asarray(l_functions)
            start_time = timeit.default_timer()
            functions_values = pool.map(call_f, array_functions.tolist(),
                                        len(m[t]) * [u_bound])
            elapsed = timeit.default_timer() - start_time
            print('to compute function values it took')
            print(elapsed)

            ind = np.argmax(functions_values)
            cross_points = len(m[t]) * [u_bound]
            l_roots.insert(0, u_bound)
            max_m = m[t][ind]
            l_solution.insert(0, max_m)

            # move from the upper bound twoards the lower bound
            # and find the dominating solution by exploring all cross
            # points.

            test = True

            while test:
                l_ordered_functions.insert(0, array_functions[ind])
                current_max = l_ordered_functions[0]

                l_c_max = len(m[t]) * [current_max]
                l_u_cross = len(m[t]) * [cross_points[ind]]

                # Find new cross points on the smaller interval

                diff = pool.map(_diff, l_c_max, array_functions.tolist())
                cross_points = pool.map(find_roots, diff,
                                        l_u_cross, len(m[t]) * [l_bound])

                # update the solution, cross points and current
                # dominating function.

                ind = np.argmax(cross_points)
                l_roots.insert(0, cross_points[ind])
                max_m = m[t][ind]
                l_solution.insert(0, max_m)

                if cross_points[ind] <= l_bound:
                    test = False

            l_ordered_functions.insert(0, l_functions[0])
            l_roots.insert(0, 0)
            l_roots[-1] = np.inf

            l_comp = l_roots[:]
            l_f = l_ordered_functions[:]

            # build piecewise function which is used for next
            # iteration.

            _u = piecewise(l_comp, l_f)
            _sol = pd.DataFrame(data=l_solution,
                                index=np.asarray(l_roots)[0:-1])
            self.solution.insert(0, _sol)
        return self.solution
like image 284
math Avatar asked Jan 14 '18 11:01

math


People also ask

How do you pass a function as an argument?

We cannot pass the function as an argument to another function. But we can pass the reference of a function as a parameter by using a function pointer. This process is known as call by reference as the function parameter is passed as a pointer that holds the address of arguments.

How do you pass a function by reference in C++?

To pass the value by reference, argument reference is passed to the functions just like any other value. So accordingly you need to declare the function parameters as reference types as in the following function swap(), which exchanges the values of the two integer variables pointed to by its arguments.

How do you pass a function to another function in TypeScript?

Similar to JavaScript, to pass a function as a parameter in TypeScript, define a function expecting a parameter that will receive the callback function, then trigger the callback function inside the parent function.

Can you pass a function into another function Python?

In Python you can pass function objects in to other functions. Functions can be passed around in Python. In fact there are functions built into Python that expect functions to be given as one or more of their arguments so that they can then call them later.


2 Answers

Let's start by changing the code to output the current iteration:

_u = self.u
for t in range(0, self.T):
    print(t)
    lparams = np.random.randint(self.a, self.b, 6).reshape(3, 2).tolist()
    functions = [self._function_template(_u, *lparams[i])
                 for i in range(0, 3)]
    # evaluate functions
    pairs = list(itertools.combinations(functions, 2))
    fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]
    ind = np.sort(np.unique(np.random.randint(self.a, self.b, 10)))
    _u = _temp(ind, np.asarray(functions)[ind % 3])

Looking into the line causing the behaviour,

fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]

functions of interest would be F and diff. The latter being straightforward, the former:

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError:
        pass

Hmm, swallowing exceptions, let's see what happens if we:

def F(f, a, b):
    brentq(f, a=a, b=b)

Immediately, for the first function and on the first iteration, an error is thrown:

ValueError: f(a) and f(b) must have different signs

Looking at the docs this is a prerequisite of the root finding function brentq. Let's change the definition once more to monitor this condition on each iteration.

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError as e:
        print(e)

The output is

i
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs

for i ranging from 0 to 57. Meaning, the first time the function F ever does any real work is for i=58. And it keeps doing so for higher values of i.

Conclusion: it takes longer for these higher values, because:

  1. the root is never calculated for the lower values
  2. the number of calculations grows linear for i>58
like image 51
Funk Avatar answered Oct 17 '22 00:10

Funk


Your code is really far too complex to explain your problem - strive for something simpler. Sometimes you have to write code just to demonstrate the problem.

I'm taking a stab, based purely on your description rather than your code (although I ran the code and verified) . Here's your problem:

method eval: This is the core function to generate the blue, green and red versions every time. It takes three varying parameters each iteration: vfunction which is the dominating function from the previous step, m and s which are two parameters (flaots) affecting the shape of the resulting curve.

Your vfunction parameter is more complex on each iteration. You are passing a nested function built up over previous iterations, which causes a recursive execution. Each iteration increases the depth of the recursive call.

How can you avoid this? There's no easy or built in way. The simplest answer is - assuming the inputs to these functions are consistent - to store the functional result (i.e. the numbers) rather than the function themselves. You can do this as long as you have a finite number of known inputs.

If the inputs to the underlying functions aren't consistent then there's no shortcut. You need to repeatedly evaluate those underlying functions. I see that you're doing some piecewise splicing of the underlying functions - you can test whether the cost of doing so exceeds the cost of simply taking the max of each of the underlying functions.

The test that I ran (10 iterations) took a few seconds. I don't see that as a problem.

like image 33
Kirk Broadhurst Avatar answered Oct 16 '22 23:10

Kirk Broadhurst