df = pd.DataFrame(dict(
list(
zip(["A", "B", "C"],
[np.array(["id %02d" % i for i in range(1, 11)]).repeat(10),
pd.date_range("2018-01-01", periods=100).strftime("%Y-%m-%d"),
[i for i in range(10, 110)]])
)
))
df = df.groupby(["A", "B"]).sum()
df["D"] = df["C"].shift(1).rolling(2).mean()
df
This code generates the following:
I want the rolling logic to start over for every new ID. Right now, ID 02
is using the last two values from ID 01
to calculate the mean.
How can this be achieved?
I believe you need groupby
:
df['D'] = df["C"].shift(1).groupby(df['A'], group_keys=False).rolling(2).mean()
print (df.head(20))
C D
A B
id 01 2018-01-01 10 NaN
2018-01-02 11 NaN
2018-01-03 12 10.5
2018-01-04 13 11.5
2018-01-05 14 12.5
2018-01-06 15 13.5
2018-01-07 16 14.5
2018-01-08 17 15.5
2018-01-09 18 16.5
2018-01-10 19 17.5
id 02 2018-01-11 20 NaN
2018-01-12 21 19.5
2018-01-13 22 20.5
2018-01-14 23 21.5
2018-01-15 24 22.5
2018-01-16 25 23.5
2018-01-17 26 24.5
2018-01-18 27 25.5
2018-01-19 28 26.5
2018-01-20 29 27.5
Or:
df['D'] = df["C"].groupby(df['A']).shift(1).rolling(2).mean()
print (df.head(20))
C D
A B
id 01 2018-01-01 10 NaN
2018-01-02 11 NaN
2018-01-03 12 10.5
2018-01-04 13 11.5
2018-01-05 14 12.5
2018-01-06 15 13.5
2018-01-07 16 14.5
2018-01-08 17 15.5
2018-01-09 18 16.5
2018-01-10 19 17.5
id 02 2018-01-11 20 NaN
2018-01-12 21 NaN
2018-01-13 22 20.5
2018-01-14 23 21.5
2018-01-15 24 22.5
2018-01-16 25 23.5
2018-01-17 26 24.5
2018-01-18 27 25.5
2018-01-19 28 26.5
2018-01-20 29 27.5
While the accepted answer by @jezrael works correctly for positive shifts, it gives incorrect result (partially) for negative shifts. Please check the following
df['D'] = df["C"].groupby(df['A']).shift(1).rolling(2).mean()
df['E'] = df["C"].groupby(df['A']).rolling(2).mean().shift(1).values
df['F'] = df["C"].groupby(df['A']).shift(-1).rolling(2).mean()
df['G'] = df["C"].groupby(df['A']).rolling(2).mean().shift(-1).values
df.set_index(['A', 'B'], inplace=True)
print(df.head(20))
C D E F G
A B
id 01 2018-01-01 10 NaN NaN NaN 10.5
2018-01-02 11 NaN NaN 11.5 11.5
2018-01-03 12 10.5 10.5 12.5 12.5
2018-01-04 13 11.5 11.5 13.5 13.5
2018-01-05 14 12.5 12.5 14.5 14.5
2018-01-06 15 13.5 13.5 15.5 15.5
2018-01-07 16 14.5 14.5 16.5 16.5
2018-01-08 17 15.5 15.5 17.5 17.5
2018-01-09 18 16.5 16.5 18.5 18.5
2018-01-10 19 17.5 17.5 NaN NaN
id 02 2018-01-11 20 NaN 18.5 NaN 20.5
2018-01-12 21 NaN NaN 21.5 21.5
2018-01-13 22 20.5 20.5 22.5 22.5
2018-01-14 23 21.5 21.5 23.5 23.5
2018-01-15 24 22.5 22.5 24.5 24.5
2018-01-16 25 23.5 23.5 25.5 25.5
2018-01-17 26 24.5 24.5 26.5 26.5
2018-01-18 27 25.5 25.5 27.5 27.5
2018-01-19 28 26.5 26.5 28.5 28.5
2018-01-20 29 27.5 27.5 NaN NaN
Note that columns D
and E
are computed for .shift(1)
and columns F
and G
are computed for .shift(-1)
. Column E
is incorrect, since the first value of id 02
uses last two values of id 01
. Column F
is incorrect since first values are NaN
s for both id 01
and id 02
. Columns D
and G
give correct results. So, the full answer should be like this. If shift period is non-negative, use the following
df['D'] = df["C"].groupby(df['A']).shift(1).rolling(2).mean()
If shift period is negative, use the following
df['G'] = df["C"].groupby(df['A']).rolling(2).mean().shift(-1).values
Hope it helps!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With