I have a 3D array (z, y, x)
with shape=(92, 4800, 4800)
where each value along axis 0
represents a different point in time. The acquisition of values in the time domain failed in a few instances causing some values to be np.NaN
. In other instances no values have been acquired and all values along z
are np.NaN
.
What is the most efficient way to use linear interpolation to fill np.NaN
along axis 0
disregarding instances where all values are np.NaN
?
Here is a working example of what I'm doing that employs pandas
wrapper to scipy.interpolate.interp1d
. This takes around 2 seconds per slice on the original dataset meaning the whole array is processed in 2.6 hours. The example dataset with reduced size takes around 9.5 seconds.
import numpy as np
import pandas as pd
# create example data, original is (92, 4800, 4800)
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 480, 480))
test_arr[1:90:7, :, :] = -32768 # NaN fill value in original data
test_arr[:, 1:90:6, 1:90:8] = -32768
def interpolate_nan(arr, method="linear", limit=3):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)
for i in range(arr.shape[1]):
# slice along y axis, interpolate with pandas wrapper to interp1d
line_stack = pd.DataFrame(data=arr[:,i,:], dtype=np.float32)
line_stack.replace(to_replace=-37268, value=np.NaN, inplace=True)
line_stack.interpolate(method=method, axis=0, inplace=True, limit=limit)
line_stack.replace(to_replace=np.NaN, value=-37268, inplace=True)
result[:, i, :] = line_stack.values.astype(np.int16)
return result
Performance on my machine with the example dataset:
%timeit interpolate_nan(test_arr)
1 loops, best of 3: 9.51 s per loop
Edit:
I should clarify that the code is producing my expected outcome. The question is - how can I optimize this process?
I recently solved this problem for my particular use case with the help of numba and also did a little writeup on it.
from numba import jit
@jit(nopython=True)
def interpolate_numba(arr, no_data=-32768):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)
for x in range(arr.shape[2]):
# slice along x axis
for y in range(arr.shape[1]):
# slice along y axis
for z in range(arr.shape[0]):
value = arr[z,y,x]
if z == 0: # don't interpolate first value
new_value = value
elif z == len(arr[:,0,0])-1: # don't interpolate last value
new_value = value
elif value == no_data: # interpolate
left = arr[z-1,y,x]
right = arr[z+1,y,x]
# look for valid neighbours
if left != no_data and right != no_data: # left and right are valid
new_value = (left + right) / 2
elif left == no_data and z == 1: # boundary condition left
new_value = value
elif right == no_data and z == len(arr[:,0,0])-2: # boundary condition right
new_value = value
elif left == no_data and right != no_data: # take second neighbour to the left
more_left = arr[z-2,y,x]
if more_left == no_data:
new_value = value
else:
new_value = (more_left + right) / 2
elif left != no_data and right == no_data: # take second neighbour to the right
more_right = arr[z+2,y,x]
if more_right == no_data:
new_value = value
else:
new_value = (more_right + left) / 2
elif left == no_data and right == no_data: # take second neighbour on both sides
more_left = arr[z-2,y,x]
more_right = arr[z+2,y,x]
if more_left != no_data and more_right != no_data:
new_value = (more_left + more_right) / 2
else:
new_value = value
else:
new_value = value
else:
new_value = value
result[z,y,x] = int(new_value)
return result
This is about 20 times faster than my initial code.
The questioner gave an excellent answer by taking advantage of numba
. I really appreciate it but I cannot totally agree with the contents within interpolate_numba
function. I do not think the logic of linear interpolation over a specific point is to find the average value of its left and right neighbors. For illustration, let's say we have an array [1,nan,nan,4,nan,6], the interpolate_numba
function above will probably return [1,2.5,2.5,4,5,6] (only theoretical deduction), whereas pandas
wrapper will surely return [1,2,3,4,5,6]. Instead, I believe the logic of linear interpolation over a specific point is to find its left and right neighbors, use their values to determine a line (i.e. slope and intercept), and finally calculate the interpolation value. The following shows my code. To make things easy, I assume the input data is a 3-D array containing nan values. I stipulate the first and last elements are equivalent to their right and left nearest neighbors (i.e. limit_direction='both'
in pandas
). I do not specify the maximum number of consecutive interpolations (i.e. no limit
in pandas
).
import numpy as np
from numba import jit
@jit(nopython=True)
def f(arr_3d):
result=np.zeros_like(arr_3d)
for i in range(arr_3d.shape[1]):
for j in range(arr_3d.shape[2]):
arr=arr_3d[:,i,j]
# If all elements are nan then cannot conduct linear interpolation.
if np.sum(np.isnan(arr))==arr.shape[0]:
result[:,i,j]=arr
else:
# If the first elemet is nan, then assign the value of its right nearest neighbor to it.
if np.isnan(arr[0]):
arr[0]=arr[~np.isnan(arr)][0]
# If the last element is nan, then assign the value of its left nearest neighbor to it.
if np.isnan(arr[-1]):
arr[-1]=arr[~np.isnan(arr)][-1]
# If the element is in the middle and its value is nan, do linear interpolation using neighbor values.
for k in range(arr.shape[0]):
if np.isnan(arr[k]):
x=k
x1=x-1
x2=x+1
# Find left neighbor whose value is not nan.
while x1>=0:
if np.isnan(arr[x1]):
x1=x1-1
else:
y1=arr[x1]
break
# Find right neighbor whose value is not nan.
while x2<arr.shape[0]:
if np.isnan(arr[x2]):
x2=x2+1
else:
y2=arr[x2]
break
# Calculate the slope and intercept determined by the left and right neighbors.
slope=(y2-y1)/(x2-x1)
intercept=y1-slope*x1
# Linear interpolation and assignment.
y=slope*x+intercept
arr[x]=y
result[:,i,j]=arr
return result
Initializing a 3-D array containing some nans, I have checked my code which can give as same answer as those from pandas
wrapper. It will be a little bit confusing to go through pandas
wrapper code since pandas can only address 2-dimensional data.
Using my code
y1=np.ones((2,2))
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[1,1]=np.nan
y2[0,0]=np.nan
y4[1,1]=np.nan
y6[1,1]=np.nan
y=np.stack((y1,y2,y3,y4,y5,y6),axis=0)
print(y)
print("="*10)
f(y)
Using pandas wrapper
import pandas as pd
y1=np.ones((2,2)).flatten()
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[3]=np.nan
y2[0]=np.nan
y4[3]=np.nan
y6[3]=np.nan
y=pd.DataFrame(np.stack([y1,y2,y3,y4,y5,y6],axis=0))
y=y.interpolate(method='linear', limit_direction='both', axis=0)
y_numpy=y.to_numpy()
y_numpy.shape=((6,2,2))
print(np.stack([y1,y2,y3,y4,y5,y6],axis=0).reshape(6,2,2))
print("="*10)
print(y_numpy)
Output will be the same
[[[ 1. 1.]
[ 1. nan]]
[[nan 2.]
[ 2. 2.]]
[[nan nan]
[nan nan]]
[[ 4. 4.]
[ 4. nan]]
[[nan nan]
[nan nan]]
[[ 6. 6.]
[ 6. nan]]]
==========
[[[1. 1.]
[1. 2.]]
[[2. 2.]
[2. 2.]]
[[3. 3.]
[3. 2.]]
[[4. 4.]
[4. 2.]]
[[5. 5.]
[5. 2.]]
[[6. 6.]
[6. 2.]]]
Using test_arr
data increasing its size to (92,4800,4800) as input, I found only approximately 40 s was needed to complete the interpolation!
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 4800, 4800))
test_arr[1:90:7, :, :] = np.nan # NaN fill value in original data
test_arr[2,:,:] = np.nan
test_arr[:, 1:479:6, 1:479:8] = np.nan
%time f(test_arr)
Output
CPU times: user 32.5 s, sys: 9.13 s, total: 41.6 s
Wall time: 41.6 s
This depends; you will have to take out a sheet of paper and calculate the error your overall statistics will get if you don't interpolate and just zero-fill these NaN.
Other than that, I think your interpolation is over the top.
Just find each NaN, and linearly interpolate to the adjacent four values (which is, sum up the values at (y +- 1,x +- 1) ) -- this will seriously limit your error enough (calculate yourself!), and you don't have interpolate with whatever complex method is used in your case (you didn't define method
).
You can try to just pre-compute one "averaged" 4800x4800 matrix per z value -- this shouldn't really take long -- by applying a cross-shaped kernel across the matrix (it's all very image-processing-like, here). In case of NaN's, some of the averaged values will be NaN (every averaged pixel where a NaN was in the neighborliness), but you don't care -- unless there are two adjacent NaNs, the NaN cells that you want to replace in the original matrix are all real-valued.
Then you just replace all the NaNs by the value in the averaged matrix.
Compare the speed of that with the speed of "manual" calculation of the neighborhood average for every NaN you find.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With