Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Speedup scipy griddata for multiple interpolations between two irregular grids

I have several values that are defined on the same irregular grid (x, y, z) that I want to interpolate onto a new grid (x1, y1, z1). i.e., I have f(x, y, z), g(x, y, z), h(x, y, z) and I want to calculate f(x1, y1, z1), g(x1, y1, z1), h(x1, y1, z1).

At the moment I am doing this using scipy.interpolate.griddata and it works well. However, because I have to perform each interpolation separately and there are many points, it is quite slow, with a great deal of duplication in the calculation (i.e finding which points are closest, setting up the grids etc...).

Is there a way to speedup the calculation and reduce the duplicated calculations? i.e something along the lines of defining the two grids, then changing the values for the interpolation?

like image 558
s_haskey Avatar asked Jan 04 '14 01:01

s_haskey


3 Answers

There are several things going on every time you make a call to scipy.interpolate.griddata:

  1. First, a call to sp.spatial.qhull.Delaunay is made to triangulate the irregular grid coordinates.
  2. Then, for each point in the new grid, the triangulation is searched to find in which triangle (actually, in which simplex, which in your 3D case will be in which tetrahedron) does it lay.
  3. The barycentric coordinates of each new grid point with respect to the vertices of the enclosing simplex are computed.
  4. An interpolated values is computed for that grid point, using the barycentric coordinates, and the values of the function at the vertices of the enclosing simplex.

The first three steps are identical for all your interpolations, so if you could store, for each new grid point, the indices of the vertices of the enclosing simplex and the weights for the interpolation, you would minimize the amount of computations by a lot. This is unfortunately not easy to do directly with the functionality available, although it is indeed possible:

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import itertools

def interp_weights(xyz, uvw):
    tri = qhull.Delaunay(xyz)
    simplex = tri.find_simplex(uvw)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uvw - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

def interpolate(values, vtx, wts):
    return np.einsum('nj,nj->n', np.take(values, vtx), wts)

The function interp_weights does the calculations for the first three steps I listed above. Then the function interpolate uses those calcualted values to do step 4 very fast:

m, n, d = 3.5e4, 3e3, 3
# make sure no new grid point is extrapolated
bounding_cube = np.array(list(itertools.product([0, 1], repeat=d)))
xyz = np.vstack((bounding_cube,
                 np.random.rand(m - len(bounding_cube), d)))
f = np.random.rand(m)
g = np.random.rand(m)
uvw = np.random.rand(n, d)

In [2]: vtx, wts = interp_weights(xyz, uvw)

In [3]: np.allclose(interpolate(f, vtx, wts), spint.griddata(xyz, f, uvw))
Out[3]: True

In [4]: %timeit spint.griddata(xyz, f, uvw)
1 loops, best of 3: 2.81 s per loop

In [5]: %timeit interp_weights(xyz, uvw)
1 loops, best of 3: 2.79 s per loop

In [6]: %timeit interpolate(f, vtx, wts)
10000 loops, best of 3: 66.4 us per loop

In [7]: %timeit interpolate(g, vtx, wts)
10000 loops, best of 3: 67 us per loop

So first, it does the same as griddata, which is good. Second, setting up the interpolation, i.e. computing vtx and wts takes roughly the same as a call to griddata. But third, you can now interpolate for different values on the same grid in virtually no time.

The only thing that griddata does that is not contemplated here is assigning fill_value to points that have to be extrapolated. You could do that by checking for points for which at least one of the weights is negative, e.g.:

def interpolate(values, vtx, wts, fill_value=np.nan):
    ret = np.einsum('nj,nj->n', np.take(values, vtx), wts)
    ret[np.any(wts < 0, axis=1)] = fill_value
    return ret
like image 145
Jaime Avatar answered Oct 20 '22 01:10

Jaime


Great thanks to Jaime for his solution (even if I don't really understand how the barycentric computation is done ...)

Here you will find an example adapted from his case in 2D :

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import numpy as np

def interp_weights(xy, uv,d=2):
    tri = qhull.Delaunay(xy)
    simplex = tri.find_simplex(uv)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uv - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

def interpolate(values, vtx, wts):
    return np.einsum('nj,nj->n', np.take(values, vtx), wts)

m, n = 101,201
mi, ni = 1001,2001

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m))
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi))

xy=np.zeros([X.shape[0]*X.shape[1],2])
xy[:,0]=Y.flatten()
xy[:,1]=X.flatten()
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2])
uv[:,0]=Yi.flatten()
uv[:,1]=Xi.flatten()

values=np.cos(2*X)*np.cos(2*Y)

#Computed once and for all !
vtx, wts = interp_weights(xy, uv)
valuesi=interpolate(values.flatten(), vtx, wts)
valuesi=valuesi.reshape(Xi.shape[0],Xi.shape[1])
print "interpolation error: ",np.mean(valuesi-np.cos(2*Xi)*np.cos(2*Yi))  
print "interpolation uncertainty: ",np.std(valuesi-np.cos(2*Xi)*np.cos(2*Yi))  

It is possible to applied image transformation such as image mapping with a udge speed-up

You can't use the same function definition as the new coordinates will change at every iteration but you can compute triangulation Once for all.

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import numpy as np
import time

# Definition of the fast  interpolation process. May be the Tirangulation process can be removed !!
def interp_tri(xy):
    tri = qhull.Delaunay(xy)
    return tri


def interpolate(values, tri,uv,d=2):
    simplex = tri.find_simplex(uv)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uv- temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)  
    return np.einsum('nj,nj->n', np.take(values, vertices),  np.hstack((bary, 1.0 - bary.sum(axis=1, keepdims=True))))

m, n = 101,201
mi, ni = 101,201

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m))
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi))

xy=np.zeros([X.shape[0]*X.shape[1],2])
xy[:,1]=Y.flatten()
xy[:,0]=X.flatten()
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2])
# creation of a displacement field
uv[:,1]=0.5*Yi.flatten()+0.4
uv[:,0]=1.5*Xi.flatten()-0.7
values=np.zeros_like(X)
values[50:70,90:150]=100.

#Computed once and for all !
tri = interp_tri(xy)
t0=time.time()
for i in range(0,100):
  values_interp_Qhull=interpolate(values.flatten(),tri,uv,2).reshape(Xi.shape[0],Xi.shape[1])
t_q=(time.time()-t0)/100

t0=time.time()
values_interp_griddata=spint.griddata(xy,values.flatten(),uv,fill_value=0).reshape(values.shape[0],values.shape[1])
t_g=time.time()-t0

print "Speed-up:", t_g/t_q
print "Mean error: ",(values_interp_Qhull-values_interp_griddata).mean()
print "Standard deviation: ",(values_interp_Qhull-values_interp_griddata).std()

On my laptop the speed-up is between 20 and 40x !

Hope that can help someone

like image 31
Jeff Witz Avatar answered Oct 20 '22 01:10

Jeff Witz


I had the same problem (griddata extremely slow, grid stays the same for many interpolations) and I liked the solution described here the best, mainly because it is very easy to understand and apply.

It is using the LinearNDInterpolator, where one can pass the Delaunay triangulation that needs to be computed only once. Copy & paste from that post (all credits to xdze2):

from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator

tri = Delaunay(mesh1)  # Compute the triangulation

# Perform the interpolation with the given values:
interpolator = LinearNDInterpolator(tri, values_mesh1)
values_mesh2 = interpolator(mesh2)

That speeds up my computations by a factor of approximately 2.

like image 3
Waterkant Avatar answered Oct 20 '22 01:10

Waterkant