Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Seaborn and matplotlib control legend in subplots

I have been playing a bit with plt.legend() and ax.legend() and legend from seaborn itself and I think I'm missing something.

My first question is, could someone please explain to me how those go together, how they work and if I have subplots, what is superior to what? Meaning can I set a general definition (eg. have this legend in all subplots in this loc) and then overwrite this definition for specific subplots (eg by ax.legend())?

My second question is practical and showing my problems. Let's take the seaborn Smokers data set to illustrate it on:

import seaborn as sns
import matplotlib.pyplot as plt
tips = sns.load_dataset("tips")

# define sizes for labels, ticks, text, ...
# as defined here https://stackoverflow.com/questions/3899980/how-to-change-the-font-size-on-a-matplotlib-plot
SMALL_SIZE = 10
MEDIUM_SIZE = 14
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


# create figure
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(16,12))
ylim = (0,1)

sns.boxplot(x= 'day', y= 'tip', hue="sex",
                  data=tips, palette="Set2", ax=ax1)
sns.swarmplot(x= 'day', y= 'tip', hue="sex",
                  data=tips, palette="Set2", ax=ax2)
ax2.legend(loc='upper right')

sns.boxplot(x= 'day', y= 'total_bill', hue="sex",
                  data=tips, palette="Set2", ax=ax3)
sns.swarmplot(x= 'day', y= 'total_bill', hue="sex",
                  data=tips, palette="Set2", ax=ax4)


plt.suptitle('Smokers')
plt.legend(loc='upper right')

plt.savefig('test.png', dpi = 150)

Plot example.

If I use simply seaborn, I get a legend as in Subplot 1 and 3 -- it has the 'hue' label and follows defined font size. However, I'm not able to control its location (it has some default, see the difference between 1 and 3). If I use ax.legend() as in Subplot 2, then I can modify specific subplot but I lose the seaborn 'hue' feature (notice that the "sex" disappears) and it does not follow my font definitions. If I use plt.legend(), it only affects the Subplot before it (Subplot 4 in this case). How can I unite all this? Eg. to have one definition for all subplots or how to control the seaborn default? To make clear goal, how to have a legend as in Subplot 1 where the labels come automatically from the data (but I can change them) and the location, font size, ... is set the same for all the subplots (eg. upper right, font size of 10, ...)?

Thank you for help and explanation.

like image 242
My Work Avatar asked Sep 12 '25 19:09

My Work


1 Answers

Seaborn legends are always called with the keyword loc=best. This is hardcoded in the sourcecode. You could change the sourcecode, e.g. in this line and replace by ax.legend(). Then setting the rc parameter in your code like

plt.rc('legend', loc="upper right")

would give the desired output.

The only other option is to create the legend manually, like you do in the second case,

ax2.legend(loc="upper right", title="sex", title_fontsize="x-large")
like image 85
ImportanceOfBeingErnest Avatar answered Sep 14 '25 12:09

ImportanceOfBeingErnest