Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

label matplotlib imshow axes with strings

I want to create multiple imshows via plt.subplots. The axes of each imshow should be labeled with strings, not with numbers (these being correlation matrices representing correlations between categories).

I figured out from the documentation (very bottom), that plt.yticks() returns what I want, but I seem not to be able to set them. Also ax.yticks(...) does not work.

I found the docs about the ticker locator and formatter but I am not sure if or how this could be useful

A = np.random.random((3,3))
B = np.random.random((3,3))+1
C = np.random.random((3,3))+2
D = np.random.random((3,3))+3

lbls = ['la', 'le', 'li']

fig, axar = plt.subplots(2,2)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])   

ar_plts = [A, B, C, D]

for i,ax in enumerate(axar.flat):
    im = ax.imshow(ar_plts[i]
                    , interpolation='nearest'
                    , origin='lower')
    ax.grid(False)
    plt.yticks(np.arange(len(lbls)), lbls)

fig.colorbar(im, cax=cbar_ax)

fig_path = r"blah/blub"
fig_name = "matrices.png"
fig_fobj = os.path.join(fig_path, fig_name)
fig.savefig(fig_fobj)
like image 417
Claus Avatar asked Feb 27 '15 09:02

Claus


1 Answers

You can change the numbers with either plt.xticks or ax.set_xticks (the same for y) but this does not allow you to change the labels of the ticks. For this you need ax.set_xticklabels (the same for y). This code worked for me

A = np.random.random((3,3))
B = np.random.random((3,3))+1
C = np.random.random((3,3))+2
D = np.random.random((3,3))+3

lbls = ['la', 'le', 'li']

fig, axar = plt.subplots(2,2)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])   

ar_plts = [A, B, C, D]

for i,ax in enumerate(axar.flat):
    im = ax.imshow(ar_plts[i]
                    , interpolation='nearest'
                    , origin='lower')
    ax.grid(False)
    ax.set_yticks([0,1,2])
    ax.set_xticks([0,1,2])

    ax.set_xticklabels(lbls)
    ax.set_yticklabels(lbls)

fig.colorbar(im, cax=cbar_ax)

fig_path = r"blah/blub"
fig_name = "matrices.png"
fig_fobj = os.path.join(fig_path, fig_name)
fig.savefig(fig_fobj)

You need to be careful with the colorbar for more than one plot. It gives the right values only for your last plot. If it should be correct for all plots you need to use

im = ax.imshow(ar_plts[i],
             interpolation='nearest',
             origin='lower',
             vmin=0.0,vmax=1.0)

where I assumed that your smallest value in your data is 0.0 and the largest 1.0.

like image 159
plonser Avatar answered Oct 14 '22 05:10

plonser