Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python: How to create a legend using an example

This is from Chapter 2 in the book Machine Learning In Action and I am trying to make the plot pictured here:

plot

The author has posted the plot's code here, which I believe may be a bit hacky (he also mentions this code is sloppy since it is out of the book's scope).

Here is my attempt to re-create the plot:

First, the .txt file holding the data is as follows (source: "datingTestSet2.txt" in Ch.2 here):

40920   8.326976    0.953952    largeDoses
14488   7.153469    1.673904    smallDoses
26052   1.441871    0.805124    didntLike
75136   13.147394   0.428964    didntLike
38344   1.669788    0.134296    didntLike
...

Assume datingDataMat is a numpy.ndarray of shape `(1000L, 2L) where column 0 is "Frequent Flier Miles Per Year", column 1 is "% Time Playing Video Games", and column 2 is "liter of ice cream consumed per week", as shown in the sample above.

Assume datingLabels is a list of ints 1, 2, or 3 meaning "Did Not Like", "Liked in Small Doses", and "Liked in Large Doses" respectively - associated with column 3 above.

Here is the code I have to create the plot (full details for file2matrix are at the end):

datingDataMat,datingLabels = file2matrix("datingTestSet2.txt")
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot (111)
plt.xlabel("Freq flier miles")
plt.ylabel("% time video games")
# Not sure how to finish this: plt.legend([1, 2, 3], ["did not like", "small doses", "large doses"])
plt.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels)) # Change marker color and size 
plt.show()

The output is here:

enter image description here

My main concern is how to create this legend. Is there a way to do this without needing a direct handle to the points?

Next, I am curious whether I can find a way to switch the colors to match those of the plot. Is there a way to do this without having some kind of "handle" on the individual points?

Also, if interested, here is the file2matrix implementation:

def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())
    returnMat = np.zeros((numberOfLines,3)) #numpy.zeros(shape, dtype=float, order='C') 
    classLabelVector = []
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3] # FFmiles/yr, % time gaming, L ice cream/wk
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector
like image 879
modulitos Avatar asked Nov 27 '25 12:11

modulitos


1 Answers

Here's an example that mimics the code you already have that shows the approach described in Saullo Castro's example. It also shows how to set the colors in the example. If you want more information on the colors available, see the documentation at http://matplotlib.org/api/colors_api.html

It would also be worth looking at the scatter plot documentation at http://matplotlib.org/1.3.1/api/pyplot_api.html#matplotlib.pyplot.scatter

from numpy.random import rand, randint
from matplotlib import pyplot as plt
n = 1000
# Generate random data
data = rand(n, 2)
# Make a random array to mimic datingLabels
labels = randint(1, 4, n)
# Separate the data according to the labels
data_1 = data[labels==1]
data_2 = data[labels==2]
data_3 = data[labels==3]
# Plot each set of points separately
# 's' is the size parameter.
# 'c' is the color parameter.
# I have chosen the colors so that they match the plot shown.
# With each set of points, input the desired label for the legend.
plt.scatter(data_1[:,0], data_1[:,1], s=15, c='r', label="label 1")
plt.scatter(data_2[:,0], data_2[:,1], s=30, c='g', label="label 2")
plt.scatter(data_3[:,0], data_3[:,1], s=45, c='b', label="label 3")
# Put labels on the axes
plt.ylabel("ylabel")
plt.xlabel("xlabel")
# Place the Legend in the plot.
plt.gca().legend(loc="upper left")
# Display it.
plt.show()

The gray borders should become white if you use plt.savefig to save the figure to file instead of displaying it. Remember to run plt.clf() or plt.cla() after saving to file to clear the axes so you don't end up replotting the same data on top of itself over and over again.

like image 190
IanH Avatar answered Nov 29 '25 02:11

IanH



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!