K Nearest Neighbors — Explanation and Implementation

Francesco Di Salvo
Artificialis
Published in
5 min readDec 4, 2021
Michael Tuszynski from Pexels

Introduction

K Nearest Neighbors is a supervised learning algorithm based on the assumption that the points in the neighborhood (i.e. closest points) belong to the same class.

Therefore, given a positive integer k and a test observation, KNN identifies the the k closest points and then, it can make inferences based on these closest targets.

Notice that it can be used for both classification and regression problems. In particular, for a classification problem, it estimates the conditional probability for class j as the fraction of points in the neighborhood whose target value is equal to j:

where N0 represents the neighborhood of the given observation. It classifies the observation with the class having the highest conditional probability. In shorts, it considers the majority class in the neighborhood!

On the other hand, in a regression problem, the output will be just the average of the k closest targets.

https://it.wikipedia.org/wiki/K-nearest_neighbors

How to choose the proper value of k

The choice of k is always critical, in fact we have to pay attention at two different scenarios:

  • small k : since you are very close to the observation, you will have a low bias, but you may be affected by a strong variance due to the presence of some outliers
  • large k : if you enlarge your neighborhood you will be more robust to the outliers (low variance) but you will end up with an higher bias because you will probably consider points that are not so close.

The bias-variance tradeoff is always behind the corner!

Ok great, but how can we select it? Unfortunately there isn’t an absolute answer but we have to try different values.

A common practice is to plot the reference error metric for different values of k. By doing so, you can see with your own eyes how your model behaves and what can be a good trade off.

https://towardsdatascience.com/how-to-find-the-optimal-value-of-k-in-knn-35d936e554eb

Distance metrics

In order to quantify how much an observation x is close to y, we need to define a proper distance metric. The most common ones are:

  • manhattan : the distance between two points is the sum of the absolute differences of their coordinates.
  • euclidean : it is probably the most used one and it is defined as the square root of the sum of the squares of the differences of the coordinates
  • chebyshev : it is the maximum absolute difference between all the coordinates. It is also called supremum distance

Implementation

To sum up, the naive algorithm is quite straightforward:

  1. For all the test observations, compute the distances between them and all the training ones
  2. Consider only the k closest points and maintain their targets
  3. If it is a classification problem, take the majority class, otherwise, take the average of the closest targets.

The complete implementation can be found here. First, I’d like to report the skeleton of the class, with all the methods that we need to implement.

The __init__ method will be used for defining the number of neighbors, the chosen metric and an auxiliary flag for understanding if we’re dealing with a regression or classification problem. Now we will incrementally enrich the code until the end!

For the sake of brevity here I have reported only the euclidean distance, but on the github repository you will find all the implementations.

Moreover, since we can choose from different metrics, it is important to check if we are considering a “valid” one:

Then, we must be able to compute the distance, therefore let us add two other methods: compute_distance() and euclidean_distance(). The first one will be used to call the proper distance method based on the selected one, while the second one (and also the ones that I haven’t reported) will be in charge to return the actual distance.

Here the fun begins! We have defined all the contour methods but now we need to define the pure fit_predict() method. Recall that in this case we are performing a brute force approach, but there are way more efficient solutions.

So, for each test point, we have to compute (and save) all the distances with the training points. Then, we sort them and we take just the k closest ones (i.e. the ones for which we have the lowest distance).

After that, we memorize just their class and finally we can make our prediction as the mean or the majority class, for a regression or classification problem, respectively.

Finally, for the sake of completeness, I have compared the performances of this algorithm with ScikitLearn’s implementation on the Iris Dataset and the performances are equal!

Conclusions

This algorithm is quite simple but is extremely powerful whenever we have a moderate number of features and samples. It has been quite fun and challenging and I hope you have understood all the hidden passages!

For any comments, doubts or feedbacks, feel free to reach me out on LinkedIn, I would be more than happy to answer! :D

--

--