Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast interpolation over 3D array for 3D origin x

This problem is similar to a former problem answered Fast interpolation over 3D array, but cannot solve my problem.

I have a 4D array with dimensions of (time,altitude,latitude, longitude), marked as y.shape=(nt, nalt, nlat, nlon). The x is altitude and change with (time, latitude, longtitude), which means x.shape = (nt, nalt, nlat, nlon). I want to interpolate in altitude for every (nt, nlat, nlon). The interpolated x_new should be 1d, not change with (time, latitude, longtitude).

I use numpy.interp, same as scipy.interpolate.interp1d and think about the answers in former post. I cannot reduced the loops with those answers.

I can only do like this:

# y is a 4D ndarray
# x is a 4D ndarray
# new_y is a 4D array
for i in range(nlon):
    for j in range(nlat):
        for k in range(nt):
            y_new[k,:,j,i] = np.interp(new_x, x[k,:,j,i], y[k,:,j,i])

These loops make this interpolation too slow to calculation. Would someone have good ideas? Help will be highly appreciated.

like image 422
Hao Avatar asked Nov 28 '25 14:11

Hao


1 Answers

Here is my solution by using numba, it's about 3x faster.

create the test data first, x need to in ascending order:

import numpy as np
rows = 200000
cols = 66
new_cols = 69
x = np.random.rand(rows, cols)
x.sort(axis=-1)
y = np.random.rand(rows, cols)
nx = np.random.rand(new_cols)
nx.sort() 

do 200000 times interp in numpy:

%%time
ny = np.empty((x.shape[0], len(nx)))
for i in range(len(x)):
    ny[i] = np.interp(nx, x[i], y[i])

I use merge method instead of binary search method, because nx is in order, and the length of nx is about the same as x.

  • interp() use binary search, the time complexity is O(len(nx)*log2(len(x))
  • merge method: the time complexity is O(len(nx) + len(x))

Here is the numba code:

import numba

@numba.jit("f8[::1](f8[::1], f8[::1], f8[::1], f8[::1])")
def interp2(x, xp, fp, f):
    n = len(x)
    n2 = len(xp)
    j = 0
    i = 0
    while x[i] <= xp[0]:
        f[i] = fp[0]
        i += 1

    slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])        
    while i < n:
        if x[i] >= xp[j] and x[i] < xp[j+1]:
            f[i] = slope*(x[i] - xp[j]) + fp[j]
            i += 1
            continue
        j += 1
        if j + 1 == n2:
            break
        slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])   

    while i < n:
        f[i] = fp[n2-1]
        i += 1

@numba.jit("f8[:, ::1](f8[::1], f8[:, ::1], f8[:, ::1])")
def multi_interp(x, xp, fp):
    nrows = xp.shape[0]
    f = np.empty((nrows, x.shape[0]))
    for i in range(nrows):
        interp2(x, xp[i, :], fp[i, :], f[i, :])
    return f

Then call the numba function:

%%time
ny2 = multi_interp(nx, x, y)

To check the result:

np.allclose(ny, ny2)

On my pc, the time is:

python version: 3.41 s
numba version: 1.04 s

This method need an array that the last axis is the axis to be interp().

like image 95
HYRY Avatar answered Nov 30 '25 04:11

HYRY



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!