Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pyplot: really slow creating heatmaps

I have a loop that executes the body about 200 times. In each loop iteration, it does a sophisticated calculation, and then as debugging, I wish to produce a heatmap of a NxM matrix. But, generating this heatmap is unbearably slow and significantly slow downs an already slow algorithm.

My code is along the lines:

import numpy
import matplotlib.pyplot as plt
for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))

The matrix, from numpy, is not huge --- 300 x 600 of doubles. Even if I do not save the figure and instead update an on-screen plot, it's even slower.

Surely I must be abusing pyplot. (Matlab can do this, no problem.) How do I speed this up?

like image 670
carl Avatar asked Jun 04 '10 05:06

carl


1 Answers

Try putting plt.clf() in the loop to clear the current figure:

for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))
    plt.clf()

If you don't do this, the loop slows down as the machine struggles to allocate more and more memory for the figure.

like image 164
unutbu Avatar answered Sep 29 '22 20:09

unutbu