Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plotting correlation heatmaps with Seaborn FacetGrid

I am trying to create a single image with heatmaps representing the correlation of features of data points for each label separately. With seaborn I can create a heatmap for a single class like so

grouped = df.groupby('target')
sns.heatmap(grouped.get_group('Class_1').corr())

An I get this which makes sense:

Class_1 correlation heatmap

But then I try to make a list of all the labels like so:

g = sns.FacetGrid(df, col='target')
g.map(lambda grp: sns.heatmap(grp.corr()))

And sadly I get this which makes no sense to me:

Failing attempt to plot all classes

like image 496
fakedrake Avatar asked Apr 12 '15 17:04

fakedrake


2 Answers

Turns out you can do it pretty concisely with just seaborn if you use map_dataframe instead of map:

g = sns.FacetGrid(df, col='target')
g.map_dataframe(lambda data, color: sns.heatmap(data.corr(), linewidths=0))

@mwaskom points out in his comment that it might be a good idea to explicitly set the limits of the colormap so that the different facets can be more directly compared. The documentation describes relevant heatmap parameters:

vmin, vmax : floats, optional

Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments.

like image 136
fakedrake Avatar answered Oct 12 '22 23:10

fakedrake


Without FacetGrid, but making a corr heatmap for each group in a column:

import pandas as pd
import seaborn as sns
from numpy.random import randint
import matplotlib.pyplot as plt


df = pd.DataFrame(randint(0,10,(200,12)),columns=list('abcdefghijkl'))
grouped = df.groupby('a')
rowlength = grouped.ngroups/2 # fix up if odd number of groups
fig, axs = plt.subplots(figsize=(9,4), nrows=2, ncols=rowlength)

targets = zip(grouped.groups.keys(), axs.flatten())
for i, (key, ax) in enumerate(targets):
    sns.heatmap(grouped.get_group(key).corr(), ax=ax,
                xticklabels=(i >= rowlength),
                yticklabels=(i%rowlength==0),
                cbar=False) # Use cbar_ax into single side axis
    ax.set_title('a=%d'%key)
plt.show()

enter image description here Maybe there's a way to set up a lambda to correctly pass the data from the g.facet_data() generator through corr before going to heatmap.

like image 39
cphlewis Avatar answered Oct 13 '22 00:10

cphlewis