Closest pair of points in Python (divide and conquer): the quick implementation
Computing minimum distance between 2 points on a 2d plane
Given 2 list of points with x and respective y coordinates, produce a minimal distance between a pair of 2 points.
Every battle with a hardcore algorithm should start somewhere. I suggest reading Cormen et all “Introduction to Algorithms”, 3rd edition (Section 33.4), but any decent book will do.
We start from a naive implementation of divide-and-conquer approach to the closest pair of points problem:
Let us suppose that we have 2 lists of size n as our inputs: x and y, which correspond to pairs of points (x1,y1) … (xn,yn), where n is number of points.
First, let’s look at the following function:
def solution(x, y):
a = list(zip(x, y)) # This produces list of tuples
ax = sorted(a, key=lambda x: x[0]) # Presorting x-wise
ay = sorted(a, key=lambda x: x[1]) # Presorting y-wise
p1, p2, mi = closest_pair(ax, ay) # Recursive D&C function
return mi
Here we address the concept of presorting. As noted in the book,
Note that in order to attain the O(n * lg (n)) time bound, we cannot afford to sort in each recursive call; if we did, the recurrence for the running time would be T (n) = 2T(n/2) +O(n*lg (n)), whose solution is T (n) = O(n * lg(n)²).
Therefore, presorting outside of function that will be called recursively allows to implement the solution in smaller time complexity.
Let’s look at the recursive call (with the appropriate comments):
def closest_pair(ax, ay):
ln_ax = len(ax) # It's quicker to assign variable
if ln_ax <= 3:
return brute(ax) # A call to bruteforce comparison
mid = ln_ax // 2 # Division without remainder, need int
Qx = ax[:mid] # Two-part split
Rx = ax[mid:] # Determine midpoint on x-axis midpoint = ax[mid][0]
Qy = list()
Ry = list()
for x in ay: # split ay into 2 arrays using midpoint
if x[0] <= midpoint:
Qy.append(x)
else:
Ry.append(x) # Call recursively both arrays after split (p1, q1, mi1) = closest_pair(Qx, Qy)
(p2, q2, mi2) = closest_pair(Rx, Ry) # Determine smaller distance between points of 2 arrays if mi1 <= mi2:
d = mi1
mn = (p1, q1)
else:
d = mi2
mn = (p2, q2) # Call function to account for points on the boundary (p3, q3, mi3) = closest_split_pair(ax, ay, d, mn) # Determine smallest distance for the array if d <= mi3:
return mn[0], mn[1], d
else:
return p3, q3, mi3
The implementation above is done according to the book. However, during the debugging of the algorithm, I’ve found a peculiar feature. If we were to substitute the midpoint split logic to:
qx = set(Qx)
Qy = list()
Ry = list()
for x in ay:
if x in qx:
Qy.append(x)
else:
Ry.append(x)
the code would actually run a little bit faster. I won’t dive into low-level details of it, though a curious one should compare the speeds of comparison
x[0] <= midpoint
to
x in qx
for a set(). That’s the only reason I can think of.
Now, that’s where it gets interesting. First, the brute(ax) function:
def brute(ax):
mi = dist(ax[0], ax[1])
p1 = ax[0]
p2 = ax[1]
ln_ax = len(ax)
if ln_ax == 2:
return p1, p2, mi
for i in range(ln_ax-1):
for j in range(i + 1, ln_ax):
if i != 0 and j != 1:
d = dist(ax[i], ax[j])
if d < mi: # Update min_dist and points
mi = d
p1, p2 = ax[i], ax[j]
return p1, p2, mi
Let us discuss that in brief. Why mi = distance between first two points from the list? Why not a random and large number? Well, it saves us a computation on each of the many calls to the brute function. That’s a win. Furthermore, if len(ax) == 2, we’re done, result can be returned.
Second important point concerns ranges of our two cycles, which need to be used in case of 3 points (recall that brute is called only if len(ax) ≤ 3). Why do we not need to iterate over len(ax) points for i index? Because we are comparing two points: ax[i] and ax[j], and j is in range from i+1 to len(ax). It means, that we’ll compare all the points in len(ax) anyway. Furthermore, conditions on j index mean that we won’t compare points twice: dist(a[1], a[3]) and dist (a[3], a[1]) as well as dist(a[2], a[2]) situations are not allowed because of the boundaries. It speeds up the algorithm at least 2 times (as opposed to simply having 2 cycles of len(ax)).
Back to our first point. If condition inside loops saves us extra comparison computation.
Distance function (dist) is nothing special:
import math
def dist(p1, p2):
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
Finally, one of the most interesting pieces, a function, responsible for finding a closest pair of points on a splitline, closest_split_pair:
def closest_split_pair(p_x, p_y, delta, best_pair):
ln_x = len(p_x) # store length - quicker
mx_x = p_x[ln_x // 2][0] # select midpoint on x-sorted array # Create a subarray of points not further than delta from
# midpoint on x-sorted array s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta] best = delta # assign best value to delta
ln_y = len(s_y) # store length of subarray for quickness
for i in range(ln_y - 1):
for j in range(i+1, min(i + 7, ln_y)):
p, q = s_y[i], s_y[j]
dst = dist(p, q)
if dst < best:
best_pair = p, q
best = dst
return best_pair[0], best_pair[1], best
Again, the salt lies in ranges of 2 cycles. They are produced using ideas similar to ones used in brute function, with one important distinction. The upper boundary on j index is min(i+7, ln_y) for reasons discussed in Correctness chapter of Corman et all. In short: it is enough to check only seven points following each point on the s_y subarray. You should really look through the proof of correctness, because it explains a lot better this ‘trick’ that allows for great running speed increase.
P.S.: tips on debugging and testing
Unit tests are mandatory. IDE PyCharm (Ctrl + Shift + T for creating a unit test for method) is recommended. Also, additional reading on stress testing is advised.
I used the following code to create a great test case for testing purposes:
import random
def test_case(length: int = 10000):
lst1 = [random.randint(-10**9, 10**9) for i in range(length)]
lst2 = [random.randint(-10**9, 10**9) for i in range(length)]
return lst1, lst2
It took about 40 seconds to run initially on my Intel i3 (2 cores, 4 processes), ~2.3 GHz, 8 Gb RAM, SSD (~450 MB/s read/write), which dropped to about 20–30 secs after some optimizations I mentioned.
Another great tool for debugging purposes was my friend’s library of convenient timers (which I posted to my Github after some changes):
It helped to time functions using convenient wrappers, and examples are built in code.
I used wrappers over the functions described above, ran the test case and collected the prints of runtime to json file. Later I passed the results over to SQLite database and used the aggregation functions to get average runtime for each function. I performed same procedure again after adding optimizations and was able to observe % change between the average runtimes of functions to understand whether the optimization improved runtime of a specific function (overall runtime could be compared just from running the unittest example above). I designed this procedure for deep understanding of results and is not necessary for general debug.
Good luck and contact me for extra details on the algorithm or for other suggestions: andriy.lazorenko@gmail.com
P.S.: this story is a part of my series on algorithmic challenges. Check out other cool algorithms decomposed with tests and jupyter notebooks!