Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

efficient 2d mean filter implementation that minimises redundant memory loads?

Suppose a general sliding algorithm that executes some function on a kernel, like a mean-filter (average-filter) or the sum of absolute differences algorithm in image processing. As the kernel slides to the next position, there will be some redundant reads from memory because the data enclosed by the new kernel will overlap that of the previous somewhat.

Let me explain with a practical example... Suppose you want to perform a median-filter on a large 2D matrix with a kernel (window) size of 3x3. The first position of the kernel (red in image below) would be centered at (1,1), the second position (green) would be centered at (1,2). Notice how the yellow area is the overlap, and these values now need to be reloaded from memory.

meanfilter

My specific problem is a 3D mean filter so the overlap is even bigger (3^3-3^2 = 18 for 3D vs 3^2-3 = 6 for 2D).

I'm sure this is a common problem... does anyone know how algorithms such as this are implemented efficiently to either eliminate the redundant memory lookups, or to exploit spatial and temporal locality of the CPU cache on modern architectures (e.g. 2-way associative cache)?

My specific problem in 3D takes only the mean from the nearest 6 neighbours (not the diagonal ones) and is implemented in C as follows:

for( i = 0; i <= maxi; i++ ) {
    for( j = 0; j <= maxj; j++ ) {
        for( k = 0; k <= maxk; k++ ) {
            filteredData[ i ][ j ][ k ] = 
            ONE_SIXTH *
            ( 
             data[ i + 1 ][ j     ][ k     ] +
             data[ i - 1 ][ j     ][ k     ] +
             data[ i     ][ j + 1 ][ k     ] +
             data[ i     ][ j - 1 ][ k     ] +
             data[ i     ][ j     ][ k + 1 ] +
             data[ i     ][ j     ][ k - 1 ]
            );
        }
    }
}
like image 592
lmirosevic Avatar asked May 07 '11 20:05

lmirosevic


2 Answers

What you are doing is called Convolution. You convolve the multidimensional data with a smaller kernel of the same number of dimensions. It is a very common task, and there are plenty of libraries for it.

A fast solution (depending on the kernel size) is to calculate the convolution in the frequency domain. You calculate the (multidimensional) FFT of both data and kernel, multiply them, and calculate the inverse FFT. You will find libraries optimized to do just that, eg. for Python there is scipy.ndimage.filters.convolve and scipy.signal.fftconvolve.

Tiling is a common image processing technique to optimize low-level memory access. You allocate square tiles (or cubes) that fit well into the CPU cache. When you access neighbouring pixels they will be close together in memory most of the time. Looping over the whole array gets a bit tricky, though.

For further reading I reccommend the paper Why Modern CPUs Are Starving and What Can Be Done about It, which mentions this memory blocking technique, and points to numerical libraries that implement it.

And finally there is the Integral Image that allows you to calculate the average of an arbitrary rectangle/cuboid with just a very small number of memory accesses.

like image 164
maxy Avatar answered Oct 01 '22 13:10

maxy


For the case of a 2D mean filter I would maintain column totals which could then be reused so that for each iteration you only calculate one new column total and then sum the column totals to get the mean. E.g. for a 3x3 mean:

for (i = 1; i < M - 1; ++i)
{
    // init first two column sums
    col0 = a[i - 1][0] + a[i][0] + a[i + 1][0];
    col1 = a[i - 1][1] + a[i][1] + a[i + 1][1];
    for (j = 1; j < N - 1; ++j)
    {
        // calc new col sum
        col2 = a[i - 1][j + 1] + a[i][j + 1] + a[i + 1][j + 1];
        // calc new mean
        mean[i][j] = (col0 + col1 + col2) / 9;
        // shuffle col sums
        col0 = col1;
        col1 = col2;
    }
}

This results in only 3 loads per point, rather than 9 as in the naive case, but is still not quite optimal.

You can optimise this further by processing two rows per iteration and maintaining overlapping column sums for rows i and i + 1.

like image 27
Paul R Avatar answered Oct 01 '22 12:10

Paul R