Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to add column next to Seaborn heat map

Given the code below, which produces a heat map, how can I get column "D" (the total column) to display as a column to the right of the heat map with no color, just aligned total values per cell? I'm also trying to move the labels to the top. I don't mind that the labels on the left are horizontal as this does not occur with my actual data.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
%matplotlib inline
df = pd.DataFrame(
      {'A' : ['A', 'A', 'B', 'B','C', 'C', 'D', 'D'],
       'B' : ['A', 'B', 'A', 'B','A', 'B', 'A', 'B'],
       'C' : [2, 4, 5, 2, 0, 3, 9, 1],
       'D' : [6, 6, 7, 7, 3, 3, 10, 10]})

df=df.pivot('A','B','C')
fig, ax = plt.subplots(1, 1, figsize =(4,6))

sns.heatmap(df, annot=True, linewidths=0, cbar=False)
plt.show()

Here's the desired result:

Desired Result

Thanks in advance!

like image 818
Dance Party2 Avatar asked Dec 15 '15 19:12

Dance Party2


Video Answer


1 Answers

I think the cleanest way (although probably not the shortest), would be to plot Total as one of the columns, and then access colors of the facets of the heatmap and change some of them to white.

The element that is responsible for color on heatmap is matplotlib.collections.QuadMesh. It contains all facecolors used for each facet of the heatmap, from left to right, bottom to top.

You can modify some colors and pass them back to QuadMesh before you plt.show().

There is a slight problem that seaborn changes text color of some of the annotations to make them visible on dark background, and they become invisible when you change to white color. So for now I set color of all text to black, you will need to figure out what is best for your plots.

Finally, to put x axis ticks and label on top, use:

ax.xaxis.tick_top()
ax.xaxis.set_label_position('top') 

The final version of the code:

import matplotlib.pyplot as plt
from matplotlib.collections import QuadMesh
from matplotlib.text import Text

import seaborn as sns
import pandas as pd
import numpy as np
%matplotlib inline

df = pd.DataFrame(
      {'A' : ['A', 'A', 'B', 'B','C', 'C', 'D', 'D'],
       'B' : ['A', 'B', 'A', 'B','A', 'B', 'A', 'B'],
       'C' : [2, 4, 5, 2, 0, 3, 9, 1],
       'D' : [6, 6, 7, 7, 3, 3, 10, 10]})

df=df.pivot('A','B','C')

# create "Total" column
df['Total'] = df['A'] + df['B']

fig, ax = plt.subplots(1, 1, figsize =(4,6))

sns.heatmap(df, annot=True, linewidths=0, cbar=False)

# find your QuadMesh object and get array of colors
quadmesh = ax.findobj(QuadMesh)[0]
facecolors = quadmesh.get_facecolors()

# make colors of the last column white
facecolors[np.arange(2,12,3)] = np.array([1,1,1,1])

# set modified colors
quadmesh.set_facecolors = facecolors

# set color of all text to black
for i in ax.findobj(Text):
    i.set_color('black')

# move x ticks and label to the top
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top') 

plt.show()

final figure

P.S. I am on Python 2.7, some syntax adjustments might be required, though I cannot think of any.

like image 143
Sergey Antopolskiy Avatar answered Nov 13 '22 08:11

Sergey Antopolskiy