Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Speed up computation for Distance Transform on Image in Python

I would like to find the find the distance transform of a binary image in the fastest way possible without using the scipy package distance_trnsform_edt(). The image is 256 by 256. The reason I don't want to use scipy is because using it is difficult in tensorflow. Evry time I want to use this package I need to start a new session and this takes a lot of time. So I would like to make a custom function that only utilizes numpy.

My approach is as follows: Find the coordinated for all the ones and all the zeros in the image. Find the euclidian distance between each of the zero pixels (a) and the one pixels (b) and then the value at each (a) position is the minimum distance to a (b) pixel. I do this for each 0 pixel. The resultant image has the same dimensions as the original binary map. My attempt at doing this is shown below.

I tried to do this as fast as possible using no loops and only vectorization. But my function still can't work as fast as the scipy package can. When I timed the code it looks like the assignment to the variable "a" is taking the longest time. But I do not know if there is a way to speed this up.

If anyone has any other suggestions for different algorithms to solve this problem of distance transforms or can direct me to other implementations in python, it would be very appreciated.

def get_dst_transform_img(og): #og is a numpy array of original image
    ones_loc = np.where(og == 1)
    ones = np.asarray(ones_loc).T # coords of all ones in og
    zeros_loc = np.where(og == 0)
    zeros = np.asarray(zeros_loc).T # coords of all zeros in og

    a = -2 * np.dot(zeros, ones.T) 
    b = np.sum(np.square(ones), axis=1) 
    c = np.sum(np.square(zeros), axis=1)[:,np.newaxis]
    dists = a + b + c
    dists = np.sqrt(dists.min(axis=1)) # min dist of each zero pixel to one pixel
    x = og.shape[0]
    y = og.shape[1]
    dist_transform = np.zeros((x,y))
    dist_transform[zeros[:,0], zeros[:,1]] = dists 

    plt.figure()
    plt.imshow(dist_transform)
like image 245
user4500293 Avatar asked Dec 08 '18 00:12

user4500293


1 Answers

The implementation in the OP is a brute-force approach to the distance transform. This algorithm is O(n2), as it computes the distance from each background pixel to each foreground pixel. Furthermore, because of the way it is vectorized, it requires a lot of memory. On my computer it couldn't compute the distance transform of a 256x256 image without thrashing. Many other algorithms are described in the literature, below I'll discuss two O(n) algorithms.

Note: Typically, the distance transform is computed for object pixels (value 1) to the nearest background pixel (value 0). The code in the OP does the reverse, and so the code I've pasted below follows OP's convention, not the more common convention.


The easiest to implement, IMO, is the chamfer distance algorithm. This is a recursive algorithm that does two passes over the image: one left to right and top to bottom, and one right to left and bottom to top. In each pass, the distance computed for previous pixels is propagated. This algorithm can be implemented using integer distances or floating-point distances between neighbors. The latter yields smaller errors, of course. But in both cases the errors can be reduced significantly by increasing the number of neighbors queried in this propagation. The algorithm is older, but G. Borgefors analyzed it and proposed suitable neighbor distances (G. Borgefors, Distance Transformations in Digital Images, Computer Vision, Graphics, and Image Processing 34:344-371, 1986).

Here is an implementation using 3-4 distance (distance to edge-connected neighbors is 3, distance to vertex-connected neighbors is 4):

def chamfer_distance(img):
   w, h = img.shape
   dt = np.zeros((w,h), np.uint32)
   # Forward pass
   x = 0
   y = 0
   if img[x,y] == 0:
      dt[x,y] = 65535 # some large value
   for x in range(1, w):
      if img[x,y] == 0:
         dt[x,y] = 3 + dt[x-1,y]
   for y in range(1, h):
      x = 0
      if img[x,y] == 0:
         dt[x,y] = min(3 + dt[x,y-1], 4 + dt[x+1,y-1])
      for x in range(1, w-1):
         if img[x,y] == 0:
            dt[x,y] = min(4 + dt[x-1,y-1], 3 + dt[x,y-1], 4 + dt[x+1,y-1], 3 + dt[x-1,y])
      x = w-1
      if img[x,y] == 0:
         dt[x,y] = min(4 + dt[x-1,y-1], 3 + dt[x,y-1], 3 + dt[x-1,y])
   # Backward pass
   for x in range(w-2, -1, -1):
      y = h-1
      if img[x,y] == 0:
         dt[x,y] = min(dt[x,y], 3 + dt[x+1,y])
   for y in range(h-2, -1, -1):
      x = w-1
      if img[x,y] == 0:
         dt[x,y] = min(dt[x,y], 3 + dt[x,y+1], 4 + dt[x-1,y+1])
      for x in range(1, w-1):
         if img[x,y] == 0:
            dt[x,y] = min(dt[x,y], 4 + dt[x+1,y+1], 3 + dt[x,y+1], 4 + dt[x-1,y+1], 3 + dt[x+1,y])
      x = 0
      if img[x,y] == 0:
         dt[x,y] = min(dt[x,y], 4 + dt[x+1,y+1], 3 + dt[x,y+1], 3 + dt[x+1,y])
   return dt

Note that a lot of the complication here is to avoid indexing out of bounds, but still computing distances all the way to the edges of the image. If we simply skip the pixels around the border of the image, the code becomes much simpler.

Because it is a recursive algorithm, it is not possible to vectorize its implementation. The Python code will not be very efficient. But programmed in C or the like will yield a very fast algorithm that yields a fairly good approximation to the Euclidean distance.

OpenCV's cv.distanceTransform implements this algorithm.


Another very efficient algorithm computes the square of the distance transform. The square distance is separable (i.e. can be computed independently for each axis and added). This leads to an algorithm that is easy to parallelize. For each image row, the algorithm does a forward and a backward pass. For each column in the result, the algorithm then does another forward and backward pass. This process leads to an exact Euclidean distance transform.

This algorithm was first proposed by R. van den Boomgaard in his Ph.D. thesis in 1992. Unfortunately this went unnoticed. The algorithm was then again proposed by A. Meijster, J.B.T.M. Roerdink and W.H. Hesselink (A General Algorithm for Computing Distance Transforms in Linear Time, Mathematical Morphology and its Applications to Image and Signal Processing, pp 331-340, 2002), and again by P. Felzenszwalb and D. Huttenlocher (Distance transforms of sampled functions, Technical report, Cornell University, 2004).

This is the most efficient algorithm known, in part because it is the only one that can be easily and efficiently parallelized (computation on each image row, and later on each image column, is independent of other rows/columns).

Unfortunately I don't have any Python code for this one to share, but you can find implementations online. For example OpenCV's cv.distanceTransform implements this algorithm, and DIPlib's dip.EuclideanDistanceTransform does too.

like image 171
Cris Luengo Avatar answered Oct 12 '22 23:10

Cris Luengo