Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Adding y=x to a matplotlib scatter plot if I haven't kept track of all the data points that went in

Here's some code that does scatter plot of a number of different series using matplotlib and then adds the line y=x:

import numpy as np, matplotlib.pyplot as plt, matplotlib.cm as cm, pylab  nseries = 10 colors = cm.rainbow(np.linspace(0, 1, nseries))  all_x = [] all_y = [] for i in range(nseries):     x = np.random.random(12)+i/10.0     y = np.random.random(12)+i/5.0     plt.scatter(x, y, color=colors[i])     all_x.extend(x)     all_y.extend(y)  # Could I somehow do the next part (add identity_line) if I haven't been keeping track of all the x and y values I've seen? identity_line = np.linspace(max(min(all_x), min(all_y)),                             min(max(all_x), max(all_y))) plt.plot(identity_line, identity_line, color="black", linestyle="dashed", linewidth=3.0)  plt.show() 

In order to achieve this I've had to keep track of all the x and y values that went into the scatter plot so that I know where identity_line should start and end. Is there a way I can get y=x to show up even if I don't have a list of all the points that I plotted? I would think that something in matplotlib can give me a list of all the points after the fact, but I haven't been able to figure out how to get that list.

like image 467
kuzzooroo Avatar asked Aug 26 '14 03:08

kuzzooroo


People also ask

How do you avoid overlapping plots in python?

Dot Size. You can try to decrease marker size in your plot. This way they won't overlap and the patterns will be clearer.

What is the method used in Matplotlib to generate scatter plots?

In python matplotlib, the scatterplot can be created using the pyplot. plot() or the pyplot. scatter() .

How do I show all Xticks in Matplotlib?

Use xticks() method to show all the X-coordinates in the plot. Use yticks() method to show all the Y-coordinates in the plot. To display the figure, use show() method.


1 Answers

You don't need to know anything about your data per se. You can get away with what your matplotlib Axes object will tell you about the data.

See below:

import numpy as np import matplotlib.pyplot as plt  # random data  N = 37 x = np.random.normal(loc=3.5, scale=1.25, size=N) y = np.random.normal(loc=3.4, scale=1.5, size=N) c = x**2 + y**2  # now sort it just to make it look like it's related x.sort() y.sort()  fig, ax = plt.subplots() ax.scatter(x, y, s=25, c=c, cmap=plt.cm.coolwarm, zorder=10) 

Here's the good part:

lims = [     np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes     np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes ]  # now plot both limits against eachother ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0) ax.set_aspect('equal') ax.set_xlim(lims) ax.set_ylim(lims) fig.savefig('/Users/paul/Desktop/so.png', dpi=300) 

Et voilà

enter image description here

like image 91
Paul H Avatar answered Sep 19 '22 07:09

Paul H