Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is matplotlib scatter plot slow for large number of data?

I have a dataset which contains attribute x, y and they can be plotted in x-y surface.

Originally, I use the code

df.plot(kind='scatter', x='x', y='y', alpha=0.10, s=2)
plt.gca().set_aspect('equal')

The code is pretty quick with data size about 50000.

Recently, I use a newer dataset, with size about 2500000. And the scatter plot becomes much slower.

I want to know, if it's an expected behavior, and if there is anything I can do to improve the plot speed?

like image 600
cqcn1991 Avatar asked Mar 07 '17 02:03

cqcn1991


People also ask

Why does matplotlib take so long?

It's your bottleneck. In your case, you don't need to re-draw things like the axes boundaries, tick labels, etc. 2) In your case, there are a lot of subplots with a lot of tick labels. These take a long time to draw.

How many points can matplotlib handle?

Interesting. As Jonathan Dursi's answer mentions, 20 million points is achievable with Matplotlib, but with some constraints (raster output,…).


2 Answers

Yes, it is. The reason for that is that a scatterplot of more than maybe a thousand points makes very little sense, so no one bothered to optimise it. You will be better off using some other representation for your data:

  • A heatmap if your points are distributed all over the place. Make heatmap cells pretty small
  • Draw some sort of a curve that approximates a distribution, maybe correlate your y with your x. Be sure to provide some confidence values or describe a distribution in other way; for me, for instance, building a box-with-whiskers of y for every x (or a range of x) and placing them on the same grid usually works pretty well.
  • Reduce your dataset. @sascha in comments suggests random sampling, and that's definitely a good idea. Depending on your data, maybe there is a better way to choose representative points.
like image 65
Synedraacus Avatar answered Oct 15 '22 10:10

Synedraacus


I had same problem with more than 300k 2D coordinates from a dimension reduction algorithm and the solution was be approximate that coordinates into a 2D numpy array and visualize it as an image. The result was pretty good and also much faster:

def plot_to_buf(data, height=2800, width=2800, inc=0.3):
    xlims = (data[:,0].min(), data[:,0].max())
    ylims = (data[:,1].min(), data[:,1].max())
    dxl = xlims[1] - xlims[0]
    dyl = ylims[1] - ylims[0]

    print('xlims: (%f, %f)' % xlims)
    print('ylims: (%f, %f)' % ylims)

    buffer = np.zeros((height+1, width+1))
    for i, p in enumerate(data):
        print('\rloading: %03d' % (float(i)/data.shape[0]*100), end=' ')
        x0 = int(round(((p[0] - xlims[0]) / dxl) * width))
        y0 = int(round((1 - (p[1] - ylims[0]) / dyl) * height))
        buffer[y0, x0] += inc
        if buffer[y0, x0] > 1.0: buffer[y0, x0] = 1.0
    return xlims, ylims, buffer

data = load_data() # data.shape = (310216, 2) <<< your data here
xlims, ylims, I = plot_to_buf(data, height=h, width=w, inc=0.3)
ax_extent = list(xlims)+list(ylims)
plt.imshow(I,
           vmin=0,
           vmax=1, 
           cmap=plt.get_cmap('hot'),
           interpolation='lanczos',
           aspect='auto',
           extent=ax_extent
           )
plt.grid(alpha=0.2)
plt.title('Latent space')
plt.colorbar()

here is the result:

I hope this helps you.

like image 27
Dmitry Avatar answered Oct 15 '22 11:10

Dmitry