Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Close form solution for finding a root

Suppose I have a Pandas Series s whose values sum to 1 and whose values are also all greater than or equal to 0. I need to subtract a constant from all values such that the sum of the new Series is equal to 0.6. The catch is, when I subtract this constant, the values never end up less than zero.

In math formula, assume I have a series of x's and I want to find k

enter image description here

MCVE

import pandas as pd
import numpy as np
from string import ascii_uppercase

np.random.seed([3, 141592653])
s = np.power(
    1000, pd.Series(
        np.random.rand(10),
        list(ascii_uppercase[:10])
    )
).pipe(lambda s: s / s.sum())

s

A    0.001352
B    0.163135
C    0.088365
D    0.010904
E    0.007615
F    0.407947
G    0.005856
H    0.198381
I    0.027455
J    0.088989
dtype: float64

The sum is 1

s.sum()

0.99999999999999989

What I've tried

I can use Newton's method (among others) found in Scipy's optimize module

from scipy.optimize import newton

def f(k):
    return s.sub(k).clip(0).sum() - .6

Finding the root of this function will give me the k I need

initial_guess = .1
k = newton(f, x0=initial_guess)

Then subtract this from s

new_s = s.sub(k).clip(0)
new_s

A    0.000000
B    0.093772
C    0.019002
D    0.000000
E    0.000000
F    0.338583
G    0.000000
H    0.129017
I    0.000000
J    0.019626
dtype: float64

And the new sum is

new_s.sum()

0.60000000000000009

Question

Can we find k without resorting to using a solver?

like image 571
piRSquared Avatar asked Mar 16 '18 06:03

piRSquared


People also ask

What is a closed-form solution example?

A closed form is an expression that can be computed by applying a fixed number of familiar operations to the arguments. For example, the expression 2 + 4 + … + 2n is not a closed form, but the expression n(n+1) is a closed form. " = a1 +L+an .

What is the closed-form for a quadratic equation?

One simple example of a function with a closed form solution is the quadratic equation, which uses only three coefficients (a, b, c) and one variable (x): ax2 + bx + c = 0.

What is the closed-form solution of recurrence relation?

Solving a recurrence relation employs finding a closed-form solution for the recurrence relation. An equation such as S(n) = 2n, where we can substitute a value for n and get the output value back directly, is called a closed- form solution.


2 Answers

I was not expecting newton to carry the day. But on large arrays, it does.

numba.njit

Inspire by Thierry's Answer
Using a loop on a sorted array with numbas jit

import numpy as np
from numba import njit


@njit
def find_k_numba(a, t):
    a = np.sort(a)
    m = len(a)
    s = a.sum()
    to_remove = s - t

    if to_remove <= 0:
        k = 0
    else:
        for i, x in enumerate(a):
            k = to_remove / (m - i)
            if k < x:
                break
            else:
                to_remove -= x
    return k

numpy

Inspired by Paul's Answer
Numpy carrying the heavy lifting.

import numpy as np

def find_k_numpy(a, t):
    a = np.sort(a)
    m = len(a)
    s = a.sum()
    to_remove = s - t

    if to_remove <= 0:
        k = 0
    else:
        c = a.cumsum()
        n = np.arange(m)[::-1]
        b = n * a + c
        i = np.searchsorted(b, to_remove)
        k = a[i] + (to_remove - b[i]) / (m - i)
    return k

scipy.optimize.newton

My method via Newton

import numpy as np
from scipy.optimize import newton


def find_k_newton(a, t):
    s = a.sum()
    if s <= t:
        k = 0
    else:
        def f(k_):
            return np.clip(a - k_, 0, a.max()).sum() - t

        k = newton(f, (s - t) / len(a))

    return k

Time Trials

import pandas as pd
from timeit import timeit


res = pd.DataFrame(
    np.nan, [10, 30, 100, 300, 1000, 3000, 10000, 30000],
    'find_k_newton find_k_numpy find_k_numba'.split()
)

for i in res.index:
    a = np.random.rand(i)
    t = a.sum() * .6
    for j in res.columns:
        stmt = f'{j}(a, t)'
        setp = f'from __main__ import {j}, a, t'
        res.at[i, j] = timeit(stmt, setp, number=200)

Results

res.plot(loglog=True)

enter image description here

res.div(res.min(1), 0)

       find_k_newton  find_k_numpy  find_k_numba
10         57.265421     17.552150      1.000000
30         29.221947      9.420263      1.000000
100        16.920463      5.294890      1.000000
300        10.761341      3.037060      1.000000
1000        1.455159      1.033066      1.000000
3000        1.000000      2.076484      2.550152
10000       1.000000      3.798906      4.421955
30000       1.000000      5.551422      6.784594
like image 188
piRSquared Avatar answered Sep 28 '22 08:09

piRSquared


Updated: Three different implementations - interestingly, the least sophisticated scales best.

import numpy as np

def f_sort(A, target=0.6):
    B = np.sort(A)
    C = np.cumsum(np.r_[B[0], np.diff(B)] * np.arange(N, 0, -1))
    idx = np.searchsorted(C, 1 - target)
    return B[idx] + (1 - target - C[idx]) / (N-idx)

def f_partition(A, target=0.6):
    target, l = 1 - target, len(A)
    while len(A) > 1:
        m = len(A) // 2
        A = np.partition(A, m-1)
        ls = A[:m].sum()
        if ls + A[m-1] * (l-m) > target:
            A = A[:m]
        else:
            l -= m
            target -= ls
            A = A[m:]
    return target / l            

def f_direct(A, target=0.6):
    target = 1 - target
    while True:
        gt = A > target / len(A)
        if np.all(gt):
            return target / len(A)
        target -= A[~gt].sum()
        A = A[gt]

N = 10
A = np.random.random(N)
A /= A.sum()

print(f_sort(A), np.clip(A-f_sort(A), 0, None).sum())
print(f_partition(A), np.clip(A-f_partition(A), 0, None).sum())
print(f_direct(A), np.clip(A-f_direct(A), 0, None).sum())

from timeit import timeit
kwds = dict(globals=globals(), number=1000)

N = 100000
A = np.random.random(N)
A /= A.sum()

print(timeit('f_sort(A)', **kwds))
print(timeit('f_partition(A)', **kwds))
print(timeit('f_direct(A)', **kwds))

Sample run:

0.04813686999999732 0.5999999999999999
0.048136869999997306 0.6000000000000001
0.048136869999997306 0.6000000000000001
8.38109541599988
2.1064437470049597
1.2743922089866828
like image 20
Paul Panzer Avatar answered Sep 28 '22 09:09

Paul Panzer