I have a pandas dataframe with 3 classes and datapoints of n features.
The following code produces a scatter matrix with histograms in the diagonal, of 4 of the features in the dataframe.
colums = ['n1','n2','n3','n4']
grr = pd.scatter_matrix(
dataframe[columns], c=y_train, figsize=(15,15), label=['B','N','O'], marker='.',
hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg')
plt.legend()
plt.show()
like this:
The problem I'm having is that plt.legend() doesn't seem to work, it shown no legend at all (or it's the tiny 'le8' barely visible in the first column of the second row...)
What I'd like to have is a single legend that just shows which color is which class.
I've tried all the suggested questions but none have a solution. I also tried to put the labels in the legend function parameters like this:
plt.legend(label=['B','N','O'], loc=1)
but to no avail..
What am I doing wrong?
We can try to add legend to the scatterplot colored by a variable, by using legend() function in Matplotlib. In legend(), we specify title and handles by extracting legend elements from the plot. Our first attempt to add legends did not work well. We can see that we have a legend with colors but not the variable names.
The variables are written in a diagonal line from top left to bottom right. Then each variable is plotted against each other. For example, the middle square in the first column is an individual scatterplot of Girth and Height, with Girth as the X-axis and Height as the Y-axis.
The pandas scatter_matrix
is a wrapper for several matplotlib scatter
plots. Arguments are passed on to the scatter
function. However, the scatter is usually meant to be used with a colormap and not a legend with discrete labeled points, so there is no argument available to create a legend automatically.
I'm affraid you have to manually create the legend. To this end you may create the dots from the scatter using matplotlib's plot
function (with empty data) and add them as handles to the legend.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.subplot.right"] = 0.8
v= np.random.rayleigh(size=(30,5))
v[:,4] = np.random.randint(1,4,size=30)/3.
dataframe= pd.DataFrame(v, columns=['n1','n2','n3','n4',"c"])
columns = ['n1','n2','n3','n4']
grr = pd.scatter_matrix(
dataframe[columns], c=dataframe["c"], figsize=(7,5), label=['B','N','O'], marker='.',
hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg')
handles = [plt.plot([],[],color=plt.cm.brg(i/2.), ls="", marker=".", \
markersize=np.sqrt(10))[0] for i in range(3)]
labels=["Label A", "Label B", "Label C"]
plt.legend(handles, labels, loc=(1.02,0))
plt.show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With