Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot multiple figures in a row using seaborn

I have a dataframe df that looks like this:

df.head()
id        feedback        nlp_model        similarity_score
0xijh4    1               tfidf            0.36
0sdnj7    -1              lda              0.89
kjh458    1               doc2vec          0.78
....

I want to plot similairty_score versus feedback in a boxplot form using seaborn for each of the unique values in the model column: tfidf, lda, doc2vec. My code for this is as follows:

fig, ax = plt.subplots(figsize=(10,8))
ax = sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'])
ax = sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'], color="0.25")

fig, ax = plt.subplots(figsize=(10,8))
ax = sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'])
ax = sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'], color="0.25")

fig, ax = plt.subplots(figsize=(10,8))
ax = sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'])
ax = sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'], color="0.25")

plt.show()

The problem is this creates 3 plots one on top of the other.

enter image description here

How can I generate these same plots but all on a single line, with one axis marking "Similarity Score" on the left most plot only, and "Feedback" axis label directly below each plot?

like image 867
PyRsquared Avatar asked Jan 11 '18 10:01

PyRsquared


Video Answer


1 Answers

You are creating new figures, each time you plot. So you can remove all but one of the calls to plt.subplots()

The seaborn swarmplot() and boxplot() accept ax arguments i.e. you can tell it which axes to plot to. Therefore, create your figure, subplots and axes using:

fig, (ax1, ax2, ax3) = plt.subplots(1, 3)

Then you can do something like:

sns.boxplot(x="x_vals", y="y_vals", data=some_data, ax=ax1)

You can then manipulate the axes as you see fit. For example, removing the y axis labels only on certain subplots etc.

fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(10,8))

sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'], ax=ax1)
sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'], color="0.25", ax=ax1)

sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'], ax=ax2)
sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'], color="0.25", ax=ax2)

ax2.set_ylabel("")  # remove y label, but keep ticks

sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'], ax=ax3)
sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'], color="0.25", ax=ax3)

ax3.set_ylabel("")  # remove y label, but keep ticks

plt.show()
like image 120
DavidG Avatar answered Oct 12 '22 14:10

DavidG