Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using non-linear scale with Seaborn heatmap

Tags:

python

seaborn

I'm trying to use a logarithmic scale for this heatmap below. I need a heatmap for figures between 0-30 say, and then another color scheme for larger values which could be errors.

Tried a few different approaches and am still thoroughly confused. Appreciate the help.

Cheers!

Here is the current script I'm using.

read_occupancy = pd.read_csv (r'C:\Users\holborm\Desktop\Visualisation\dataaxisplotstuff.csv')   #read the csv file (put 'r' before the path string to address any special characters, such as '\'). Don't forget to put the file name at the end of the path + ".csv"

df = DataFrame(read_occupancy)    # assign column names


#create time and detector name axis

sns.heatmap(df.set_index('Row Labels').T, cmap='magma', linecolor='white', linewidths=.05)
sns.clustermap(df.set_index('Row Labels').T, cmap='magma', linecolor='white', linewidths=.05)

Update based on Question/Answer

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.colors import LogNorm
def mix_palette():
    palette = sns.color_palette("GnBu", 10)
    palette[9] = sns.color_palette("OrRd", 10)[9]
    return palette


def set_ax(iax):
    for text in iax.texts:
        if float(text.get_text()) < 30:
            text.set_text("")
    iax.figure.tight_layout()


def load_data(path):
    initial = pd.read_csv(path, delim_whitespace=True)
    columns = list(initial.columns.values)[1:]
    rows = []
    for values in initial.values:
        rng = values[0]
        for column, value in zip(columns, values[1:]):
            rows.append([rng, column, value])
    return pd.DataFrame(data=rows, columns=['range', 'label', 'quantity'])
data = load_data('dataaxisplotstuff.csv')
data = data.pivot("range", "label", "quantity")
mi, ma = data.values.min(), data.values.max()
ax = sns.heatmap(data, cmap=mix_palette(), annot=True, square=True, cbar_kws={'ticks': ticker.LogLocator(numticks=8)},
                 xticklabels=True, yticklabels=True, norm=LogNorm(vmin=mi, vmax=ma))
set_ax(ax)
plt.show()

Receive this error

TypeError                                 Traceback (most recent call last)
<ipython-input-5-7466da1cd6c9> in <module>()
      1 data = load_data('dataaxisplotstuff.csv')
      2 data = data.pivot("range", "label", "quantity")
----> 3 mi, ma = data.values.min(), data.values.max()
      4 ax = sns.heatmap(data, cmap=mix_palette(), annot=True, square=True, cbar_kws={'ticks': ticker.LogLocator(numticks=8)},
      5                  xticklabels=True, yticklabels=True, norm=LogNorm(vmin=mi, vmax=ma))

~\AppData\Local\Continuum\anaconda3\lib\site-packages\numpy\core\_methods.py in _amin(a, axis, out, keepdims)
     27 
     28 def _amin(a, axis=None, out=None, keepdims=False):
---> 29     return umr_minimum(a, axis, None, out, keepdims)
     30 
     31 def _sum(a, axis=None, dtype=None, out=None, keepdims=False):

TypeError: '<=' not supported between instances of 'float' and 'str'
like image 820
LeCoda Avatar asked Aug 13 '18 06:08

LeCoda


1 Answers

I'll give it a try. From my understanding you want a heatmap with a color scheme for normal values and a different color for outliers, also the heatmap must be in logarithmic scale. For this I'm going to use pandas, seaborn and matplotlib. The versions are pandas: 0.22.0, matplotlib: 2.2.2 and seaborn: 0.9.0. First some functions:

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.colors import LogNorm


def mix_palette():
    palette = sns.color_palette("GnBu", 10)
    palette[9] = sns.color_palette("OrRd", 10)[9]
    return palette


def set_ax(iax):
    iax.collections[0].colorbar.set_ticklabels(['10', '30'])
    for text in iax.texts:
        if float(text.get_text()) < 30:
            text.set_text("")
    iax.figure.tight_layout()


def load_data(path):
    initial = pd.read_csv(path, delim_whitespace=True)
    columns = list(initial.columns.values)[1:]
    rows = []
    for values in initial.values:
        rng = values[0]
        for column, value in zip(columns, values[1:]):
            rows.append([rng, column, value])
    return pd.DataFrame(data=rows, columns=['range', 'label', 'quantity'])

The function mix_palette creates a mixture of palettes, set_ax make some adjustments to the figure and finally load_data receives a path directing to a csv just like the one from the example, (using whitespace as a delimiter). The output of load_data is a DataFrame with the same shape as flights from the seaborn datasets something like (row_name, column_name, value). Now the plotting code:

data = load_data('data.csv')
data = data.pivot("range", "label", "quantity")
mi, ma = data.values.min(), data.values.max()
ax = sns.heatmap(data, cmap=mix_palette(), annot=True, square=True, cbar_kws={'ticks': [10, 30],
                 xticklabels=True, yticklabels=True, norm=LogNorm(vmin=mi, vmax=ma))
set_ax(ax)
plt.savefig('image.png', bbox_inches='tight')
plt.show()

The output is: enter image description here This plots the values near or above 30 in red and also shows the numerical value for better visualization purposes. In more detail:

  • mix_palette create a mixture from the default palettes "GnBu" and "OrRd".
  • the first line set_ax sets the label of the colorbar (the bar on the side) to 10 and 30, the loop sets the values of those cells that are below 30 to the empty string. Finally makes the layout tight (the labels of the axis values are large and you can do this to show all the labels).
  • The cmap argument receives the palette, annot=True shows the values of the cells, square=True makes the cells of the heatmap square, 'ticks': [10, 30] sets the positions of the tick on the side of the colorbar and norm=LogNorm(vmin=mi, vmax=ma) is the one that handles the logarithmic scale.
  • To save a plot you can use the function plt.savefig('image.png', bbox_inches='tight') make sure you use it before showing the image.
like image 161
Dani Mesejo Avatar answered Oct 05 '22 06:10

Dani Mesejo