Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient 2D cross correlation in Python?

I have two arrays of size (n, m, m) (n number of images of size (m,m)). I want to perform a cross correlation between each corresponding n of the two arrays.

Example: n=1 -> corr2d([m,m]<sub>1</sub>,[m,m]<sub>2</sub>) My current way include a bunch of for loops in python:

for i in range(len(X)):
    X_co = X[i,0,:,:]/(np.max(X[i,0,:,:]))
    X_x = X[i,1,:,:]/(np.max(X[i,1,:,:]))
    autocorr[i,0,:,:]=correlate2d(X_co, X_x, mode='same', boundary='fill', fillvalue=0)

Obviously this is very slow when the input contain many images, and becomes a substantial part of the total run time if (m,m) << n.

The obvious optimization is to skip the loop and feed everything directly to the compiled correlation function. Currently I'm using scipy's correlate2d. I've looked around but haven't found any function that allows correlation along some axis or multiple inputs.

Any tips on how to make scipy's correlate2d work or alternatives?

like image 832
CupinaCoffee Avatar asked Jul 23 '18 18:07

CupinaCoffee


1 Answers

I decided to implement it via the FFT instead.

def fft_xcorr2D(x):
    # Over axes (-2,-1) (default in the fft2 function)
    ## Pad because of cyclic (circular?) behavior of the FFT
    x = np.fft2(np.pad(x,([0,0],[0,0],[0,34],[0,34]),mode='constant'))

    # Conjugate for correlation, not convolution (Conv. Theorem)
    x[:,1,:,:] = np.conj(x[:,1,:,:])

    # Over axes (-2,-1) (default in the ifft2 function)
    ## Multiply elementwise over 2:nd axis (2 image bands for me)
    ### fftshift over rows and column over images
    corr = np.fft.fftshift(np.ifft2(np.prod(x,axis=1)),axes=(-2,-1))

    # Return after removing padding
    return np.abs(corr)[:,3:-2,3:-2]

Call via:

ts=fft_xcorr2D(X)

If anybody wants to use it: My input is a 4D array: (N, 2, #Rows, #Cols)

E.g. (500, 2, 30, 30): 500 images, 2 bands (polarizations, for example), of 30x30 pixels

If your input is different, adjust the padding to your liking Check so your input order is the same as mine otherwise change the axes arguments in the fft2 and ifft2 functions, the np.prod and fftshift. I use fftshift to get the maximum value in the middle (otherwise in the corners), so be wary of that if that's not what you want.

Why is it the maximum value? Technically, it doesn't have to be, but for my purpose it is. fftshift is used to get a correlation that looks like you're used to. Otherwise, the quadrants are turned "inside out". If you wonder what I mean, remove fftshift (just the fftshift part, not its arguments), call the function as before, and plot it.

Afterwards, it should be ready to use. Possibly x.prod(axis=1) is faster than np.prod(x,axis=1) but it's an old post. It shows no improvement for me after trying.

like image 144
CupinaCoffee Avatar answered Nov 15 '22 00:11

CupinaCoffee