Kth Closest Points to Origin

Algorithm & Data Structures Simplified

Kevin
5 min readAug 8, 2020

In this article, I will be explaining to you one of the problems that you may find when tackling questions in data structures and algorithm. You will need some basic knowledge of data structures in order to understand the optimized solution to the problem. The code in this article will be based on Python (note that Python is zero-indexed)!

Difficulty: ❤️️️❤️️💛
Ingredient: Priority Queue (or Heap)

At one point of your life, you may have come across an algorithm and data structures question that goes like this:

Given a non-ordered (unsorted) array of coordinates and the value of k, find the kth nearest point to the origin. The coordinates given can be be in 1D, 2D or 3D space.

For instance, if you have an array of 2D coordinates,

[ (1,2), (1,0), (9,8), (6,8), (3,3) ]

and also given the value of k,

k = 3

You are supposed to find the 3rd set of coordinates closest to the origin (0, 0). Let’s approach this step by step.

Brute Force

One of the possible questions you may ask yourself, instead of kth element, how do I get the 1st element closest to the origin (where k = 1)? To simplify the problem, what if I am given 1D coordinates instead of 2D or 3D?

For instance, given the following array

[ 2, 3, 1, 5, 7, 6]

How do I get the closest value to origin 0 (in layman terms, smallest value) for 1D case? There are 2 distinct way of doing so,

  1. Sort the array from smallest to largest value, and take the first value, or
  2. Go through every single element in the array, and record the smallest you have seen. This is as good as remembering k number of elements closest to the origin, and replace if necessary.

Both solutions actually works! But there are notable difference in the runtime complexity versus space complexity (see Big O Notation).

Brute Force — Method 1: Sorting

In the first method, it is very straightforward. You sort the array,

[ 1, 2, 3, 5, 6, 7]

And to get the smallest element (k = 1), just get the index 0 element. What about second (k = 2) element? It will be the element at index 1.

The code (written as a function) will look something like this:

def kthClosestPoint(k: int, array: list):
if k < 1:
raise Exception('Invalid k')
return sorted(array)[k-1]

Depending on the sorting algorithm, you will have a typical runtime complexity of O(n log n). Unlike the above code that obtains a new sorted array behind the hood which will give you a space complexity of O(n), if you sort in-place, you will have a space complexity of O(1) instead.

But is there any possibility of further improving this method in terms of runtime complexity? Probably not.

Brute Force — Method 2: Remember k number of elements

Now, instead of doing a sort, what if you just keep track of k number of elements closest to the origin?

Back to the same 1D example and given k = 1,

[ 2, 3, 1, 5, 7, 6]

You will pick up every element in the array one by one, and remember the smallest you have seen so far! Similarly for k = 2, you will remember only the 2 smallest you have seen.

Now, if you are familiar with priority queue or heap queue (I will be using heapq for Python), then you will realize that you can actually make use of this data structure to obtain k smallest elements.

import heapqdef kthClosestPoint(k: int, array: list):
if k < 1:
raise Exception('Invalid k')
# Convert array into heap
heapq.heapify(array)
return heapq.nsmallest(k, array)

If your array length (a.k.a. heap queue) is n, using this method, you will end up with a worse case runtime complexity of O(n log n), since pushing and popping an element to a heap takes O(log n). The space complexity is O(n) if you duplicate the array or in this example code, O(1) since I am doing it in place.

Optimization

You can actually further improve the runtime complexity of this method by limiting the heap queue to k instead of the whole array length n:

import heapqdef kthClosestPoint(k: int, array: list):
if k < 1:
raise Exception('Invalid k')
k_elements = [] for num in array:
heappush(k_elements, -num)
if len(k_elements) > k:
heappop(k_elements)
return [-num for num in k_elements]

Note that since heappop only removes the smallest element, one possibility is to invert the polarity of the elements i.e. positive integers will be negative and negative integers will be positive. This will force all large integers to appear small, hence only large integers will be removed from the heap queue.

The typical runtime complexity will be O(n log k), since you will be heappush-ing and heappop-ing every single element of the array, while the heap queue length is at most k. This is as bad as having the worse case scenario!

Further Optimization

Can we further improve this for typical case? Instead of placing every element into the heap queue and removing them, can we check before we do it? Yes we can!

If we already have a heap queue of size k, we should peek at the “largest” element in the heap queue and see if our current element is larger or smaller than that, before we push an element in. If the heap queue is still smaller than length k, we can continue to push elements into it!

import heapqdef kthClosestPoint(k: int, array: list):
if k < 1:
raise Exception('Invalid k')
k_elements = [] for num in array: if len(k_elements) < k or k_elements[0] < -num:
heappush(k_elements, -num)
if len(k_elements) > k:
heappop(k_elements)
return [-num for num in k_elements]

Similarly, if you are dealing with 2D or even 3D data, you can modify this code to accommodate them, using the exact same method.

Solving for 2D Data

Assuming you have data points in an array looking like this:

[ (1, 2), (3, 5), (6, 7)]

The distance for each point to origin (0, 0) is simply expressed using Pythagoras theorem in its reduced form:

distance = x**2 + y**2

Nothing beats looking code so by modifying the previous 1D code:

import heapqdef kthClosestPoint(k: int, array: list):
if k < 1:
raise Exception('Invalid k')
k_elements = [] for x, y in array: dist = x**2 + y**2 if len(k_elements) < k or k_elements[0][0] < -dist:
heappush(k_elements, (-dist, x, y))
if len(k_elements) > k:
heappop(k_elements)
return [[x, y] for dist, x, y in k_elements]

If you have any feedback or anything that you wish to share, feel free to drop a comment 👇!

--

--

Kevin

Technopreneur, Rocket Scientist, AI and Security Enthusiast, Serverless Advocate, Full Stack Engineer, Native Android & iOS, Pythonista