Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

geodesic distance transform in python

In python there is the distance_transform_edt function in the scipy.ndimage.morphology module. I applied it to a simple case, to compute the distance from a single cell in a masked numpy array.

However the function remove the mask of the array and compute, as expected, the Euclidean distance for each cell, with non null value, from the reference cell, with the null value.

Below is an example I gave in my blog post:

%pylab
from scipy.ndimage.morphology import distance_transform_edt
l = 100
x, y = np.indices((l, l))
center1 = (50, 20)
center2 = (28, 24)
center3 = (30, 50)
center4 = (60,48)
radius1, radius2, radius3, radius4 = 15, 12, 19, 12
circle1 = (x - center1[0])**2 + (y - center1[1])**2 < radius1**2
circle2 = (x - center2[0])**2 + (y - center2[1])**2 < radius2**2
circle3 = (x - center3[0])**2 + (y - center3[1])**2 < radius3**2
circle4 = (x - center4[0])**2 + (y - center4[1])**2 < radius4**2
# 3 circles
img = circle1 + circle2 + circle3 + circle4
mask = ~img.astype(bool)
img = img.astype(float)
m = ones_like(img)
m[center1] = 0
#imshow(distance_transform_edt(m), interpolation='nearest')
m = ma.masked_array(distance_transform_edt(m), mask)
imshow(m, interpolation='nearest')

Euclidean distance transform

However I want to compute the geodesic distance transform that take into account the masked elements of the array. I do not want to compute the Euclidean distance along a straight line that go through masked elements.

I used The Dijkstra algorithm to obtain the result I want. Below is the implementation I proposed:

def geodesic_distance_transform(m):
    mask = m.mask
    visit_mask = mask.copy() # mask visited cells
    m = m.filled(numpy.inf)
    m[m!=0] = numpy.inf
    distance_increments = numpy.asarray([sqrt(2), 1., sqrt(2), 1., 1., sqrt(2), 1., sqrt(2)])
    connectivity = [(i,j) for i in [-1, 0, 1] for j in [-1, 0, 1] if (not (i == j == 0))]
    cc = unravel_index(m.argmin(), m.shape) # current_cell
    while (~visit_mask).sum() > 0:
        neighbors = [tuple(e) for e in asarray(cc) - connectivity 
                     if not visit_mask[tuple(e)]]
        tentative_distance = [distance_increments[i] for i,e in enumerate(asarray(cc) - connectivity) 
                              if not visit_mask[tuple(e)]]
        for i,e in enumerate(neighbors):
            d = tentative_distance[i] + m[cc]
            if d < m[e]:
                m[e] = d
        visit_mask[cc] = True
        m_mask = ma.masked_array(m, visit_mask)
        cc = unravel_index(m_mask.argmin(), m.shape)
    return m

gdt = geodesic_distance_transform(m)
imshow(gdt, interpolation='nearest')
colorbar()

enter image description here

The function implemented above works well but is too slow for the application I developed which needs to compute the geodesic distance transform several times.

Below is the time benchmark of the euclidean distance transform and the geodesic distance transform:

%timeit distance_transform_edt(m)
1000 loops, best of 3: 1.07 ms per loop

%timeit geodesic_distance_transform(m)
1 loops, best of 3: 702 ms per loop

How can I obtained a faster geodesic distance transform?

like image 368
bougui Avatar asked Jan 28 '15 08:01

bougui


People also ask

How do you calculate geodesic distance?

The simplest way to calculate geodesic distance is to find the angle between the two points, and multiply this by the circumference of the earth. The formula is: angle = arccos(point1 * point2) distance = angle * pi * radius.

Is geodesic a distance?

A simple measure of the distance between two vertices in a graph is the shortest path between the vertices. Formally, the geodesic distance between two vertices is the length in terms of the number of edges of the shortest path between the vertices.

What is the difference between Euclidean distance and geodesic distance?

While the Euclidean distance calculates only the distance by ignoring the shape of the dataset, the geodesic distance is calculated by passing the shortest path on the dataset.

What is Euclidean distance transform?

The distance transform (DT) is a general operator forming the basis of many methods in computer vision and geometry, with great potential for practical applications. However, all the optimal algorithms for the computation of the exact Euclidean DT (EDT) were proposed only since the 1990s.


3 Answers

First of all, thumbs up for a very clear and well written question.

There is a very good and fast implementation of a Fast Marching method called scikit-fmm to solve this kind of problem. You can find the details here: http://pythonhosted.org//scikit-fmm/

Installing it might be the hardest part, but on Windows with Conda its easy, since there is 64bit Conda package for Py27: https://binstar.org/jmargeta/scikit-fmm

From there on, just pass your masked array to it, as you do with your own function. Like:

distance = skfmm.distance(m)

The results looks similar, and i think even slightly better. Your approach searches (apparently) in eight distinct directions resulting in a bit of a 'octagonal-shaped` distance.

enter image description here

On my machine the scikit-fmm implementation is over 200x faster then your function.

enter image description here

like image 89
Rutger Kassies Avatar answered Sep 23 '22 08:09

Rutger Kassies


64-bit Windows binaries for scikit-fmm are now available from Christoph Gohlke.

http://www.lfd.uci.edu/~gohlke/pythonlibs/#scikit-fmm

like image 20
Jason Furtney Avatar answered Sep 24 '22 08:09

Jason Furtney


A slightly faster (about 10x) implementation that achieves the same result as your geodesic_distance_transform:

def getMissingMask(slab):

    nan_mask=numpy.where(numpy.isnan(slab),1,0)
    if not hasattr(slab,'mask'):
        mask_mask=numpy.zeros(slab.shape)
    else:
        if slab.mask.size==1 and slab.mask==False:
            mask_mask=numpy.zeros(slab.shape)
        else:
            mask_mask=numpy.where(slab.mask,1,0)
    mask=numpy.where(mask_mask+nan_mask>0,1,0)

    return mask

def geodesic(img,seed):

    seedy,seedx=seed
    mask=getMissingMask(img)

    #----Call distance_transform_edt if no missing----
    if mask.sum()==0:
        slab=numpy.ones(img.shape)
        slab[seedy,seedx]=0
        return distance_transform_edt(slab)

    target=(1-mask).sum()
    dist=numpy.ones(img.shape)*numpy.inf
    dist[seedy,seedx]=0

    def expandDir(img,direction):
        if direction=='n':
            l1=img[0,:]
            img=numpy.roll(img,1,axis=0)
            img[0,:]==l1
        elif direction=='s':
            l1=img[-1,:]
            img=numpy.roll(img,-1,axis=0)
            img[-1,:]==l1
        elif direction=='e':
            l1=img[:,0]
            img=numpy.roll(img,1,axis=1)
            img[:,0]=l1
        elif direction=='w':
            l1=img[:,-1]
            img=numpy.roll(img,-1,axis=1)
            img[:,-1]==l1
        elif direction=='ne':
            img=expandDir(img,'n')
            img=expandDir(img,'e')
        elif direction=='nw':
            img=expandDir(img,'n')
            img=expandDir(img,'w')
        elif direction=='sw':
            img=expandDir(img,'s')
            img=expandDir(img,'w')
        elif direction=='se':
            img=expandDir(img,'s')
            img=expandDir(img,'e')

        return img

    def expandIter(img):
        sqrt2=numpy.sqrt(2)
        tmps=[]
        for dirii,dd in zip(['n','s','e','w','ne','nw','sw','se'],\
                [1,]*4+[sqrt2,]*4):
            tmpii=expandDir(img,dirii)+dd
            tmpii=numpy.minimum(tmpii,img)
            tmps.append(tmpii)
        img=reduce(lambda x,y:numpy.minimum(x,y),tmps)

        return img

    #----------------Iteratively expand----------------
    dist_old=dist
    while True:
        expand=expandIter(dist)
        dist=numpy.where(mask,dist,expand)
        nc=dist.size-len(numpy.where(dist==numpy.inf)[0])

        if nc>=target or numpy.all(dist_old==dist):
            break
        dist_old=dist

    return dist

Also note that if the mask forms more than 1 connected regions (e.g. adding another circle not touching the others), your function will fall into an endless loop.

UPDATE:

I found one Cython implementation of Fast Sweeping method in this notebook, which can be used to achieve the same result as scikit-fmm with probably comparable speed. One just need to feed a binary flag matrix (with 1s as viable points, inf otherwise) as the cost to the GDT() function.

like image 39
Jason Avatar answered Sep 22 '22 08:09

Jason