This problem is similar to a former problem answered Fast interpolation over 3D array, but cannot solve my problem.
I have a 4D array with dimensions of (time,altitude,latitude, longitude), marked as y.shape=(nt, nalt, nlat, nlon). The x is altitude and change with (time, latitude, longtitude), which means x.shape = (nt, nalt, nlat, nlon). I want to interpolate in altitude for every (nt, nlat, nlon). The interpolated x_new should be 1d, not change with (time, latitude, longtitude).
I use numpy.interp, same as scipy.interpolate.interp1d and think about the answers in former post. I cannot reduced the loops with those answers.
I can only do like this:
# y is a 4D ndarray
# x is a 4D ndarray
# new_y is a 4D array
for i in range(nlon):
for j in range(nlat):
for k in range(nt):
y_new[k,:,j,i] = np.interp(new_x, x[k,:,j,i], y[k,:,j,i])
These loops make this interpolation too slow to calculation. Would someone have good ideas? Help will be highly appreciated.
Here is my solution by using numba, it's about 3x faster.
create the test data first, x need to in ascending order:
import numpy as np
rows = 200000
cols = 66
new_cols = 69
x = np.random.rand(rows, cols)
x.sort(axis=-1)
y = np.random.rand(rows, cols)
nx = np.random.rand(new_cols)
nx.sort()
do 200000 times interp in numpy:
%%time
ny = np.empty((x.shape[0], len(nx)))
for i in range(len(x)):
ny[i] = np.interp(nx, x[i], y[i])
I use merge method instead of binary search method, because nx is in order, and the length of nx is about the same as x.
interp() use binary search, the time complexity is O(len(nx)*log2(len(x))O(len(nx) + len(x))Here is the numba code:
import numba
@numba.jit("f8[::1](f8[::1], f8[::1], f8[::1], f8[::1])")
def interp2(x, xp, fp, f):
n = len(x)
n2 = len(xp)
j = 0
i = 0
while x[i] <= xp[0]:
f[i] = fp[0]
i += 1
slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])
while i < n:
if x[i] >= xp[j] and x[i] < xp[j+1]:
f[i] = slope*(x[i] - xp[j]) + fp[j]
i += 1
continue
j += 1
if j + 1 == n2:
break
slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])
while i < n:
f[i] = fp[n2-1]
i += 1
@numba.jit("f8[:, ::1](f8[::1], f8[:, ::1], f8[:, ::1])")
def multi_interp(x, xp, fp):
nrows = xp.shape[0]
f = np.empty((nrows, x.shape[0]))
for i in range(nrows):
interp2(x, xp[i, :], fp[i, :], f[i, :])
return f
Then call the numba function:
%%time
ny2 = multi_interp(nx, x, y)
To check the result:
np.allclose(ny, ny2)
On my pc, the time is:
python version: 3.41 s
numba version: 1.04 s
This method need an array that the last axis is the axis to be interp().
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With