Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to update artists in scrollable, matplotlib and multiplot

I'm trying to create a scrollable multiplot based on the answer to this question: Creating a scrollable multiplot with python's pylab

Lines created using ax.plot() are updating correctly, however I'm unable to figure out how to update artists created using xvlines() and fill_between().

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider

#create dataframes
dfs={}
for x in range(100):
    col1=np.random.normal(10,0.5,30)
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
    col3=np.random.randint(4,size=30)
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})

#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])

#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')

#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')

#artists
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')

#update plots
def update(val):
    frame = np.floor(sframe.val)
    ln1.set_ydata(dfs[frame]['col1'])
    ln2.set_ydata(dfs[frame]['col2'])
    ax.set_title('Frame ' + str(int(frame)))
    plt.draw()

#connect callback to slider 
sframe.on_changed(update)
plt.show()

This is what it looks like at the moment enter image description here

I can't apply the same approach as for plot(), since the following produces an error message:

ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
TypeError: 'PolyCollection' object is not iterable

This is what it's meant to look like on each frame enter image description here

like image 681
themachinist Avatar asked Nov 20 '15 13:11

themachinist


1 Answers

fill_between returns a PolyCollection, which expects a list (or several lists) of vertices upon creation. Unfortunately I haven't found a way to retrieve the vertices that where used to create the given PolyCollection, but in your case it is easy enough to create the PolyCollection directly (thereby avoiding the use of fill_between) and then update its vertices upon frame change.

Below a version of your code that does what you are after:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider

from matplotlib.collections import PolyCollection

#create dataframes
dfs={}
for x in range(100):
    col1=np.random.normal(10,0.5,30)
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
    col3=np.random.randint(4,size=30)
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})

#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])

#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')

#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')

##additional code to update the PolyCollections
val_r = 5
val_b = 8
val_g = 7

def update_collection(collection, value, frame = 0):
    xs = np.array(dfs[frame].index)
    ys = np.array(dfs[frame]['col2'])

    ##we need to catch the case where no points with y == value exist:
    try:
        minx = np.min(xs[ys == value])
        maxx = np.max(xs[ys == value])
        miny = value-0.5
        maxy = value+0.5
        verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])
    except ValueError:
        verts = np.zeros((0,2))
    finally:
        collection.set_verts([verts])

#artists

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)
ax.add_collection(reds)
update_collection(reds,val_r)

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)
ax.add_collection(blues)
update_collection(blues, val_b)

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)
ax.add_collection(greens)
update_collection(greens, val_g)

ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')

#update plots
def update(val):
    frame = np.floor(sframe.val)
    ln1.set_ydata(dfs[frame]['col1'])
    ln2.set_ydata(dfs[frame]['col2'])
    ax.set_title('Frame ' + str(int(frame)))

    ##updating the PolyCollections:
    update_collection(reds,val_r, frame)
    update_collection(blues,val_b, frame)
    update_collection(greens,val_g, frame)

    plt.draw()

#connect callback to slider 
sframe.on_changed(update)
plt.show()

Each of the three PolyCollections (reds, blues, and greens) has only four vertices (the edges of the rectangles), which are determined based on the given data (which is done in update_collections). The result looks like this:

example result of given code

Tested in Python 3.5

like image 73
Thomas Kühn Avatar answered Nov 14 '22 23:11

Thomas Kühn