Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pandas iteratively update column values

I have a pandas Series like the following:

a = pd.Series([a1, a2, a3, a4, ...])

and I want to create another pandas Series based on the following rule:

b = pd.Series(a1, a2+a1**0.8, a3 + (a2 + a1**0.8)**0.8, a4 + (a3 + (a2 + a1**0.8)**0.8)**0.8, ...).

This is doable using iteration, but I have a large dataset (millions of records) and I must perform operation for thousands of times (for optimization purposes). I need to do this operation very fast. Is there any possible way for me to realize this by using pandas or numpy built-in functions?

like image 481
Xuejun Zhao Avatar asked Jul 20 '18 19:07

Xuejun Zhao


1 Answers

Rather than fight against the fundamentally iterative nature of your problem, you could use numba and try to do the easiest performant iterative version you can:

@numba.jit(nopython=True)
def epow(vec, p):
    out = np.zeros(len(vec))
    out[0] = vec[0]
    for i in range(1, len(vec)):
        out[i] = vec[i] + (out[i-1])**0.8
    return out

which gives me

In [148]: a1, a2, a3, a4 = range(1, 5)

In [149]: a1, a2+a1**0.8, a3 + (a2 + a1**0.8)**0.8, a4 + (a3 + (a2 + a1**0.8)**0.8)**0.8
Out[149]: (1, 3.0, 5.408224685280692, 7.858724574530816)

In [150]: epow(pd.Series([a1, a2, a3, a4]).values, 0.8)
Out[150]: array([1.        , 3.        , 5.40822469, 7.85872457])

and for longer Series:

In [151]: s = pd.Series(np.arange(2*10**6))

In [152]: %time epow(s.values, 0.8)
CPU times: user 512 ms, sys: 20 ms, total: 532 ms
Wall time: 531 ms
Out[152]: 
array([0.00000000e+00, 1.00000000e+00, 3.00000000e+00, ...,
       2.11487244e+06, 2.11487348e+06, 2.11487453e+06])
like image 182
DSM Avatar answered Oct 14 '22 01:10

DSM