Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pandas sum of next n rows

Tags:

python

pandas

I have a DataFrame that looks like the following:

             ds         y
0    2017-02-07  0.154941
1    2017-02-08  0.110595
2    2017-02-09  0.044022
3    2017-02-10  0.283902
4    2017-02-11  0.121570
5    2017-02-12  0.000000
6    2017-02-13  0.020265
7    2017-02-14  0.053577
8    2017-02-15  0.080842
9    2017-02-16  0.022043

I am now trying to create a new column 'next_3' that is the sum of the y values for the next 3 days following the current day.

I am achieving this using:

df['next_3'] = df['y'].shift(-3).rolling(3).sum()

which produces this:

           ds         y     label
0  2017-02-07  0.154941       NaN
1  2017-02-08  0.110595       NaN
2  2017-02-09  0.044022  0.405472
3  2017-02-10  0.283902  0.141836
4  2017-02-11  0.121570  0.073842
5  2017-02-12  0.000000  0.154685
6  2017-02-13  0.020265  0.156462
7  2017-02-14  0.053577       NaN
8  2017-02-15  0.080842       NaN
9  2017-02-16  0.022043       NaN

I understand why the last 3 rows have NaN values since the next 3 rows aren't available, but why do the first 2 rows have NaN values when these values can be calculated?

How can I correct my shift().rolling().sum() call so that the first two rows are also calculated?

like image 919
KOB Avatar asked Dec 18 '18 10:12

KOB


People also ask

How do I sum across rows in pandas?

To sum all the rows of a DataFrame, use the sum() function and set the axis value as 1. The value axis 1 will add the row values.

How do you add two rows in a data frame?

To append the rows of one dataframe with the rows of another, we can use the Pandas append() function.

How do you sum multiple columns in Python?

If we want to summarize all the columns, then we can simply use the DataFrame sum() method.


1 Answers

Use parameter min_periods=1:

df['next_3'] = df['y'].shift(-3).rolling(3, min_periods=1).sum()
print (df)
           ds         y    next_3
0  2017-02-07  0.154941  0.283902
1  2017-02-08  0.110595  0.405472
2  2017-02-09  0.044022  0.405472
3  2017-02-10  0.283902  0.141835
4  2017-02-11  0.121570  0.073842
5  2017-02-12  0.000000  0.154684
6  2017-02-13  0.020265  0.156462
7  2017-02-14  0.053577  0.102885
8  2017-02-15  0.080842  0.022043
9  2017-02-16  0.022043       NaN

Or first use rolling and then shifting:

df['next_3'] = df['y'].rolling(3).sum().shift(-3)
print (df)
           ds         y    next_3
0  2017-02-07  0.154941  0.438519
1  2017-02-08  0.110595  0.449494
2  2017-02-09  0.044022  0.405472
3  2017-02-10  0.283902  0.141835
4  2017-02-11  0.121570  0.073842
5  2017-02-12  0.000000  0.154684
6  2017-02-13  0.020265  0.156462
7  2017-02-14  0.053577       NaN
8  2017-02-15  0.080842       NaN
9  2017-02-16  0.022043       NaN
like image 56
jezrael Avatar answered Oct 12 '22 21:10

jezrael