Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Heatmap with multiple colormaps by column

I have a dataframe where each column contains values considered "normal" if they fall within an interval, which is different for every column:

# The main df
df = pd.DataFrame({"A": [20, 10, 7, 39], 
                   "B": [1, 8, 12, 9], 
                   "C": [780, 800, 1200, 250]})

The df_info represents the intervals for each column of df. So for example df_info["A"][0] is the min for the column df["A"] and df_info["A"][1] represents the max for the column df["A"] and so on.

df_info =  pd.DataFrame({"A": [22, 35], 
                   "B": [5, 10], 
                   "C": [850, 900]})

Thanks to this SO Answer I was able to create a custom heatmap to print in blue values below the range, in red value above the range and in white values within the range. Just remember each column has a different range. SO i normalized according to this:

df_norm = pd.DataFrame()
for col in df:
    col_min = df_info[col][0]
    col_max = df_info[col][1]
    df_norm[col] = (df[col] - col_min) / (col_max - col_min)

And finally printed my heatmap

vmin = df_norm.min().min()
vmax = df_norm.max().max()

norm_zero = (0 - vmin) / (vmax - vmin)
norm_one = (1 - vmin) / (vmax - vmin)
colors = [[0, 'darkblue'],
            [norm_zero, 'white'],
            [norm_one, 'white'],
            [1, 'darkred']
            ]
cmap = LinearSegmentedColormap.from_list('', colors, )
fig, ax = plt.subplots()

ax=sns.heatmap(data=data, 
            annot=True,
            annot_kws={'size': 'large'},
            mask=None,
            cmap=cmap,
            vmin=vmin,
            vmax=vmax) \
        .set_facecolor('white')

enter image description here

In the example you can see that the third column has values much higher/lower compared to the the 0-1 interval (and to the first column) so they "absorb" all the shades of red and blue.

QUESTION: What I want to obtain is use the entire shades of red/blue for each column or at least to reduce the perceptual difference between (for example) the first and third column.

I had tough of:

  1. create a custom colormap where each colormap normalization is performed by column
  2. use multiple colormaps, each one applied to a different column
  3. applying a colormap mpl.colors.LogNorm but I'm not sure how to use it with my custom LinearSegmentedColormap
like image 999
Leonardo Avatar asked Sep 05 '25 01:09

Leonardo


1 Answers

Using a mask per column, you could draw the heatmap column per column, each with its own colormap:

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.cm import ScalarMappable

df = pd.DataFrame({"A": [20, 10, 7, 39],
                   "B": [1, 8, 12, 9],
                   "C": [780, 800, 1200, 250]})
df_info = pd.DataFrame({"A": [22, 35],
                        "B": [5, 10],
                        "C": [850, 900]})
df_norm = pd.DataFrame()
for col in df:
    col_min = df_info[col][0]
    col_max = df_info[col][1]
    df_norm[col] = (df[col] - col_min) / (col_max - col_min)

fig, ax = plt.subplots()

for col in df:
    vmin = df_norm[col].min()
    vmax = df_norm[col].max()

    norm_zero = (0 - vmin) / (vmax - vmin)
    norm_one = (1 - vmin) / (vmax - vmin)
    colors = [[0, 'darkblue'],
              [norm_zero, 'white'],
              [norm_one, 'white'],
              [1, 'darkred']]
    cmap = LinearSegmentedColormap.from_list('', colors)
    mask = df.copy()
    for col_m in mask:
        mask[col_m] = col != col_m

    sns.heatmap(data=df_norm,
                annot=df.to_numpy(), annot_kws={'size': 'large'}, fmt="g",
                mask=mask,
                cmap=cmap, vmin=vmin, vmax=vmax, cbar=False, ax=ax)

ax.set_facecolor('white')

colors = [[0, 'darkblue'],
          [1 / 3, 'white'],
          [2 / 3, 'white'],
          [1, 'darkred']]
cmap = LinearSegmentedColormap.from_list('', colors)
cbar = plt.colorbar(ScalarMappable(cmap=cmap), ax=ax, ticks=[0, 1 / 3, 2 / 3, 1])
cbar.ax.yaxis.set_ticklabels(['min\nlimit', 'min', 'max', 'max\nlimit'])
plt.tight_layout()
plt.show()

heatmap in columns

like image 130
JohanC Avatar answered Sep 07 '25 17:09

JohanC