Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to regrid efficiently a multi-spectral image?

Given a multi-spectral image with the following shape:

a = np.random.random([240, 320, 30])

where the tail axis represent values at the following fractional wavelengths:

array([395.13, 408.62, 421.63, 434.71, 435.64, 453.39, 456.88, 471.48,
       484.23, 488.89, 497.88, 513.35, 521.38, 528.19, 539.76, 548.39,
       557.78, 568.06, 577.64, 590.22, 598.63, 613.13, 618.87, 632.75,
       637.5 , 647.47, 655.6 , 672.66, 681.88, 690.1 ])

What is the most efficient, i.e. without iterating on every single wavelength,to regrid the data at integer wavelengths as follows:

array([400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520,
       530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650,
       660, 670, 680, 690])
    
like image 819
Kel Solaar Avatar asked Oct 25 '25 10:10

Kel Solaar


2 Answers

It depends on the interpolation method and Physics that you deem appropriate.

From what you write, I would tend to assume that the error along the spatial dimensions is negligible compared to the error in the wavelength. If that is the case, an N-Dim interpolation is likely wrong as the pixel information should be independent. Instead what you would need to do is a 1D interpolation for all pixels.

The simplest (and fastest) form of interpolation is with nearest neighbor. Now, if the new wavelength can be computed with np.round(decimals=-1). The data is already interpolated and you just need to update the wavelength values.

If the new wavelength are not the old ones rounded, or if you do not need or want nearest neighbor interpolation, then you need to use a different approach, which at some point will involve looping through the pixels.

SciPy offers scipy.interpolate.interp1d() which does exactly that in a vectorized fashion (i.e. the loop through the pixel is pushed outside of Python frames) and offers a variety of interpolation methods.

For example, if samples are the measured wavelengths and new_samples contain the new ones, and arr contains the stacked images with the last axis running across wavelengths:

import scipy.interpolate


def interp_1d_last_sp(arr, samples, new_samples, kind="linear"):
    interpolator = scipy.interpolate.interp1d(samples, arr, axis=-1, kind=kind, fill_value="extrapolate")
    return interpolator(new_samples)

Something similar can be computed manually with an explicit loop, at least for linear interpolation, using the much faster np.interp() function:

import numpy as np


def interp_1d_last_np(arr, samples, new_samples):
    shape = arr.shape
    k = shape[-1]
    arr = arr.reshape((-1, k))
    result = np.empty_like(arr)
    n = arr.shape[0]
    for i in range(n):
        result[i, :] = np.interp(new_samples, samples, arr[i, :])
    return result.reshape(shape)

which does not support multidimensional input, but it can be accelerated with Numba:

import numba as nb


@nb.njit(parallel=True)
def interp_1d_last_nb(arr, samples, new_samples):
    shape = arr.shape
    k = shape[-1]
    arr = arr.reshape((-1, k))
    result = np.empty_like(arr)
    n = arr.shape[0]
    for i in nb.prange(n):
        result[i, :] = np.interp(new_samples, samples, arr[i, :])
    return result.reshape(shape)

While they all get to similar results (they deal slightly differently with extrapolation), the timings can be different:

np.random.seed(0)
a = np.random.random([240, 320, 30])
w_min = 400
w_max = 700
w_step = 10
w = np.arange(w_min, w_max, w_step)
nw = w_step * (np.random.random(w.size) - 0.5)


funcs = interp_1d_last_sp, interp_1d_last_np, interp_1d_last_nb
base = funcs[0](a, w + nw, w)
for func in funcs:
    res = func(a, w + nw, w)
    # the sum of the absolute difference with the non-interpolated array is reasonably the same
    is_good = np.isclose(np.sum(np.abs(base - a)), np.sum(np.abs(res - a)))
    print(f"{func.__name__:>12s}  {is_good!s:>5}  {np.sum(np.abs(res - a)):16.8f}  ", end="")
    %timeit -n 2 -r 2 func(a, w + nw, w)
# interp_1d_last_sp   True   140136.05282911  90.8 ms ± 21.7 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
# interp_1d_last_np   True   140136.05282911  349 ms ± 14.2 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
# interp_1d_last_nb   True   140136.05282911  61.5 ms ± 1.83 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)

like image 126
norok2 Avatar answered Oct 27 '25 00:10

norok2


Edited the solution to be fully vectorized with no loops. Probably much more memory intensive though.

Note, the vectorized version is slower than just looping through wavelengths. Loops aren't always bad.

import numpy as np
import scipy.interpolate as interpolate

a = np.random.random([240, 320, 30])

wavelengths1 = np.array([395.13, 408.62, 421.63, 434.71, 435.64, 453.39, 456.88, 471.48,
       484.23, 488.89, 497.88, 513.35, 521.38, 528.19, 539.76, 548.39,
       557.78, 568.06, 577.64, 590.22, 598.63, 613.13, 618.87, 632.75,
       637.5 , 647.47, 655.6 , 672.66, 681.88, 690.1 ])
wavelengths2 = np.array([400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520,
       530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650,
       660, 670, 680, 690])

interp = interpolate.RegularGridInterpolator((np.arange(a.shape[1]), np.arange(a.shape[0]), wavelengths1), a.transpose((1, 0, 2)))

# Create new grid based on the integer wavelengths desired above
X, Y = np.meshgrid(np.arange(320), np.arange(240))

X2 = np.repeat(X[:, :, None], len(wavelengths2), axis=-1)

Y2 = np.repeat(Y[:, :, None], len(wavelengths2), axis=-1)

a2 = interp((Y2, X2, wavelengths2[None, None, :]))
like image 29
NanoBennett Avatar answered Oct 27 '25 00:10

NanoBennett



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!