Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

additional row colors in seaborn cluster map

I am currently generating clustermaps in seaborn and labeling the row colors as below.

matrix = pd.DataFrame(np.random.random_integers(0,1, size=(50,4)))
labels = np.random.random_integers(0,5, size=50)

lut = dict(zip(set(labels), sns.hls_palette(len(set(labels)), l=0.5, s=0.8)))
row_colors = pd.DataFrame(labels)[0].map(lut)

g=sns.clustermap(matrix, col_cluster=False, linewidths=0.1, cmap='coolwarm', row_colors=row_colors)
plt.show()

I have a second annotation column similar to the labels data I would also like to add to the plot. The seaborn API doesn't support adding a second row_colors column, which is fine, but I am struggling in finding a workaround using matplotlib to add this annotation column to the clustering.

If I cannot use seaborn to do this and have to generate all of this manually using matplotlib that would be fine, I just can't figure that out either.

Thanks for your help!

clustermap example plot

like image 619
Anthony Avatar asked Jan 09 '18 17:01

Anthony


Video Answer


2 Answers

The solution is below. The seaborn API does actually allow this to be done.

matrix = pd.DataFrame(np.random.random_integers(0,1, size=(50,4)))

labels = np.random.random_integers(0,5, size=50)
lut = dict(zip(set(labels), sns.hls_palette(len(set(labels)), l=0.5, s=0.8)))
row_colors = pd.DataFrame(labels)[0].map(lut)

#Create additional row_colors here
labels2 = np.random.random_integers(0,1, size=50)
lut2 = dict(zip(set(labels2), sns.hls_palette(len(set(labels2)), l=0.5, s=0.8)))
row_colors2 = pd.DataFrame(labels2)[0].map(lut2)

g=sns.clustermap(matrix, col_cluster=False, linewidths=0.1, cmap='coolwarm', row_colors=[row_colors, row_colors2])
plt.show()

This produces a Clustermap with two additional columns: Clustermap with two additional columns

like image 111
Anthony Avatar answered Dec 09 '22 03:12

Anthony


I tried to concat the row_colors dataframe by pandas and it worked! Please try this code:

import seaborn as sns; sns.set(color_codes=True)
import matplotlib.pyplot as plt
import pandas as pd

iris = sns.load_dataset("iris")
print(iris)
species = iris.pop("species")


lut1 = dict(zip(species.unique(), ['#ED2323','#60FD00','#808080']))
row_colors1 = species.map(lut1)

lut2 = dict(zip(species.unique(), "rbg"))
row_colors2 = species.map(lut2)

row_colors = pd.concat([row_colors1,row_colors2],axis=1)
print(row_colors)

g = sns.clustermap(iris, row_colors=row_colors, col_cluster=False,cmap="mako", yticklabels=False, xticklabels=False)

plt.show()

enter image description here

like image 24
Amy Annine Avatar answered Dec 09 '22 02:12

Amy Annine