Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Speed up numpy.where for extracting integer segments?

I'm trying to work out how to speed up a Python function which uses numpy. The output I have received from lineprofiler is below, and this shows that the vast majority of the time is spent on the line ind_y, ind_x = np.where(seg_image == i).

seg_image is an integer array which is the result of segmenting an image, thus finding the pixels where seg_image == i extracts a specific segmented object. I am looping through lots of these objects (in the code below I'm just looping through 5 for testing, but I'll actually be looping through over 20,000), and it takes a long time to run!

Is there any way in which the np.where call can be speeded up? Or, alternatively, that the penultimate line (which also takes a good proportion of the time) can be speeded up?

The ideal solution would be to run the code on the whole array at once, rather than looping, but I don't think this is possible as there are side-effects to some of the functions I need to run (for example, dilating a segmented object can make it 'collide' with the next region and thus give incorrect results later on).

Does anyone have any ideas?

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     5                                           def correct_hot(hot_image, seg_image):
     6         1       239810 239810.0      2.3      new_hot = hot_image.copy()
     7         1       572966 572966.0      5.5      sign = np.zeros_like(hot_image) + 1
     8         1        67565  67565.0      0.6      sign[:,:] = 1
     9         1      1257867 1257867.0     12.1      sign[hot_image > 0] = -1
    10                                           
    11         1          150    150.0      0.0      s_elem = np.ones((3, 3))
    12                                           
    13                                               #for i in xrange(1,seg_image.max()+1):
    14         6           57      9.5      0.0      for i in range(1,6):
    15         5      6092775 1218555.0     58.5          ind_y, ind_x = np.where(seg_image == i)
    16                                           
    17                                                   # Get the average HOT value of the object (really simple!)
    18         5         2408    481.6      0.0          obj_avg = hot_image[ind_y, ind_x].mean()
    19                                           
    20         5          333     66.6      0.0          miny = np.min(ind_y)
    21                                                   
    22         5          162     32.4      0.0          minx = np.min(ind_x)
    23                                                   
    24                                           
    25         5          369     73.8      0.0          new_ind_x = ind_x - minx + 3
    26         5          113     22.6      0.0          new_ind_y = ind_y - miny + 3
    27                                           
    28         5          211     42.2      0.0          maxy = np.max(new_ind_y)
    29         5          143     28.6      0.0          maxx = np.max(new_ind_x)
    30                                           
    31                                                   # 7 is + 1 to deal with the zero-based indexing, + 2 * 3 to deal with the 3 cell padding above
    32         5          217     43.4      0.0          obj = np.zeros( (maxy+7, maxx+7) )
    33                                           
    34         5          158     31.6      0.0          obj[new_ind_y, new_ind_x] = 1
    35                                           
    36         5         2482    496.4      0.0          dilated = ndimage.binary_dilation(obj, s_elem)
    37         5         1370    274.0      0.0          border = mahotas.borders(dilated)
    38                                           
    39         5          122     24.4      0.0          border = np.logical_and(border, dilated)
    40                                           
    41         5          355     71.0      0.0          border_ind_y, border_ind_x = np.where(border == 1)
    42         5          136     27.2      0.0          border_ind_y = border_ind_y + miny - 3
    43         5          123     24.6      0.0          border_ind_x = border_ind_x + minx - 3
    44                                           
    45         5          645    129.0      0.0          border_avg = hot_image[border_ind_y, border_ind_x].mean()
    46                                           
    47         5      2167729 433545.8     20.8          new_hot[seg_image == i] = (new_hot[ind_y, ind_x] + (sign[ind_y, ind_x] * np.abs(obj_avg - border_avg)))
    48         5        10179   2035.8      0.1          print obj_avg, border_avg
    49                                           
    50         1            4      4.0      0.0      return new_hot
like image 941
robintw Avatar asked Jul 08 '13 15:07

robintw


2 Answers

EDIT I have left my original answer at the bottom for the record, but I have actually looked into your code in more detail over lunch, and I think that using np.where is a big mistake:

In [63]: a = np.random.randint(100, size=(1000, 1000))

In [64]: %timeit a == 42
1000 loops, best of 3: 950 us per loop

In [65]: %timeit np.where(a == 42)
100 loops, best of 3: 7.55 ms per loop

You could get a boolean array (that you can use for indexing) in 1/8 of the time you need to get the actual coordinates of the points!!!

There is of course the cropping of the features that you do, but ndimage has a find_objects function that returns enclosing slices, and appears to be very fast:

In [66]: %timeit ndimage.find_objects(a)
100 loops, best of 3: 11.5 ms per loop

This returns a list of tuples of slices enclosing all of your objects, in 50% more time thn it takes to find the indices of one single object.

It may not work out of the box as I cannot test it right now, but I would restructure your code into something like the following:

def correct_hot_bis(hot_image, seg_image):
    # Need this to not index out of bounds when computing border_avg
    hot_image_padded = np.pad(hot_image, 3, mode='constant',
                              constant_values=0)
    new_hot = hot_image.copy()
    sign = np.ones_like(hot_image, dtype=np.int8)
    sign[hot_image > 0] = -1
    s_elem = np.ones((3, 3))

    for j, slice_ in enumerate(ndimage.find_objects(seg_image)):
        hot_image_view = hot_image[slice_]
        seg_image_view = seg_image[slice_]
        new_shape = tuple(dim+6 for dim in hot_image_view.shape)
        new_slice = tuple(slice(dim.start,
                                dim.stop+6,
                                None) for dim in slice_)
        indices = seg_image_view == j+1

        obj_avg = hot_image_view[indices].mean()

        obj = np.zeros(new_shape)
        obj[3:-3, 3:-3][indices] = True

        dilated = ndimage.binary_dilation(obj, s_elem)
        border = mahotas.borders(dilated)
        border &= dilated

        border_avg = hot_image_padded[new_slice][border == 1].mean()

        new_hot[slice_][indices] += (sign[slice_][indices] *
                                     np.abs(obj_avg - border_avg))

    return new_hot

You would still need to figure out the collisions, but you could get about a 2x speed-up by computing all the indices simultaneously using a np.unique based approach:

a = np.random.randint(100, size=(1000, 1000))

def get_pos(arr):
    pos = []
    for j in xrange(100):
        pos.append(np.where(arr == j))
    return pos

def get_pos_bis(arr):
    unq, flat_idx = np.unique(arr, return_inverse=True)
    pos = np.argsort(flat_idx)
    counts = np.bincount(flat_idx)
    cum_counts = np.cumsum(counts)
    multi_dim_idx = np.unravel_index(pos, arr.shape)
    return zip(*(np.split(coords, cum_counts) for coords in multi_dim_idx))

In [33]: %timeit get_pos(a)
1 loops, best of 3: 766 ms per loop

In [34]: %timeit get_pos_bis(a)
1 loops, best of 3: 388 ms per loop

Note that the pixels for each object are returned in a different order, so you can't simply compare the returns of both functions to assess equality. But they should both return the same.

like image 191
Jaime Avatar answered Oct 21 '22 08:10

Jaime


One thing you could do to same a little bit of time is to save the result of seg_image == i so that you don't need to compute it twice. You're computing it on lines 15 & 47, you could add seg_mask = seg_image == i and then reuse that result (It might also be good to separate out that piece for profiling purposes).

While there a some other minor things that you could do to eke out a little bit of performance, the root issue is that you're using a O(M * N) algorithm where M is the number of segments and N is the size of your image. It's not obvious to me from your code whether there is a faster algorithm to accomplish the same thing, but that's the first place I'd try and look for a speedup.

like image 25
Bi Rico Avatar answered Oct 21 '22 08:10

Bi Rico