Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Make seaborn show a colorbar instead of a legend when using hue in a bar plot?

Let's say I want to make a bar plot where the hue of the bars represents some continuous quantity. e.g.

import seaborn as sns
titanic = sns.load_dataset("titanic")
g = titanic.groupby('pclass')
survival_rates = g['survived'].mean()
n = g.size()
ax = sns.barplot(x=n.index, y=n,
           hue=survival_rates, palette='Reds',
            dodge=False,
          )
ax.set_ylabel('n passengers')

bar plot drawn by sns

The legend here is kind of silly, and gets even worse the more bars I plot. What would make most sense is a colorbar (such as are used when calling sns.heatmap). Is there a way to make seaborn do this?

like image 874
Coquelicot Avatar asked Apr 10 '18 19:04

Coquelicot


People also ask

How do you show bar chart values in Seaborn?

In seaborn barplot with bar, values can be plotted using sns. barplot() function and the sub-method containers returned by sns. barplot(). Import pandas, numpy, and seaborn packages.

What is hue parameter in Seaborn?

In seaborn, the hue parameter represents which column in the data frame, you want to use for color encoding.


2 Answers

The other answer is a bit hacky. So a more stringent solution, without producing plots that are deleted afterwards, would involve the manual creation of a ScalarMappable as input for the colorbar.

import matplotlib.pyplot as plt
import seaborn as sns
titanic = sns.load_dataset("titanic")
g = titanic.groupby('pclass')
survival_rates = g['survived'].mean()
n = g.size()

norm = plt.Normalize(survival_rates.min(), survival_rates.max())
sm = plt.cm.ScalarMappable(cmap="Reds", norm=norm)
sm.set_array([])

ax = sns.barplot(x=n.index, y=n, hue=survival_rates, palette='Reds', 
                 dodge=False)

ax.set_ylabel('n passengers')
ax.get_legend().remove()
ax.figure.colorbar(sm)

plt.show()
like image 146
ImportanceOfBeingErnest Avatar answered Oct 19 '22 10:10

ImportanceOfBeingErnest


You can try this:

import matplotlib.pyplot as plt
import seaborn as sns
titanic = sns.load_dataset("titanic")
g = titanic.groupby('pclass')
survival_rates = g['survived'].mean()
n = g.size()

plot = plt.scatter(n.index, n, c=survival_rates, cmap='Reds')
plt.clf()
plt.colorbar(plot)
ax = sns.barplot(x=n.index, y=n, hue=survival_rates, palette='Reds', dodge=False)
ax.set_ylabel('n passengers')
ax.legend_.remove()

Output: enter image description here

like image 40
Scott Boston Avatar answered Oct 19 '22 08:10

Scott Boston