Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

plot two seaborn heatmap graphs side by side

I'm attempting to plot two seaborn graphs side by side as other graphs (successfully) have done in previous questions, only difference I can see is that heatmaps seems to be throwing an issue. The code to produce the error is:

import numpy as np; np.random.seed(0)
import seaborn as sns

uniform_data = np.random.rand(10, 12)    
uniform_data2 = np.random.rand(100, 120)

fig, ax =plt.subplots(1,2)

ax = sns.heatmap(uniform_data)
ax = sns.heatmap(uniform_data2)

Which produces the below

enter image description here

like image 892
Jake Bourne Avatar asked Jun 20 '18 12:06

Jake Bourne


People also ask

How do you combine heatmaps?

To concatenate heatmaps, simply use + operator. Under default mode, dendrograms from the second heatmap will be removed and row orders will be the same as the first one. Also row names for the first two heatmaps are removed as well. The returned value of the concatenation is a HeatmapList object.

How do you do subplots in Seaborn?

You can use the following basic syntax to create subplots in the seaborn data visualization library in Python: #define dimensions of subplots (rows, columns) fig, axes = plt. subplots(2, 2) #create chart in each subplot sns. boxplot(data=df, x='team', y='points', ax=axes[0,0]) sns.


2 Answers

You just have to use the ax parameter

fig, (ax1, ax2) = plt.subplots(1,2)
sns.heatmap(uniform_data, ax=ax1)
sns.heatmap(uniform_data2, ax=ax2)
plt.show()
like image 170
J. Doe Avatar answered Nov 14 '22 00:11

J. Doe


You have created an array of axes using fig, ax = plt.subplots(1,2). You are then overwriting that array with the result of sns.heatmap. Instead, you want to specify which axes you want to plot to using the ax= argument of sns.heatmap:

import numpy as np; np.random.seed(0)
import seaborn as sns

uniform_data = np.random.rand(10, 12)    
uniform_data2 = np.random.rand(100, 120)

fig, ax =plt.subplots(1,2)

sns.heatmap(uniform_data, ax=ax[0])
sns.heatmap(uniform_data2, ax=ax[1])

plt.show()

Which gives:

enter image description here

like image 26
DavidG Avatar answered Nov 14 '22 00:11

DavidG