Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pandas legend for scatter matrix

I have a pandas dataframe with 3 classes and datapoints of n features.

The following code produces a scatter matrix with histograms in the diagonal, of 4 of the features in the dataframe.

colums = ['n1','n2','n3','n4']
grr = pd.scatter_matrix(
dataframe[columns], c=y_train, figsize=(15,15), label=['B','N','O'], marker='.',
    hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg')
plt.legend()
plt.show()

like this:

Scatter matrix of this dataframe

The problem I'm having is that plt.legend() doesn't seem to work, it shown no legend at all (or it's the tiny 'le8' barely visible in the first column of the second row...)

What I'd like to have is a single legend that just shows which color is which class.

I've tried all the suggested questions but none have a solution. I also tried to put the labels in the legend function parameters like this:

plt.legend(label=['B','N','O'], loc=1)

but to no avail..

What am I doing wrong?

like image 325
NG. Avatar asked May 05 '17 09:05

NG.


People also ask

How do you add a legend to a panda scatter plot?

We can try to add legend to the scatterplot colored by a variable, by using legend() function in Matplotlib. In legend(), we specify title and handles by extracting legend elements from the plot. Our first attempt to add legends did not work well. We can see that we have a legend with colors but not the variable names.

How do you read a scatter matrix?

The variables are written in a diagonal line from top left to bottom right. Then each variable is plotted against each other. For example, the middle square in the first column is an individual scatterplot of Girth and Height, with Girth as the X-axis and Height as the Y-axis.


1 Answers

The pandas scatter_matrix is a wrapper for several matplotlib scatter plots. Arguments are passed on to the scatter function. However, the scatter is usually meant to be used with a colormap and not a legend with discrete labeled points, so there is no argument available to create a legend automatically.

I'm affraid you have to manually create the legend. To this end you may create the dots from the scatter using matplotlib's plot function (with empty data) and add them as handles to the legend.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.subplot.right"] = 0.8

v= np.random.rayleigh(size=(30,5))
v[:,4] = np.random.randint(1,4,size=30)/3.
dataframe= pd.DataFrame(v, columns=['n1','n2','n3','n4',"c"])

columns = ['n1','n2','n3','n4']
grr = pd.scatter_matrix(
dataframe[columns], c=dataframe["c"], figsize=(7,5), label=['B','N','O'], marker='.',
    hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg')

handles = [plt.plot([],[],color=plt.cm.brg(i/2.), ls="", marker=".", \
                    markersize=np.sqrt(10))[0] for i in range(3)]
labels=["Label A", "Label B", "Label C"]
plt.legend(handles, labels, loc=(1.02,0))
plt.show()

enter image description here

like image 95
ImportanceOfBeingErnest Avatar answered Oct 12 '22 18:10

ImportanceOfBeingErnest