Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Speed up Numpy Meshgrid Command

I am generating a Meshgrid with Numpy and it's taking a lot of memory and quite a bit of time as well.

xi, yi = np.meshgrid(xi, yi)

I am generating a meshgrid the same resolution as the underlying sitemap image, sometimes 3000px dimensions. It uses several gigs of memory sometimes and takes 10-15 seconds or more while when it's writing it to the page file.

My question is; can I speed this up without upgrading the server? Here is a full copy of my application source code.

def generateContours(date_collected, substance_name, well_arr, site_id, sitemap_id, image, title_wildcard='', label_over_well=False, crop_contours=False, groundwater_contours=False, flow_lines=False, site_image_alpha=1, status_token=""):
    #create empty arrays to fill up!
    x_values = []
    y_values = []
    z_values = []

    #iterate over wells and fill the arrays with well data
    for well in well_arr:
        x_values.append(well['xpos'])
        y_values.append(well['ypos'])
        z_values.append(well['value'])

    #initialize numpy array as required for interpolation functions
    x = np.array(x_values, dtype=np.float)
    y = np.array(y_values, dtype=np.float)
    z = np.array(z_values, dtype=np.float)

    #create a list of x, y coordinate tuples
    points = zip(x, y)

    #create a grid on which to interpolate data
    start_time = time.time()
    xi, yi = np.linspace(0, image['width'], image['width']), np.linspace(0, image['height'], image['height'])

    xi, yi = np.meshgrid(xi, yi)

    #interpolate the data with the matlab griddata function (http://matplotlib.org/api/mlab_api.html#matplotlib.mlab.griddata)
    zi = griddata(x, y, z, xi, yi, interp='nn')

    #create a matplotlib figure and adjust the width and heights to output contours to a resolution very close to the original sitemap
    fig = plt.figure(figsize=(image['width']/72, image['height']/72))

    #create a single subplot, just takes over the whole figure if only one is specified
    ax = fig.add_subplot(111, frameon=False, xticks=[], yticks=[])

    #read the database image and save to a temporary variable
    im = Image.open(image['tmpfile'])

    #place the sitemap image on top of the figure
    ax.imshow(im, origin='upper', alpha=site_image_alpha)

    #figure out a good linewidth
    if image['width'] > 2000:
        linewidth = 3
    else:
        linewidth = 2

    #create the contours (options here http://cl.ly/2X0c311V2y01)
    kwargs = {}
    if groundwater_contours:
        kwargs['colors'] = 'b'

    CS = plt.contour(xi, yi, zi, linewidths=linewidth, **kwargs)
    for key, value in enumerate(CS.levels):
        if value == 0:
            CS.collections[key].remove()

    #add a streamplot
    if flow_lines:
        dy, dx = np.gradient(zi)
        plt.streamplot(xi, yi, dx, dy, color='c', density=1, arrowsize=3, arrowstyle='<-')

    #add labels to well locations
    label_kwargs = {}
    if label_over_well is True:
        label_kwargs['manual'] = points

    plt.clabel(CS, CS.levels[1::1], inline=5, fontsize=math.floor(image['width']/100), fmt="%.1f", **label_kwargs)

    #add scatterplot to show where well data was read
    scatter_size = math.floor(image['width']/20)
    plt.scatter(x, y, s=scatter_size, c='k', facecolors='none', marker=(5, 1))

    try:
        site_name = db_session.query(Sites).filter_by(site_id=site_id).first().title
    except:
        site_name = "Site Map #%i" % site_id

    sitemap = SiteMaps.query.get(sitemap_id)
    if sitemap.title != 'Sitemap':
        sitemap_wildcard = " - " + sitemap.title
    else:
        sitemap_wildcard = ""

    if title_wildcard != '':
        filename_wildcard = "-" + slugify(title_wildcard)
        title_wildcard = " - " + title_wildcard
    else:
        filename_wildcard = ""
        title_wildcard = ""

    #add descriptive title to the top of the contours
    title_font_size = math.floor(image['width']/72)
    plt.title(parseDate(date_collected) + " - " + site_name + " " + substance_name + " Contour" + sitemap_wildcard + title_wildcard, fontsize=title_font_size)

    #generate a unique filename and save to a temp directory
    filename = slugify(site_name) + str(int(time.time())) + filename_wildcard + ".pdf"
    temp_dir = tempfile.gettempdir()
    tempFileObj = temp_dir + "/" + filename
    savefig(tempFileObj)  # bbox_inches='tight' tightens the white border

    #clears the matplotlib memory
    clf()

    #send the temporary file to the user
    resp = make_response(send_file(tempFileObj, mimetype='application/pdf', as_attachment=True, attachment_filename=filename))

    #set the users status token for javascript workaround to check if file is done being generated
    resp.set_cookie('status_token', status_token)

    return resp
like image 885
Nick Woodhams Avatar asked Aug 13 '13 04:08

Nick Woodhams


1 Answers

If meshgrid is what's slowing you down, don't call it... According to the griddata docs:

xi and yi must describe a regular grid, can be either 1D or 2D, but must be monotonically increasing.

So your call to griddata should work just the same if you skip the call to meshgrid and do:

xi = np.linspace(0, image['width'], image['width'])
yi = np.linspace(0, image['height'], image['height'])
zi = griddata(x, y, z, xi, yi, interp='nn')

This said, if your x and y vectors are large, the actual interpolation, i.e. the call to griddata is probably going to take quite some time, as the Delaunay triangulation is a computationally intensive operation. Are you sure your performanc issues are coming from meshgrid, not from griddata?

like image 107
Jaime Avatar answered Sep 19 '22 19:09

Jaime