I want to create multiple imshows via plt.subplots. The axes of each imshow should be labeled with strings, not with numbers (these being correlation matrices representing correlations between categories).
I figured out from the documentation (very bottom), that plt.yticks()
returns what I want, but I seem not to be able to set them. Also ax.yticks(...)
does not work.
I found the docs about the ticker locator and formatter but I am not sure if or how this could be useful
A = np.random.random((3,3))
B = np.random.random((3,3))+1
C = np.random.random((3,3))+2
D = np.random.random((3,3))+3
lbls = ['la', 'le', 'li']
fig, axar = plt.subplots(2,2)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
ar_plts = [A, B, C, D]
for i,ax in enumerate(axar.flat):
im = ax.imshow(ar_plts[i]
, interpolation='nearest'
, origin='lower')
ax.grid(False)
plt.yticks(np.arange(len(lbls)), lbls)
fig.colorbar(im, cax=cbar_ax)
fig_path = r"blah/blub"
fig_name = "matrices.png"
fig_fobj = os.path.join(fig_path, fig_name)
fig.savefig(fig_fobj)
You can change the numbers with either plt.xticks
or ax.set_xticks
(the same for y) but this does not allow you to change the labels of the ticks. For this you need ax.set_xticklabels
(the same for y).
This code worked for me
A = np.random.random((3,3))
B = np.random.random((3,3))+1
C = np.random.random((3,3))+2
D = np.random.random((3,3))+3
lbls = ['la', 'le', 'li']
fig, axar = plt.subplots(2,2)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
ar_plts = [A, B, C, D]
for i,ax in enumerate(axar.flat):
im = ax.imshow(ar_plts[i]
, interpolation='nearest'
, origin='lower')
ax.grid(False)
ax.set_yticks([0,1,2])
ax.set_xticks([0,1,2])
ax.set_xticklabels(lbls)
ax.set_yticklabels(lbls)
fig.colorbar(im, cax=cbar_ax)
fig_path = r"blah/blub"
fig_name = "matrices.png"
fig_fobj = os.path.join(fig_path, fig_name)
fig.savefig(fig_fobj)
You need to be careful with the colorbar for more than one plot. It gives the right values only for your last plot. If it should be correct for all plots you need to use
im = ax.imshow(ar_plts[i],
interpolation='nearest',
origin='lower',
vmin=0.0,vmax=1.0)
where I assumed that your smallest value in your data is 0.0
and the largest 1.0
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With