Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to label Y ticklabels as group/category in seaborn clustermap?

I want to make a clustermap/heatmap of gene presence-absence data from patients where the genes will be grouped into categories (e.g chemotaxis, endotoxin etc) and labelled appropriately. I haven't found any such option in seaborn documentation. I know how to generate the heatmap, I just don't know how to label yticks as categories. Here is a sample (unrelated to my work) of what I want to achieve:

heatmap

Here , yticklabels January, February and March are given group label winter and other yticklabels are also similarly labelled.

like image 572
Ahmed Abdullah Avatar asked Nov 14 '19 10:11

Ahmed Abdullah


People also ask

How do I add labels to Seaborn heatmap?

The heatmap can show the exact value behind the color. To add a label to each cell, annot parameter of the heatmap() function should be set to True .

How do you annotate a heatmap in Seaborn?

To annotate each cell of a heatmap, we can make annot = True in heatmap() method.

How does Seaborn Clustermap work?

The . clustermap() method uses a hierarchical clusters to order data by similarity. This reorganizes the data for the rows and columns and displays similar content next to one another for even more depth of understanding the data.

What is cluster map in Seaborn?

The function clustermap() in seaborn draws a hierarchically clustered heatmap. A clustered heatmap is different from an ordinary heatmap on the following terms: The heatmap cells are all clustered using a similarity algorithm. Dentograms are drawn for the columns and the rows of the heatmap.


2 Answers

I've reproduced the example you gave in seaborn, adapting @Stein's answer from here.

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby
import datetime
import seaborn as sns

def test_table():
    months = [datetime.date(2008, i+1, 1).strftime('%B') for i in range(12)]
    seasons = ['Winter',]*3 + ['Spring',]*2 + ['Summer']*3 + ['Pre-Winter',]*4
    tuples = list(zip(months, seasons))
    index = pd.MultiIndex.from_tuples(tuples, names=['first', 'second'])
    d = {i: [np.random.randint(0,50) for _ in range(12)] for i in range(1950, 1960)}
    df = pd.DataFrame(d, index=index)
    return df

def add_line(ax, xpos, ypos):
    line = plt.Line2D([ypos, ypos+ .2], [xpos, xpos], color='black', transform=ax.transAxes)
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    xpos = -.2
    scale = 1./df.index.size
    for level in range(df.index.nlevels):
        pos = df.index.size
        for label, rpos in label_len(df.index,level):
            add_line(ax, pos*scale, xpos)
            pos -= rpos
            lypos = (pos + .5 * rpos)*scale
            ax.text(xpos+.1, lypos, label, ha='center', transform=ax.transAxes) 
        add_line(ax, pos*scale , xpos)
        xpos -= .2

df = test_table()

fig = plt.figure(figsize = (10, 10))
ax = fig.add_subplot(111)
sns.heatmap(df)

#Below 3 lines remove default labels
labels = ['' for item in ax.get_yticklabels()]
ax.set_yticklabels(labels)
ax.set_ylabel('')

label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()

Gives:

Hope that helps.

like image 177
CDJB Avatar answered Nov 08 '22 04:11

CDJB


I haven't tested this with seaborn yet, but the following works with vanilla matplotlib.

enter image description here

#!/usr/bin/env python
"""
Annotate a group of y-tick labels as such.
"""

import matplotlib.pyplot as plt
from matplotlib.transforms import TransformedBbox

def annotate_yranges(groups, ax=None):
    """
    Annotate a group of consecutive yticklabels with a group name.

    Arguments:
    ----------
    groups : dict
        Mapping from group label to an ordered list of group members.
    ax : matplotlib.axes object (default None)
        The axis instance to annotate.
    """
    if ax is None:
        ax = plt.gca()

    label2obj = {ticklabel.get_text() : ticklabel for ticklabel in ax.get_yticklabels()}

    for ii, (group, members) in enumerate(groups.items()):
        first = members[0]
        last = members[-1]

        bbox0 = _get_text_object_bbox(label2obj[first], ax)
        bbox1 = _get_text_object_bbox(label2obj[last], ax)

        set_yrange_label(group, bbox0.y0 + bbox0.height/2,
                         bbox1.y0 + bbox1.height/2,
                         min(bbox0.x0, bbox1.x0),
                         -2,
                         ax=ax)


def set_yrange_label(label, ymin, ymax, x, dx=-0.5, ax=None, *args, **kwargs):
    """
    Annotate a y-range.

    Arguments:
    ----------
    label : string
        The label.
    ymin, ymax : float, float
        The y-range in data coordinates.
    x : float
        The x position of the annotation arrow endpoints in data coordinates.
    dx : float (default -0.5)
        The offset from x at which the label is placed.
    ax : matplotlib.axes object (default None)
        The axis instance to annotate.
    """

    if not ax:
        ax = plt.gca()

    dy = ymax - ymin
    props = dict(connectionstyle='angle, angleA=90, angleB=180, rad=0',
                 arrowstyle='-',
                 shrinkA=10,
                 shrinkB=10,
                 lw=1)
    ax.annotate(label,
                xy=(x, ymin),
                xytext=(x + dx, ymin + dy/2),
                annotation_clip=False,
                arrowprops=props,
                *args, **kwargs,
    )
    ax.annotate(label,
                xy=(x, ymax),
                xytext=(x + dx, ymin + dy/2),
                annotation_clip=False,
                arrowprops=props,
                *args, **kwargs,
    )


def _get_text_object_bbox(text_obj, ax):
    # https://stackoverflow.com/a/35419796/2912349
    transform = ax.transData.inverted()
    # the figure needs to have been drawn once, otherwise there is no renderer?
    plt.ion(); plt.show(); plt.pause(0.001)
    bb = text_obj.get_window_extent(renderer = ax.get_figure().canvas.renderer)
    # handle canvas resizing
    return TransformedBbox(bb, transform)


if __name__ == '__main__':

    import numpy as np

    fig, ax = plt.subplots(1,1)

    # so we have some extra space for the annotations
    fig.subplots_adjust(left=0.3)

    data = np.random.rand(10,10)
    ax.imshow(data)

    ticklabels = 'abcdefghij'
    ax.set_yticks(np.arange(len(ticklabels)))
    ax.set_yticklabels(ticklabels)

    groups = {
        'abc' : ('a', 'b', 'c'),
        'def' : ('d', 'e', 'f'),
        'ghij' : ('g', 'h', 'i', 'j')
    }

    annotate_yranges(groups)

    plt.show()
like image 6
Paul Brodersen Avatar answered Nov 08 '22 03:11

Paul Brodersen