# Closest pair of points in Python (divide and conquer): the quick implementation

## Computing minimum distance between 2 points on a 2d plane

Given 2 list of points with x and respective y coordinates, produce a minimal distance between a pair of 2 points.

Every battle with a hardcore algorithm should start somewhere. I suggest reading Cormen et all “Introduction to Algorithms”, 3rd edition (Section 33.4), but any decent book will do.

We start from a naive implementation of divide-and-conquer approach to the closest pair of points problem:

Let us suppose that we have 2 lists of size n as our inputs: x and y, which correspond to pairs of points (x1,y1) … (xn,yn), where n is number of points.

First, let’s look at the following function:

`def solution(x, y):    a = list(zip(x, y))  # This produces list of tuples    ax = sorted(a, key=lambda x: x[0])  # Presorting x-wise    ay = sorted(a, key=lambda x: x[1])  # Presorting y-wise    p1, p2, mi = closest_pair(ax, ay)  # Recursive D&C function    return mi`

Here we address the concept of presorting. As noted in the book,

Note that in order to attain the O(n * lg (n)) time bound, we cannot afford to sort in each recursive call; if we did, the recurrence for the running time would be T (n) = 2T(n/2) +O(n*lg (n)), whose solution is T (n) = O(n * lg(n)²).

Therefore, presorting outside of function that will be called recursively allows to implement the solution in smaller time complexity.

Let’s look at the recursive call (with the appropriate comments):

`def closest_pair(ax, ay):    ln_ax = len(ax)  # It's quicker to assign variable    if ln_ax <= 3:        return brute(ax)  # A call to bruteforce comparison    mid = ln_ax // 2  # Division without remainder, need int    Qx = ax[:mid]  # Two-part split    Rx = ax[mid:]`
`    # Determine midpoint on x-axis`
`    midpoint = ax[mid][0]      Qy = list()    Ry = list()    for x in ay:  # split ay into 2 arrays using midpoint        if x[0] <= midpoint:           Qy.append(x)        else:           Ry.append(x)`
`    # Call recursively both arrays after split`
`    (p1, q1, mi1) = closest_pair(Qx, Qy)    (p2, q2, mi2) = closest_pair(Rx, Ry)`
`    # Determine smaller distance between points of 2 arrays`
`    if mi1 <= mi2:        d = mi1        mn = (p1, q1)    else:        d = mi2        mn = (p2, q2)`
`    # Call function to account for points on the boundary`
`    (p3, q3, mi3) = closest_split_pair(ax, ay, d, mn)`
`    # Determine smallest distance for the array`
`    if d <= mi3:        return mn[0], mn[1], d    else:        return p3, q3, mi3`

The implementation above is done according to the book. However, during the debugging of the algorithm, I’ve found a peculiar feature. If we were to substitute the midpoint split logic to:

`qx = set(Qx)Qy = list()Ry = list()for x in ay:    if x in qx:        Qy.append(x)    else:        Ry.append(x)`

the code would actually run a little bit faster. I won’t dive into low-level details of it, though a curious one should compare the speeds of comparison

`x[0] <= midpoint`

to

`x in qx`

for a set(). That’s the only reason I can think of.

Now, that’s where it gets interesting. First, the brute(ax) function:

`def brute(ax):    mi = dist(ax[0], ax[1])    p1 = ax[0]    p2 = ax[1]    ln_ax = len(ax)    if ln_ax == 2:        return p1, p2, mi    for i in range(ln_ax-1):        for j in range(i + 1, ln_ax):            if i != 0 and j != 1:                d = dist(ax[i], ax[j])                if d < mi:  # Update min_dist and points                    mi = d                    p1, p2 = ax[i], ax[j]    return p1, p2, mi`

Let us discuss that in brief. Why mi = distance between first two points from the list? Why not a random and large number? Well, it saves us a computation on each of the many calls to the brute function. That’s a win. Furthermore, if len(ax) == 2, we’re done, result can be returned.

Second important point concerns ranges of our two cycles, which need to be used in case of 3 points (recall that brute is called only if len(ax) ≤ 3). Why do we not need to iterate over len(ax) points for i index? Because we are comparing two points: ax[i] and ax[j], and j is in range from i+1 to len(ax). It means, that we’ll compare all the points in len(ax) anyway. Furthermore, conditions on j index mean that we won’t compare points twice: dist(a[1], a[3]) and dist (a[3], a[1]) as well as dist(a[2], a[2]) situations are not allowed because of the boundaries. It speeds up the algorithm at least 2 times (as opposed to simply having 2 cycles of len(ax)).

Back to our first point. If condition inside loops saves us extra comparison computation.

Distance function (dist) is nothing special:

`import mathdef dist(p1, p2):    return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)`

Finally, one of the most interesting pieces, a function, responsible for finding a closest pair of points on a splitline, closest_split_pair:

`def closest_split_pair(p_x, p_y, delta, best_pair):    ln_x = len(p_x)  # store length - quicker    mx_x = p_x[ln_x // 2][0]  # select midpoint on x-sorted array`
`    # Create a subarray of points not further than delta from    # midpoint on x-sorted array`
`    s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]`
`    best = delta  # assign best value to delta    ln_y = len(s_y)  # store length of subarray for quickness    for i in range(ln_y - 1):        for j in range(i+1, min(i + 7, ln_y)):            p, q = s_y[i], s_y[j]            dst = dist(p, q)            if dst < best:                best_pair = p, q                best = dst    return best_pair[0], best_pair[1], best`

Again, the salt lies in ranges of 2 cycles. They are produced using ideas similar to ones used in brute function, with one important distinction. The upper boundary on j index is min(i+7, ln_y) for reasons discussed in Correctness chapter of Corman et all. In short: it is enough to check only seven points following each point on the s_y subarray. You should really look through the proof of correctness, because it explains a lot better this ‘trick’ that allows for great running speed increase.

P.S.: tips on debugging and testing

Unit tests are mandatory. IDE PyCharm (Ctrl + Shift + T for creating a unit test for method) is recommended. Also, additional reading on stress testing is advised.

I used the following code to create a great test case for testing purposes:

`import randomdef test_case(length: int = 10000):    lst1 = [random.randint(-10**9, 10**9) for i in range(length)]    lst2 = [random.randint(-10**9, 10**9) for i in range(length)]    return lst1, lst2`

It took about 40 seconds to run initially on my Intel i3 (2 cores, 4 processes), ~2.3 GHz, 8 Gb RAM, SSD (~450 MB/s read/write), which dropped to about 20–30 secs after some optimizations I mentioned.

Another great tool for debugging purposes was my friend’s library of convenient timers (which I posted to my Github after some changes):

It helped to time functions using convenient wrappers, and examples are built in code.

I used wrappers over the functions described above, ran the test case and collected the prints of runtime to json file. Later I passed the results over to SQLite database and used the aggregation functions to get average runtime for each function. I performed same procedure again after adding optimizations and was able to observe % change between the average runtimes of functions to understand whether the optimization improved runtime of a specific function (overall runtime could be compared just from running the unittest example above). I designed this procedure for deep understanding of results and is not necessary for general debug.

Good luck and contact me for extra details on the algorithm or for other suggestions: andriy.lazorenko@gmail.com