Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Speeding up iterating over Numpy Arrays

I am working on performing image processing using Numpy, specifically a running standard deviation stretch. This reads in X number of columns, finds the Std. and performs a percentage linear stretch. It then iterates to the next "group" of columns and performs the same operations. The input image is a 1GB, 32-bit, single band raster which is taking quite a long time to process (hours). Below is the code.

I realize that I have 3 nested for loops which is, presumably where the bottleneck is occurring. If I process the image in "boxes", that is to say loading an array that is [500,500] and iterating through the image processing time is quite short. Unfortunately, camera error requires that I iterate in extremely long strips (52,000 x 4) (y,x) to avoid banding.

Any suggestions on speeding this up would be appreciated:

def box(dataset, outdataset, sampleSize, n):

    quiet = 0
    sample = sampleSize
    #iterate over all of the bands
    for j in xrange(1, dataset.RasterCount + 1): #1 based counter

        band = dataset.GetRasterBand(j)
        NDV = band.GetNoDataValue()

        print "Processing band: " + str(j)       

        #define the interval at which blocks are created
        intervalY = int(band.YSize/1)    
        intervalX = int(band.XSize/2000) #to be changed to sampleSize when working

        #iterate through the rows
        scanBlockCounter = 0

        for i in xrange(0,band.YSize,intervalY):

            #If the next i is going to fail due to the edge of the image/array
            if i + (intervalY*2) < band.YSize:
                numberRows = intervalY
            else:
                numberRows = band.YSize - i

            for h in xrange(0,band.XSize, intervalX):

                if h + (intervalX*2) < band.XSize:
                    numberColumns = intervalX
                else:
                    numberColumns = band.XSize - h

                scanBlock = band.ReadAsArray(h,i,numberColumns, numberRows).astype(numpy.float)

                standardDeviation = numpy.std(scanBlock)
                mean = numpy.mean(scanBlock)

                newMin = mean - (standardDeviation * n)
                newMax = mean + (standardDeviation * n)

                outputBlock = ((scanBlock - newMin)/(newMax-newMin))*255
                outRaster = outdataset.GetRasterBand(j).WriteArray(outputBlock,h,i)#array, xOffset, yOffset


                scanBlockCounter = scanBlockCounter + 1
                #print str(scanBlockCounter) + ": " + str(scanBlock.shape) + str(h)+ ", " + str(intervalX)
                if numberColumns == band.XSize - h:
                    break

                #update progress line
                if not quiet:
                    gdal.TermProgress_nocb( (float(h+1) / band.YSize) )

Here is an update: Without using the profile module, as I did not want to start wrapping small sections of the code into functions I used a mix of print and exit statements to get a really rough idea about which lines were taking the most time. Luckily (and I do understand how lucky I was) one line was dragging everything down.

    outRaster = outdataset.GetRasterBand(j).WriteArray(outputBlock,h,i)#array, xOffset, yOffset

It appears that GDAL is quite inefficient when opening the output file and writing out the array. With this in mind I decided to add my modified arrays "outBlock" to a python list, then write out chunks. Here is the segment that I changed:

The outputBlock was just modified ...

         #Add the array to a list (tuple)
            outputArrayList.append(outputBlock)

            #Check the interval counter and if it is "time" write out the array
            if len(outputArrayList) >= (intervalX * writeSize) or finisher == 1:

                #Convert the tuple to a numpy array.  Here we horizontally stack the tuple of arrays.
                stacked = numpy.hstack(outputArrayList)

                #Write out the array
                outRaster = outdataset.GetRasterBand(j).WriteArray(stacked,xOffset,i)#array, xOffset, yOffset
                xOffset = xOffset + (intervalX*(intervalX * writeSize))

                #Cleanup to conserve memory
                outputArrayList = list()
                stacked = None
                finisher=0

Finisher is simply a flag that handles the edges. It took a bit of time to figure out how to build an array from the list. In that, using numpy.array was creating a 3-d array (anyone care to explain why?) and write array requires a 2d array. Total processing time is now varying from just under 2 minutes to 5 minutes. Any idea why the range of times might exist?

Many thanks to everyone who posted! The next step is to really get into Numpy and learn about vectorization for additional optimization.

like image 526
Jzl5325 Avatar asked Jul 11 '11 17:07

Jzl5325


People also ask

How can I make NumPy array faster?

By explicitly declaring the "ndarray" data type, your array processing can be 1250x faster. This tutorial will show you how to speed up the processing of NumPy arrays using Cython. By explicitly specifying the data types of variables in Python, Cython can give drastic speed increases at runtime.

How do I speed up my Python iteration?

A faster way to loop in Python is using built-in functions. In our example, we could replace the for loop with the sum function. This function will sum the values inside the range of numbers. The code above takes 0.84 seconds.

Does Numba speed up NumPy?

Numba can speed things up Of course, it turns out that NumPy has a function that will do this already, numpy. maximum. accumulate . Using that, running only takes 0.03 seconds.

Is NumPy faster than for loop?

Looping over Python arrays, lists, or dictionaries, can be slow. Thus, vectorized operations in Numpy are mapped to highly optimized C code, making them much faster than their standard Python counterparts.


2 Answers

One way to speed up operations over numpy data is to use vectorize. Essentially, vectorize takes a function f and creates a new function g that maps f over an array a. g is then called like so: g(a).

>>> sqrt_vec = numpy.vectorize(lambda x: x ** 0.5)
>>> sqrt_vec(numpy.arange(10))
array([ 0.        ,  1.        ,  1.41421356,  1.73205081,  2.        ,
        2.23606798,  2.44948974,  2.64575131,  2.82842712,  3.        ])

Without having the data you're working with available, I can't say for certain whether this will help, but perhaps you can rewrite the above as a set of functions that can be vectorized. Perhaps in this case you could vectorize over an array of indices into ReadAsArray(h,i,numberColumns, numberRows). Here's an example of the potential benefit:

>>> print setup1
import numpy
sqrt_vec = numpy.vectorize(lambda x: x ** 0.5)
>>> print setup2
import numpy
def sqrt_vec(a):
    r = numpy.zeros(len(a))
    for i in xrange(len(a)):
        r[i] = a[i] ** 0.5
    return r
>>> timeit.timeit(stmt='a = sqrt_vec(numpy.arange(1000000))', setup=setup1, number=1)
0.30318188667297363
>>> timeit.timeit(stmt='a = sqrt_vec(numpy.arange(1000000))', setup=setup2, number=1)
4.5400981903076172

A 15x speedup! Note also that numpy slicing handles the edges of ndarrays elegantly:

>>> a = numpy.arange(25).reshape((5, 5))
>>> a[3:7, 3:7]
array([[18, 19],
       [23, 24]])

So if you could get your ReadAsArray data into an ndarray you wouldn't have to do any edge-checking shenanigans.


Regarding your question about reshaping -- reshaping doesn't fundamentally alter the data at all. It just changes the "strides" by which numpy indices the data. When you call the reshape method, the value returned is a new view into the data; the data isn't copied or altered at all, nor is the old view with the old stride information.

>>> a = numpy.arange(25)
>>> b = a.reshape((5, 5))
>>> a
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24])
>>> b
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])
>>> a[5]
5
>>> b[1][0]
5
>>> a[5] = 4792
>>> b[1][0]
4792
>>> a.strides
(8,)
>>> b.strides
(40, 8)
like image 83
senderle Avatar answered Sep 23 '22 22:09

senderle


Answered as requested.

If you are IO bound, you should chunk your reads/writes. Try dumping ~500 MB of data to an ndarray, process it all, write it out and then grab the next ~500 MB. Make sure to reuse the ndarray.

like image 41
matt Avatar answered Sep 23 '22 22:09

matt