Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot heatmap for high-dimensional dataset?

I would greatly appreciate if you could let me know how to plot high-resolution heatmap for a large dataset with approximately 150 features.

My code is as follows:

XX = pd.read_csv('Financial Distress.csv')

y = np.array(XX['Financial Distress'].values.tolist())
y = np.array([0 if i > -0.50 else 1 for i in y])
XX = XX.iloc[:, 3:87]
df=XX
df["target_var"]=y.tolist()
target_var=["target_var"]

fig, ax = plt.subplots(figsize=(8, 6))
correlation = df.select_dtypes(include=['float64',
                                             'int64']).iloc[:, 1:].corr()
sns.heatmap(correlation, ax=ax, vmax=1, square=True)
plt.xticks(rotation=90)
plt.yticks(rotation=360)
plt.title('Correlation matrix')
plt.tight_layout()
plt.show()
k = df.shape[1]  # number of variables for heatmap
fig, ax = plt.subplots(figsize=(9, 9))
corrmat = df.corr()
# Generate a mask for the upper triangle
mask = np.zeros_like(corrmat, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
cols = corrmat.nlargest(k, target_var)[target_var].index
cm = np.corrcoef(df[cols].values.T)
sns.set(font_scale=1.0)
hm = sns.heatmap(cm, mask=mask, cbar=True, annot=True,
                 square=True, fmt='.2f', annot_kws={'size': 7},
                 yticklabels=cols.values,
                 xticklabels=cols.
                 values)
plt.xticks(rotation=90)
plt.yticks(rotation=360)
plt.title('Annotated heatmap matrix')
plt.tight_layout()
plt.show()

It works fine but the plotted heatmap for a dataset with more than 40 features is too small. enter image description here

Thanks in advance,

like image 880
ebrahimi Avatar asked Jun 23 '18 03:06

ebrahimi


People also ask

How do you Visualise high dimensional data?

Using Hypertools - A Python Toolbox Data visualization helps in identifying hidden patterns, associations, and trends between different columns of data. We create different types of charts, plots, graphs, etc. in order to understand what data is all about and how different columns are related to each other.

Can heatmap be used for categorical data?

If we want to see how categorical variables interact with each other, heatmaps are a very useful way to do so. While you can use a heatmap to visualize the relationship between any two categorical variables, it's quite common to use heatmaps across dimensions of time.

What is the difference between Correlogram and heatmap?

A correlogram is a variant of the heatmap that replaces each of the variables on the two axes with a list of numeric variables in the dataset. Each cell depicts the relationship between the intersecting variables, such as a linear correlation.


1 Answers

Adjusting the figsize and dpi worked for me.

I adapted your code and doubled the size of the heatmap to 165 x 165. The rendering takes a while, but the png looks fine. My backend is "module://ipykernel.pylab.backend_inline."

As noted in my original answer, I'm pretty sure you forgot close the figure object before creating a new one. Try plt.close("all") before fig, ax = plt.subplots() if you get wierd effects.

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

print(plt.get_backend())

# close any existing plots
plt.close("all")

df = pd.read_csv("Financial Distress.csv")
# select out the desired columns
df = df.iloc[:, 3:].select_dtypes(include=['float64','int64'])

# copy columns to double size of dataframe
df2 = df.copy()
df2.columns = "c_" + df2.columns
df3 = pd.concat([df, df2], axis=1)

# get the correlation coefficient between the different columns
corr = df3.iloc[:, 1:].corr()
arr_corr = corr.as_matrix()
# mask out the top triangle
arr_corr[np.triu_indices_from(arr_corr)] = np.nan

fig, ax = plt.subplots(figsize=(24, 18))

hm = sns.heatmap(arr_corr, cbar=True, vmin=-0.5, vmax=0.5,
                 fmt='.2f', annot_kws={'size': 3}, annot=True, 
                 square=True, cmap=plt.cm.Blues)

ticks = np.arange(corr.shape[0]) + 0.5
ax.set_xticks(ticks)
ax.set_xticklabels(corr.columns, rotation=90, fontsize=8)
ax.set_yticks(ticks)
ax.set_yticklabels(corr.index, rotation=360, fontsize=8)

ax.set_title('correlation matrix')
plt.tight_layout()
plt.savefig("corr_matrix_incl_anno_double.png", dpi=300)

full figure: corr_matrix_anno_double_image zoom of top left section: zoom_of_top_end_image

like image 194
Mark Teese Avatar answered Oct 13 '22 04:10

Mark Teese