I have a pandas dataframe with three columns and a datetime index
date px_last 200dma 50dma
2014-12-24 2081.88 1953.16760 2019.2726
2014-12-26 2088.77 1954.37975 2023.7982
2014-12-29 2090.57 1955.62695 2028.3544
2014-12-30 2080.35 1956.73455 2032.2262
2014-12-31 2058.90 1957.66780 2035.3240
I would like to make a time series plot of the 'px_last' column that is colored green if on the given day the 50dma is above the 200dma value and colored red if the 50dma value is below the 200dma value. I have seen this example, but can't seem to make it work for my case http://matplotlib.org/examples/pylab_examples/multicolored_line.html
Here is an example to do it without matplotlib.collections.LineCollection
. The idea is to first identify the cross-over point and then using a plot
function via groupby.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# simulate data
# =============================
np.random.seed(1234)
df = pd.DataFrame({'px_last': 100 + np.random.randn(1000).cumsum()}, index=pd.date_range('2010-01-01', periods=1000, freq='B'))
df['50dma'] = pd.rolling_mean(df['px_last'], window=50)
df['200dma'] = pd.rolling_mean(df['px_last'], window=200)
df['label'] = np.where(df['50dma'] > df['200dma'], 1, -1)
# plot
# =============================
df = df.dropna(axis=0, how='any')
fig, ax = plt.subplots()
def plot_func(group):
global ax
color = 'r' if (group['label'] < 0).all() else 'g'
lw = 2.0
ax.plot(group.index, group.px_last, c=color, linewidth=lw)
df.groupby((df['label'].shift() * df['label'] < 0).cumsum()).apply(plot_func)
# add ma lines
ax.plot(df.index, df['50dma'], 'k--', label='MA-50')
ax.plot(df.index, df['200dma'], 'b--', label='MA-200')
ax.legend(loc='best')
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With