Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Interpolating a 3d array in Python expanded

My question expands on the code response seen here: Interpolating a 3d array in Python. How to avoid for loops?. Relevant original solution code is below:

import numpy as np
from scipy.interpolate import interp1d
array = np.random.randint(0, 9, size=(100, 100, 100))
x = np.linspace(0, 100, 100)
x_new = np.linspace(0, 100, 1000)
new_array = interp1d(x, array, axis=0)(x_new)
new_array.shape # -> (1000, 100, 100)

The approach above works great when x_new is a constant 1-d array but what if my x_new is not a constant 1-d array, but instead depends on the index of the latitude/longitude dimension in another 3-d array. My x_new is of size 355x195x192 (time x lat x long) and right now I am double-for-looping through the latitude and longitude dimensions. Since x_new is different for each latitude/longitude pair, how can I avoid looping as seen below? My loop process takes a couple hours, unfortunately...

x_new=(np.argsort(np.argsort(modell, 0), 0).astype(float) + 1) / np.size(modell, 0)
## x_new is shape 355x195x192
## pobs is shape 355x1
## prism_aligned_tmax_sorted  is shape 355x195x192
interp_func = interpolate.interp1d(pobs, prism_aligned_tmax_sorted,axis=0)
tmaxmod = np.empty((355, 195, 192,))
tmaxmod[:] = np.NAN                                    
for latt in range(0, 195):
    for lonn in range(0, 192):
        temp = interp_func(x_new[:,latt,lonn])
        tmaxmod[:,latt,lonn] = temp[:,latt,lonn]

Thanks for any and all assistance!

like image 411
ajoros Avatar asked Apr 03 '17 11:04

ajoros


1 Answers

I know how you can get rid of those loops, but you're not going to like it.

The problem is that this use of interp1d gives you essentially a matrix-valued function interpolated over a 1d domain, i.e. an F(x) function where for each scalar x you have a 2d-array-shaped F. The interpolation you're trying to do is rather this: creating an individual interpolator for each (lat,lon) pair of yours. This is more along the lines of F_(lat,lon)(x).

The reason this is a problem is that for your use case you're computing the matrix-valued F(x) for each of your query points, but then carry on to discard all of the matrix elements except for a single one (element [lat,lon] for a query point corresponding to this pair). So you're doing a bunch of unnecessary calculations computing all those irrelevant function values. The problem is that I'm not sure there's a more efficient way.

Your use case can be fixed with appropriate memory behind your back. The fact that your loops run for hours suggest that this will not really be possible for your use case, but anyway I'll show this solution. The idea is to turn your 3d array into a 2d one, do the interpolation with this shape, then take the diagonal elements along the effective 2d subspace of your interpolated result. You will still compute every irrelevant matrix element for every query point, but instead of looping over your arrays you'll be able to extract the relevant matrix elements with a single indexing step. The cost of this is to create a much larger auxiliary array, which will most likely not fit into your available RAM.

Anyway, here's the trick in action, comparing your current approach with a the one:

import numpy as np
from scipy.interpolate import interp1d
arr = np.random.randint(0, 9, size=(3, 4, 5))
x = np.linspace(0, 10, 3)
x_new = np.random.rand(6,4,5)*10

## x is shape 3 
## arr is shape  3x4x5
## x_new is shape  6x4x5

# original, loopy approach
interp_func = interp1d(x, arr, axis=0)
res = np.empty((6, 4, 5))
for lat in range(res.shape[1]):
    for lon in range(res.shape[2]):
        temp = interp_func(x_new[:,lat,lon]) # shape (6,4,5) each iteration
        res[:,lat,lon] = temp[:,lat,lon]

# new, vectorized approach
arr2 = arr.reshape(arr.shape[0],-1) # shape (3,20)
interp_func2 = interp1d(x,arr2,axis=0)
x_new2 = x_new.reshape(x_new.shape[0],-1) # shape (6,20)
temp = interp_func2(x_new2) # shape (6,20,20): 20 larger than original!
s = x_new2.shape[1] # 20, used for fancy indexing ranges
res2 = temp[:,range(s),range(s)].reshape(res.shape) # shape (6,20) -> (6,4,5)

The resulting res and res2 arrays are exactly the same, so the approach probably works. But as I said, for a query array of shape (nx,nlat,nlon) we need an auxiliarry array of shape (nx,nlat*nlon,nlat*nlon), which will typically have enormous memory need.


The only strict alternative I can think of is just performing your 1d interpolations one by one: defining nlat*nlon interpolators in a double loop. This will have much larger overhead of creating the interpolators, but on the other hand you won't do a bunch of unnecessary work computing interpolated array values which you then discard.

Finally, depending on your use ase you should consider using multivariate interpolation (I'm thinking interpolate.interpnd or interpolate.griddata). Assuming that your function is smooth as a function of latitude and longitude as well, it might make sense to interpolate your complete dataset in higher dimension. This way you need to create your interpolator once, and query at exactly the points you need with no unnecessary fluff in your way.


If you end up sticking with your current implementation, you can probably greatly improve performance by moving your interpolating axis to the last position. This way every vectorized operation acts on contiguous blocks of memory (assuming the default C memory order), and this fits well with the "collection of 1d array" philosophy. So you should do something along the lines of

arr = arr.transpose(1,2,0) # shape (4,5,3)
interp_func = interp1d(x, arr, axis=-1)
...
for lat ...:
    for lon ...:
        res[lat,lon,:] = temp[lat,lon,:] # shape (4,5,6)

If you need to restore the original order, you can finally transpose the order back with res.transpose(2,0,1).

like image 136