Closest pair of points in Python (divide and conquer): the quick implementation

Computing minimum distance between 2 points on a 2d plane

Given 2 list of points with x and respective y coordinates, produce a minimal distance between a pair of 2 points.

Every battle with a hardcore algorithm should start somewhere. I suggest reading Cormen et all “Introduction to Algorithms”, 3rd edition (Section 33.4), but any decent book will do.

We start from a naive implementation of divide-and-conquer approach to the closest pair of points problem:

Let us suppose that we have 2 lists of size n as our inputs: x and y, which correspond to pairs of points (x1,y1) … (xn,yn), where n is number of points.

First, let’s look at the following function:

Here we address the concept of presorting. As noted in the book,

Note that in order to attain the O(n * lg (n)) time bound, we cannot afford to sort in each recursive call; if we did, the recurrence for the running time would be T (n) = 2T(n/2) +O(n*lg (n)), whose solution is T (n) = O(n * lg(n)²).

Therefore, presorting outside of function that will be called recursively allows to implement the solution in smaller time complexity.

Let’s look at the recursive call (with the appropriate comments):

The implementation above is done according to the book. However, during the debugging of the algorithm, I’ve found a peculiar feature. If we were to substitute the midpoint split logic to:

the code would actually run a little bit faster. I won’t dive into low-level details of it, though a curious one should compare the speeds of comparison


for a set(). That’s the only reason I can think of.

Now, that’s where it gets interesting. First, the brute(ax) function:

Let us discuss that in brief. Why mi = distance between first two points from the list? Why not a random and large number? Well, it saves us a computation on each of the many calls to the brute function. That’s a win. Furthermore, if len(ax) == 2, we’re done, result can be returned.

Second important point concerns ranges of our two cycles, which need to be used in case of 3 points (recall that brute is called only if len(ax) ≤ 3). Why do we not need to iterate over len(ax) points for i index? Because we are comparing two points: ax[i] and ax[j], and j is in range from i+1 to len(ax). It means, that we’ll compare all the points in len(ax) anyway. Furthermore, conditions on j index mean that we won’t compare points twice: dist(a[1], a[3]) and dist (a[3], a[1]) as well as dist(a[2], a[2]) situations are not allowed because of the boundaries. It speeds up the algorithm at least 2 times (as opposed to simply having 2 cycles of len(ax)).

Back to our first point. If condition inside loops saves us extra comparison computation.

Distance function (dist) is nothing special:

Finally, one of the most interesting pieces, a function, responsible for finding a closest pair of points on a splitline, closest_split_pair:

Again, the salt lies in ranges of 2 cycles. They are produced using ideas similar to ones used in brute function, with one important distinction. The upper boundary on j index is min(i+7, ln_y) for reasons discussed in Correctness chapter of Corman et all. In short: it is enough to check only seven points following each point on the s_y subarray. You should really look through the proof of correctness, because it explains a lot better this ‘trick’ that allows for great running speed increase.

P.S.: tips on debugging and testing

Unit tests are mandatory. IDE PyCharm (Ctrl + Shift + T for creating a unit test for method) is recommended. Also, additional reading on stress testing is advised.

I used the following code to create a great test case for testing purposes:

It took about 40 seconds to run initially on my Intel i3 (2 cores, 4 processes), ~2.3 GHz, 8 Gb RAM, SSD (~450 MB/s read/write), which dropped to about 20–30 secs after some optimizations I mentioned.

Another great tool for debugging purposes was my friend’s library of convenient timers (which I posted to my Github after some changes):

It helped to time functions using convenient wrappers, and examples are built in code.

I used wrappers over the functions described above, ran the test case and collected the prints of runtime to json file. Later I passed the results over to SQLite database and used the aggregation functions to get average runtime for each function. I performed same procedure again after adding optimizations and was able to observe % change between the average runtimes of functions to understand whether the optimization improved runtime of a specific function (overall runtime could be compared just from running the unittest example above). I designed this procedure for deep understanding of results and is not necessary for general debug.

Good luck and contact me for extra details on the algorithm or for other suggestions:

P.S.: this story is a part of my series on algorithmic challenges. Check out other cool algorithms decomposed with tests and jupyter notebooks!

Lead Data Scientist @ T-shaped Crew