Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast 1D linear np.NaN interpolation over large 3D array

I have a 3D array (z, y, x) with shape=(92, 4800, 4800) where each value along axis 0 represents a different point in time. The acquisition of values in the time domain failed in a few instances causing some values to be np.NaN. In other instances no values have been acquired and all values along z are np.NaN.

What is the most efficient way to use linear interpolation to fill np.NaN along axis 0 disregarding instances where all values are np.NaN?

Here is a working example of what I'm doing that employs pandas wrapper to scipy.interpolate.interp1d. This takes around 2 seconds per slice on the original dataset meaning the whole array is processed in 2.6 hours. The example dataset with reduced size takes around 9.5 seconds.

import numpy as np
import pandas as pd

# create example data, original is (92, 4800, 4800)
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 480, 480))
test_arr[1:90:7, :, :] = -32768  # NaN fill value in original data
test_arr[:, 1:90:6, 1:90:8] = -32768

def interpolate_nan(arr, method="linear", limit=3):
    """return array interpolated along time-axis to fill missing values"""
    result = np.zeros_like(arr, dtype=np.int16)

    for i in range(arr.shape[1]):
        # slice along y axis, interpolate with pandas wrapper to interp1d
        line_stack = pd.DataFrame(data=arr[:,i,:], dtype=np.float32)
        line_stack.replace(to_replace=-37268, value=np.NaN, inplace=True)
        line_stack.interpolate(method=method, axis=0, inplace=True, limit=limit)
        line_stack.replace(to_replace=np.NaN, value=-37268, inplace=True)
        result[:, i, :] = line_stack.values.astype(np.int16)
    return result

Performance on my machine with the example dataset:

%timeit interpolate_nan(test_arr)
1 loops, best of 3: 9.51 s per loop

Edit:

I should clarify that the code is producing my expected outcome. The question is - how can I optimize this process?

like image 561
Kersten Avatar asked Jun 18 '15 09:06

Kersten


3 Answers

I recently solved this problem for my particular use case with the help of numba and also did a little writeup on it.

from numba import jit

@jit(nopython=True)
def interpolate_numba(arr, no_data=-32768):
    """return array interpolated along time-axis to fill missing values"""
    result = np.zeros_like(arr, dtype=np.int16)

    for x in range(arr.shape[2]):
        # slice along x axis
        for y in range(arr.shape[1]):
            # slice along y axis
            for z in range(arr.shape[0]):
                value = arr[z,y,x]
                if z == 0:  # don't interpolate first value
                    new_value = value
                elif z == len(arr[:,0,0])-1:  # don't interpolate last value
                    new_value = value

                elif value == no_data:  # interpolate

                    left = arr[z-1,y,x]
                    right = arr[z+1,y,x]
                    # look for valid neighbours
                    if left != no_data and right != no_data:  # left and right are valid
                        new_value = (left + right) / 2

                    elif left == no_data and z == 1:  # boundary condition left
                        new_value = value
                    elif right == no_data and z == len(arr[:,0,0])-2:  # boundary condition right
                        new_value = value

                    elif left == no_data and right != no_data:  # take second neighbour to the left
                        more_left = arr[z-2,y,x]
                        if more_left == no_data:
                            new_value = value
                        else:
                            new_value = (more_left + right) / 2

                    elif left != no_data and right == no_data:  # take second neighbour to the right
                        more_right = arr[z+2,y,x]
                        if more_right == no_data:
                            new_value = value
                        else:
                            new_value = (more_right + left) / 2

                    elif left == no_data and right == no_data:  # take second neighbour on both sides
                        more_left = arr[z-2,y,x]
                        more_right = arr[z+2,y,x]
                        if more_left != no_data and more_right != no_data:
                            new_value = (more_left + more_right) / 2
                        else:
                            new_value = value
                    else:
                        new_value = value
                else:
                    new_value = value
                result[z,y,x] = int(new_value)
    return result

This is about 20 times faster than my initial code.

like image 88
Kersten Avatar answered Nov 08 '22 06:11

Kersten


The questioner gave an excellent answer by taking advantage of numba. I really appreciate it but I cannot totally agree with the contents within interpolate_numba function. I do not think the logic of linear interpolation over a specific point is to find the average value of its left and right neighbors. For illustration, let's say we have an array [1,nan,nan,4,nan,6], the interpolate_numba function above will probably return [1,2.5,2.5,4,5,6] (only theoretical deduction), whereas pandas wrapper will surely return [1,2,3,4,5,6]. Instead, I believe the logic of linear interpolation over a specific point is to find its left and right neighbors, use their values to determine a line (i.e. slope and intercept), and finally calculate the interpolation value. The following shows my code. To make things easy, I assume the input data is a 3-D array containing nan values. I stipulate the first and last elements are equivalent to their right and left nearest neighbors (i.e. limit_direction='both' in pandas). I do not specify the maximum number of consecutive interpolations (i.e. no limit in pandas).

import numpy as np
from numba import jit
@jit(nopython=True)
def f(arr_3d):
    result=np.zeros_like(arr_3d)
    for i in range(arr_3d.shape[1]):
        for j in range(arr_3d.shape[2]):
            arr=arr_3d[:,i,j]
            # If all elements are nan then cannot conduct linear interpolation.
            if np.sum(np.isnan(arr))==arr.shape[0]:
                result[:,i,j]=arr
            else:
                # If the first elemet is nan, then assign the value of its right nearest neighbor to it.
                if np.isnan(arr[0]):
                    arr[0]=arr[~np.isnan(arr)][0]
                # If the last element is nan, then assign the value of its left nearest neighbor to it.
                if np.isnan(arr[-1]):
                    arr[-1]=arr[~np.isnan(arr)][-1]
                # If the element is in the middle and its value is nan, do linear interpolation using neighbor values.
                for k in range(arr.shape[0]):
                    if np.isnan(arr[k]):
                        x=k
                        x1=x-1
                        x2=x+1
                        # Find left neighbor whose value is not nan.
                        while x1>=0:
                            if np.isnan(arr[x1]):
                                x1=x1-1
                            else:
                                y1=arr[x1]
                                break
                        # Find right neighbor whose value is not nan.
                        while x2<arr.shape[0]:
                            if np.isnan(arr[x2]):
                                x2=x2+1
                            else:
                                y2=arr[x2]
                                break
                        # Calculate the slope and intercept determined by the left and right neighbors.
                        slope=(y2-y1)/(x2-x1)
                        intercept=y1-slope*x1
                        # Linear interpolation and assignment.
                        y=slope*x+intercept
                        arr[x]=y
                result[:,i,j]=arr
    return result

Initializing a 3-D array containing some nans, I have checked my code which can give as same answer as those from pandas wrapper. It will be a little bit confusing to go through pandas wrapper code since pandas can only address 2-dimensional data.

Using my code

y1=np.ones((2,2))
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[1,1]=np.nan
y2[0,0]=np.nan
y4[1,1]=np.nan
y6[1,1]=np.nan
y=np.stack((y1,y2,y3,y4,y5,y6),axis=0)
print(y)
print("="*10)
f(y)

Using pandas wrapper

import pandas as pd
y1=np.ones((2,2)).flatten()
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[3]=np.nan
y2[0]=np.nan
y4[3]=np.nan
y6[3]=np.nan
y=pd.DataFrame(np.stack([y1,y2,y3,y4,y5,y6],axis=0))
y=y.interpolate(method='linear', limit_direction='both', axis=0)
y_numpy=y.to_numpy()
y_numpy.shape=((6,2,2))
print(np.stack([y1,y2,y3,y4,y5,y6],axis=0).reshape(6,2,2))
print("="*10)
print(y_numpy)

Output will be the same

[[[ 1.  1.]
  [ 1. nan]]

 [[nan  2.]
  [ 2.  2.]]

 [[nan nan]
  [nan nan]]

 [[ 4.  4.]
  [ 4. nan]]

 [[nan nan]
  [nan nan]]

 [[ 6.  6.]
  [ 6. nan]]]
==========
[[[1. 1.]
  [1. 2.]]

 [[2. 2.]
  [2. 2.]]

 [[3. 3.]
  [3. 2.]]

 [[4. 4.]
  [4. 2.]]

 [[5. 5.]
  [5. 2.]]

 [[6. 6.]
  [6. 2.]]]

Using test_arr data increasing its size to (92,4800,4800) as input, I found only approximately 40 s was needed to complete the interpolation!

test_arr = np.random.randint(low=-10000, high=10000, size=(92, 4800, 4800))
test_arr[1:90:7, :, :] = np.nan  # NaN fill value in original data
test_arr[2,:,:] = np.nan
test_arr[:, 1:479:6, 1:479:8] = np.nan
%time f(test_arr)

Output

CPU times: user 32.5 s, sys: 9.13 s, total: 41.6 s
Wall time: 41.6 s
like image 30
Fei Yao Avatar answered Nov 08 '22 08:11

Fei Yao


This depends; you will have to take out a sheet of paper and calculate the error your overall statistics will get if you don't interpolate and just zero-fill these NaN.

Other than that, I think your interpolation is over the top. Just find each NaN, and linearly interpolate to the adjacent four values (which is, sum up the values at (y +- 1,x +- 1) ) -- this will seriously limit your error enough (calculate yourself!), and you don't have interpolate with whatever complex method is used in your case (you didn't define method).

You can try to just pre-compute one "averaged" 4800x4800 matrix per z value -- this shouldn't really take long -- by applying a cross-shaped kernel across the matrix (it's all very image-processing-like, here). In case of NaN's, some of the averaged values will be NaN (every averaged pixel where a NaN was in the neighborliness), but you don't care -- unless there are two adjacent NaNs, the NaN cells that you want to replace in the original matrix are all real-valued.

Then you just replace all the NaNs by the value in the averaged matrix.

Compare the speed of that with the speed of "manual" calculation of the neighborhood average for every NaN you find.

like image 2
Marcus Müller Avatar answered Nov 08 '22 07:11

Marcus Müller