Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make a scatter plot for clustering in Python

I am carrying out clustering and try to plot the result. A dummy data set is :

data

import numpy as np

X = np.random.randn(10)
Y = np.random.randn(10)
Cluster = np.array([0, 1, 1, 1, 3, 2, 2, 3, 0, 2])    # Labels of cluster 0 to 3

cluster center

 centers = np.random.randn(4, 2)    # 4 centers, each center is a 2D point

Question

I want to make a scatter plot to show the points in data and color the points based on the cluster labels.

Then I want to superimpose the center points on the same scatter plot, in another shape (e.g. 'X') and a fifth color (as there are 4 clusters).


Comment

  • I turned to seaborn 0.6.0 but found no API to accomplish the task.
  • ggplot by yhat could made the scatter plot nice but the second plot would replace the first one.
  • I got confused by the color and cmap in matplotlib so I wonder if I could use seaborn or ggplot to do it.
like image 431
Zelong Avatar asked Jun 30 '15 11:06

Zelong


1 Answers

The first part of your question can be done using colorbar and specifying the colours to be the Cluster array. I have vaguely understood the second part of your question, but I believe this is what you are looking for.

import numpy as np
import matplotlib.pyplot as plt

x = np.random.randn(10)
y = np.random.randn(10)
Cluster = np.array([0, 1, 1, 1, 3, 2, 2, 3, 0, 2])    # Labels of cluster 0 to 3
centers = np.random.randn(4, 2) 

fig = plt.figure()
ax = fig.add_subplot(111)
scatter = ax.scatter(x,y,c=Cluster,s=50)
for i,j in centers:
    ax.scatter(i,j,s=50,c='red',marker='+')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.colorbar(scatter)

fig.show()

which results in:

enter image description here

wherein your "centres" have been shown using + marker. You can specify any colours you want to them in the same way have done for x and y

like image 96
Srivatsan Avatar answered Sep 23 '22 17:09

Srivatsan