Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

matplotlib sharex with colorbar not working

enter image description hereI have 2 subplots- 1 scatter and one bar for which I would like a shared x axis. The scatter plot has a color bar. The sharex doesn't seem to work with this as the axis for the two plots do not coincide. My code:

fig, (ax, ax2) = plt.subplots(2,1, gridspec_kw = {'height_ratios':[13,2]},figsize=(15,12), sharex=True)

df_plotdata.plot(kind='scatter', ax=ax, x='index_cancer', y='index_g', s=df_plotdata['freq1']*50, c=df_plotdata['freq2'], cmap=cmap)

df2.plot(ax=ax2, x='index_cancer', y='freq', kind = 'bar')
like image 288
Preethi Avatar asked Oct 11 '17 18:10

Preethi


2 Answers

Sharex means that the axes limits are the same and that the axes are synchronized. It doesn't mean that they lie on top of each other. It all depends on how you create the colorbar.

The colorbar created by pandas scatterplot is, just like any statndard colorbar in matplotlib, created by taking away part of the space for the axes that it relates to. Hence this axes is smaller than other axes from the grid.

Options you have include:

  • Shrinking the other axes of the grid by the same amount than the scatterplot axes.
    This can be done by using the position of the first axes and set the position of the second axes accordingly, using ax.get_position() and ax.set_postion()

    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    import itertools as it
    
    xy = list( it.product( range(10), range(10) ) )
    df = pd.DataFrame( xy, columns=['x','y'] )
    df['score'] = np.random.random( 100 )
    
    kw = {'height_ratios':[13,2]}
    fig, (ax,ax2) = plt.subplots(2,1,  gridspec_kw=kw, sharex=True)
    
    df.plot(kind='scatter', x='x',  y='y', c='score', s=100, cmap="PuRd",
              ax=ax, colorbar=True)
    df.groupby("x").mean().plot(kind = 'bar', y='score',ax=ax2, legend=False)
    
    ax2.legend(bbox_to_anchor=(1.03,0),loc=3)
    
    pos = ax.get_position()
    pos2 = ax2.get_position()
    ax2.set_position([pos.x0,pos2.y0,pos.width,pos2.height])
    
    plt.show()
    

enter image description here

  • Create a grid including an axes for the colorbar.
    In this case you can create a 4 by 4 grid and add the colorbar to the upper right axes of it. This requires to supply the scatter plot to fig.colorbar() and specify an axes for the colorbar to live in,

    fig.colorbar(ax.collections[0], cax=cax)       
    

    Then remove the lower right axes, which is not needed (ax.axis("off")). You may still share the axes, if that is needed, via ax2.get_shared_x_axes().join(ax, ax2).

    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    import itertools as it
    
    
    xy = list( it.product( range(10), range(10) ) )
    df = pd.DataFrame( xy, columns=['x','y'] )
    df['score'] = np.random.random( 100 )
    
    kw = {'height_ratios':[13,2], "width_ratios":[95,5]}
    fig, ((ax, cax),(ax2,aux)) = plt.subplots(2,2,  gridspec_kw=kw)
    
    df.plot(kind='scatter', x='x',  y='y', c='score', s=80, cmap="PuRd",
             ax=ax,colorbar=False)
    df.groupby("x").mean().plot(kind = 'bar', y='score',ax=ax2, legend=False)
    
    fig.colorbar(ax.collections[0], cax=cax, label="score")
    aux.axis("off")
    ax2.legend(bbox_to_anchor=(1.03,0),loc=3)
    ax2.get_shared_x_axes().join(ax, ax2)
    ax.tick_params(axis="x", labelbottom=0)
    ax.set_xlabel("")
    
    plt.show()
    

enter image description here

like image 122
ImportanceOfBeingErnest Avatar answered Oct 29 '22 17:10

ImportanceOfBeingErnest


Based on ImportanceOfBeingErnest's answer the two following functions would align the axis:

def align_axis_x(ax, ax_target):
    """Make x-axis of `ax` aligned with `ax_target` in figure"""
    posn_old, posn_target = ax.get_position(), ax_target.get_position()
    ax.set_position([posn_target.x0, posn_old.y0, posn_target.width, posn_old.height])

def align_axis_y(ax, ax_target):
    """Make y-axis of `ax` aligned with `ax_target` in figure"""
    posn_old, posn_target = ax.get_position(), ax_target.get_position()
    ax.set_position([posn_old.x0, posn_target.y0, posn_old.width, posn_target.height])
like image 23
leifdenby Avatar answered Oct 29 '22 18:10

leifdenby