Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Closest Pair Implemetation Python

I am trying to implement the closest pair problem in Python using divide and conquer, everything seems to work fine except that in some input cases, there is a wrong answer. My code is as follows:

def closestSplitPair(Px,Py,d):
    X = Px[len(Px)-1][0]
    Sy = [item for item in Py if item[0]>=X-d and item[0]<=X+d]
    best,p3,q3 = d,None,None
    for i in xrange(0,len(Sy)-2):
        for j in xrange(1,min(7,len(Sy)-1-i)):
            if dist(Sy[i],Sy[i+j]) < best:
                best = (Sy[i],Sy[i+j])
                p3,q3 = Sy[i],Sy[i+j]
    return (p3,q3,best)

I am calling the above function through a recursive function which is as follows:

def closestPair(Px,Py): """Px and Py are input arrays sorted according to
their x and y coordinates respectively"""
    if len(Px) <= 3:
        return min_dist(Px)
    else:
        mid = len(Px)/2
        Qx = Px[:mid] ### x-sorted left side of P
        Qy = Py[:mid] ### y-sorted left side of P
        Rx = Px[mid:] ### x-sorted right side of P
        Ry = Py[mid:] ### y-sorted right side of P
        (p1,q1,d1) = closestPair(Qx,Qy)
        (p2,q2,d2) = closestPair(Rx,Ry)
        d = min(d1,d2)
        (p3,q3,d3) = closestSplitPair(Px,Py,d)
        return min((p1,q1,d1),(p2,q2,d2),(p3,q3,d3),key=lambda tup: tup[2])

where min_dist(P) is the brute force implementation of the closest pair algorithm for a list P having 3 or less elements and returns a tuple containing the pair of closest points and their distance.

If my input is P = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)], then my output is ((5,8),(7,6),2.8284271) which is the correct output. But when my input is P = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)] the output I get is ((78, 2), (94, 5), 16.278820596099706) whereas the correct output should be ((94, 5), (99, -8), 13.92838827718412)

like image 629
maverick93 Avatar asked Jan 30 '15 14:01

maverick93


People also ask

What is the closest pair explain the closest pair algorithm?

In this problem, a set of n points are given on the 2D plane. In this problem, we have to find the pair of points, whose distance is minimum. To solve this problem, we have to divide points into two halves, after that smallest distance between two points is calculated in a recursive way.

How do you analyze brute force with the closest pair problem?

Brute-Force Method — Finding the Closest Pair The brute-force way is, like one that counts inversions in an array, to calculate the distances of every pair of points in the universe. For n number of points, we would need to measure n(n-1)/2 distances and the cost is square to n, or Θ(n²).

What is the efficiency of applying the brute force method to solve the closest pair problem in terms of runtime?

What is the runtime efficiency of using brute force technique for the closest pair problem? Explanation: The efficiency of closest pair algorithm by brute force technique is mathematically found to be O(N2). 4.


2 Answers

You have two problems, you are forgetting to call dist to update the best distance. But the main problem is there is more than one recursive call happening so you can end up overwriting when you find a closer split pair with the default, best,p3,q3 = d,None,None. I passed the best pair from closest_pair as an argument to closest_split_pair so I would not potentially overwrite the value.

def closest_split_pair(p_x, p_y, delta, best_pair): # <- a parameter
    ln_x = len(p_x)
    mx_x = p_x[ln_x // 2][0]
    s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
    best = delta
    for i in range(len(s_y) - 1):
        for j in range(1, min(i + 7, (len(s_y) - i))):
            p, q = s_y[i], s_y[i + j]
            dst = dist(p, q)
            if dst < best:
                best_pair = p, q
                best = dst
    return best_pair

The end of closest_pair looks like the following:

    p_1, q_1 = closest_pair(srt_q_x, srt_q_y)
    p_2, q_2 = closest_pair(srt_r_x, srt_r_y)
    closest = min(dist(p_1, q_1), dist(p_2, q_2))
    # get min of both and then pass that as an arg to closest_split_pair
    mn = min((p_1, q_1), (p_2, q_2), key=lambda x: dist(x[0], x[1]))
    p_3, q_3 = closest_split_pair(p_x, p_y, closest,mn)
    # either return mn or we have a closer split pair
    return min(mn, (p_3, q_3), key=lambda x: dist(x[0], x[1]))

You also have some other logic issues, your slicing logic is not correct, I made some changes to your code where brute is just a simple bruteforce double loop:

def closestPair(Px, Py):
    if len(Px) <= 3:
        return brute(Px)

    mid = len(Px) / 2
    # get left and right half of Px 
    q, r = Px[:mid], Px[mid:]
     # sorted versions of q and r by their x and y coordinates 
    Qx, Qy = [x for x in q if Py and  x[0] <= Px[-1][0]], [x for x in q if x[1] <= Py[-1][1]]
    Rx, Ry = [x for x in r if Py and x[0] <= Px[-1][0]], [x for x in r if x[1] <= Py[-1][1]]
    (p1, q1) = closestPair(Qx, Qy)
    (p2, q2) = closestPair(Rx, Ry)
    d = min(dist(p1, p2), dist(p2, q2))
    mn = min((p1, q1), (p2, q2), key=lambda x: dist(x[0], x[1]))
    (p3, q3) = closest_split_pair(Px, Py, d, mn)
    return min(mn, (p3, q3), key=lambda x: dist(x[0], x[1]))

I just did the algorithm today so there are no doubt some improvements to be made but this will get you the correct answer.

like image 184
Padraic Cunningham Avatar answered Oct 04 '22 06:10

Padraic Cunningham


Here is a recursive divide-and-conquer python implementation of the closest point problem based on the heap data structure. It also accounts for the negative integers. It can return the k-closest point by popping k nodes in the heap using heappop().

from __future__ import division
from collections import namedtuple
from random import randint
import math as m
import heapq as hq

def get_key(item):
    return(item[0])


def closest_point_problem(points):
    point = []
    heap = []
    pt = namedtuple('pt', 'x y')
    for i in range(len(points)):
        point.append(pt(points[i][0], points[i][1]))
    point = sorted(point, key=get_key)
    visited_index = []
    find_min(0, len(point) - 1, point, heap, visited_index)
    print(hq.heappop(heap))

def find_min(start, end, point, heap, visited_index):
    if len(point[start:end + 1]) & 1:
        mid = start + ((len(point[start:end + 1]) + 1) >> 1)
    else:
        mid = start + (len(point[start:end + 1]) >> 1)
        if start in visited_index:
            start = start + 1
        if end in visited_index:
            end = end - 1
    if len(point[start:end + 1]) > 3:
        if start < mid - 1:
            distance1 = m.sqrt((point[start].x - point[start + 1].x) ** 2 + (point[start].y - point[start + 1].y) ** 2)
            distance2 = m.sqrt((point[mid].x - point[mid - 1].x) ** 2 + (point[mid].y - point[mid - 1].y) ** 2)
            if distance1 < distance2:
                hq.heappush(heap, (distance1, ((point[start].x, point[start].y), (point[start + 1].x, point[start + 1].y))))
            else:
                hq.heappush(heap, (distance2, ((point[mid].x, point[mid].y), (point[mid - 1].x, point[mid - 1].y))))
            visited_index.append(start)
            visited_index.append(start + 1)
            visited_index.append(mid)
            visited_index.append(mid - 1)
            find_min(start, mid, point, heap, visited_index)
        if mid + 1 < end:
            distance1 = m.sqrt((point[mid].x - point[mid + 1].x) ** 2 + (point[mid].y - point[mid + 1].y) ** 2)
            distance2 = m.sqrt((point[end].x - point[end - 1].x) ** 2 + (point[end].y - point[end - 1].y) ** 2)
            if distance1 < distance2:
                hq.heappush(heap, (distance1, ((point[mid].x, point[mid].y), (point[mid + 1].x, point[mid + 1].y))))
            else:
                hq.heappush(heap, (distance2, ((point[end].x, point[end].y), (point[end - 1].x, point[end - 1].y))))
            visited_index.append(end)
            visited_index.append(end - 1)
            visited_index.append(mid)
            visited_index.append(mid + 1)
            find_min(mid, end, point, heap, visited_index)

x = []
num_points = 10
for i in range(num_points):
    x.append((randint(- num_points << 2, num_points << 2), randint(- num_points << 2, num_points << 2)))
closest_point_problem(x)

:)

like image 42
Nikki Mino Avatar answered Oct 04 '22 06:10

Nikki Mino