Closest pair of points in Python (divide and conquer): the quick implementation

Computing minimum distance between 2 points on a 2d plane

Given 2 list of points with x and respective y coordinates, produce a minimal distance between a pair of 2 points.

Every battle with a hardcore algorithm should start somewhere. I suggest reading Cormen et all “Introduction to Algorithms”, 3rd edition (Section 33.4), but any decent book will do.

We start from a naive implementation of divide-and-conquer approach to the closest pair of points problem:

Let us suppose that we have 2 lists of size n as our inputs: x and y, which correspond to pairs of points (x1,y1) … (xn,yn), where n is number of points.

First, let’s look at the following function:

def solution(x, y):
a = list(zip(x, y)) # This produces list of tuples
ax = sorted(a, key=lambda x: x[0]) # Presorting x-wise
ay = sorted(a, key=lambda x: x[1]) # Presorting y-wise
p1, p2, mi = closest_pair(ax, ay) # Recursive D&C function
return mi

Here we address the concept of presorting. As noted in the book,

Note that in order to attain the O(n * lg (n)) time bound, we cannot afford to sort in each recursive call; if we did, the recurrence for the running time would be T (n) = 2T(n/2) +O(n*lg (n)), whose solution is T (n) = O(n * lg(n)²).

Therefore, presorting outside of function that will be called recursively allows to implement the solution in smaller time complexity.

Let’s look at the recursive call (with the appropriate comments):

def closest_pair(ax, ay):
ln_ax = len(ax) # It's quicker to assign variable
if ln_ax <= 3:
return brute(ax) # A call to bruteforce comparison
mid = ln_ax // 2 # Division without remainder, need int
Qx = ax[:mid] # Two-part split
Rx = ax[mid:]
    # Determine midpoint on x-axis
    midpoint = ax[mid][0]  
Qy = list()
Ry = list()
for x in ay: # split ay into 2 arrays using midpoint
if x[0] <= midpoint:
    # Call recursively both arrays after split
    (p1, q1, mi1) = closest_pair(Qx, Qy)
(p2, q2, mi2) = closest_pair(Rx, Ry)
    # Determine smaller distance between points of 2 arrays
    if mi1 <= mi2:
d = mi1
mn = (p1, q1)
d = mi2
mn = (p2, q2)
    # Call function to account for points on the boundary
    (p3, q3, mi3) = closest_split_pair(ax, ay, d, mn)
    # Determine smallest distance for the array
    if d <= mi3:
return mn[0], mn[1], d
return p3, q3, mi3

The implementation above is done according to the book. However, during the debugging of the algorithm, I’ve found a peculiar feature. If we were to substitute the midpoint split logic to:

qx = set(Qx)
Qy = list()
Ry = list()
for x in ay:
if x in qx:

the code would actually run a little bit faster. I won’t dive into low-level details of it, though a curious one should compare the speeds of comparison

x[0] <= midpoint


x in qx

for a set(). That’s the only reason I can think of.

Now, that’s where it gets interesting. First, the brute(ax) function:

def brute(ax):
mi = dist(ax[0], ax[1])
p1 = ax[0]
p2 = ax[1]
ln_ax = len(ax)
if ln_ax == 2:
return p1, p2, mi
for i in range(ln_ax-1):
for j in range(i + 1, ln_ax):
if i != 0 and j != 1:
d = dist(ax[i], ax[j])
if d < mi: # Update min_dist and points
mi = d
p1, p2 = ax[i], ax[j]
return p1, p2, mi

Let us discuss that in brief. Why mi = distance between first two points from the list? Why not a random and large number? Well, it saves us a computation on each of the many calls to the brute function. That’s a win. Furthermore, if len(ax) == 2, we’re done, result can be returned.

Second important point concerns ranges of our two cycles, which need to be used in case of 3 points (recall that brute is called only if len(ax) ≤ 3). Why do we not need to iterate over len(ax) points for i index? Because we are comparing two points: ax[i] and ax[j], and j is in range from i+1 to len(ax). It means, that we’ll compare all the points in len(ax) anyway. Furthermore, conditions on j index mean that we won’t compare points twice: dist(a[1], a[3]) and dist (a[3], a[1]) as well as dist(a[2], a[2]) situations are not allowed because of the boundaries. It speeds up the algorithm at least 2 times (as opposed to simply having 2 cycles of len(ax)).

Back to our first point. If condition inside loops saves us extra comparison computation.

Distance function (dist) is nothing special:

import math
def dist
(p1, p2):
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)

Finally, one of the most interesting pieces, a function, responsible for finding a closest pair of points on a splitline, closest_split_pair:

def closest_split_pair(p_x, p_y, delta, best_pair):
ln_x = len(p_x) # store length - quicker
mx_x = p_x[ln_x // 2][0] # select midpoint on x-sorted array
    # Create a subarray of points not further than delta from
# midpoint on x-sorted array
    s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
    best = delta  # assign best value to delta
ln_y = len(s_y) # store length of subarray for quickness
for i in range(ln_y - 1):
for j in range(i+1, min(i + 7, ln_y)):
p, q = s_y[i], s_y[j]
dst = dist(p, q)
if dst < best:
best_pair = p, q
best = dst
return best_pair[0], best_pair[1], best

Again, the salt lies in ranges of 2 cycles. They are produced using ideas similar to ones used in brute function, with one important distinction. The upper boundary on j index is min(i+7, ln_y) for reasons discussed in Correctness chapter of Corman et all. In short: it is enough to check only seven points following each point on the s_y subarray. You should really look through the proof of correctness, because it explains a lot better this ‘trick’ that allows for great running speed increase.

P.S.: tips on debugging and testing

Unit tests are mandatory. IDE PyCharm (Ctrl + Shift + T for creating a unit test for method) is recommended. Also, additional reading on stress testing is advised.

I used the following code to create a great test case for testing purposes:

import random
def test_case
(length: int = 10000):
lst1 = [random.randint(-10**9, 10**9) for i in range(length)]
lst2 = [random.randint(-10**9, 10**9) for i in range(length)]
return lst1, lst2

It took about 40 seconds to run initially on my Intel i3 (2 cores, 4 processes), ~2.3 GHz, 8 Gb RAM, SSD (~450 MB/s read/write), which dropped to about 20–30 secs after some optimizations I mentioned.

Another great tool for debugging purposes was my friend’s library of convenient timers (which I posted to my Github after some changes):

It helped to time functions using convenient wrappers, and examples are built in code.

I used wrappers over the functions described above, ran the test case and collected the prints of runtime to json file. Later I passed the results over to SQLite database and used the aggregation functions to get average runtime for each function. I performed same procedure again after adding optimizations and was able to observe % change between the average runtimes of functions to understand whether the optimization improved runtime of a specific function (overall runtime could be compared just from running the unittest example above). I designed this procedure for deep understanding of results and is not necessary for general debug.

Good luck and contact me for extra details on the algorithm or for other suggestions: