Machine Learning : K-Nearest Neighbors (Theory Explained)

Ashwin Prasad
Analytics Vidhya
Published in
4 min readSep 9, 2020

What is K-Nearest Neighbors ?

KNN is a Supervised Machine Learning algorithm, That can be used for both Regression and Classification Problems.
Explanation for why to use KNN over other algorithms will be explained later

How K-Nearest Neighbors work ? (Classification Problem )

It is no doubt to us that machine learning algorithms need data to Train on. KNN is of no exception.
For the sake of explaining how KNN works, Let’s see an example:

Let us consider a problem where we are given the height and weight of a person, we have to classify if the person is fit to join the army or not (sorry, couldn’t find a better example) and assume that figure 1.1 represent this data of a lot of people. where, Red represent Fit, Blue represent Unfit.

So, Let’s say a new person has to undergo this test to be classified into one of these classes (fit/unfit).

Figure 1.2

Now, Let the height of the new person be H and weight be W.
and Red Star in Figure 1.2 represent this new person.
So, The Co-ordinates of the Star will be (H,W)

In KNN, we have to manually select the value for a variable name k, which will be shown later.

In this case, let’s try k=3 and k=6:
if K = 3, we are going to find 3 closest existing data points on our training set to our Red Star ( New Person).

How To Find the Closest 3 points to new data point ?

  1. Find the Euclidean Distance between the new data point (H,W) and all the other data points in the training set and put them in an array
Euclidean Distance Between 2 Co-Ordinates Formula

2. Sort the array in Ascending order
3. Get the Top 3 elements of the array (because k=3)

How to Classify the New Person ?

Now that we got the 3 data points that were closest to the co-ordinates of Red Star (new person), all that’s left to do is find the classes to which those 3 data points belong to.
If 2 of them belong to class fit and the other one belongs to class unfit, then we classify the new person (Red Star) to be fit. and if that’s not the case, we classify the new person to be unfit.

This is how KNN works. we literally classify data based on the majority class of it’s neighboring data (Euclidean distance) and hence the name K Nearest Neighbors.

How K-Nearest Neighbors work ? (Regression Problem )

All the steps are same for Regression and Classification till getting the
‘K’ elements of the array (shown above)

after finding K elements closest to the new data point. we, take the mean of the dependent variable Y of the ‘K’ elements in the list and that value is assigned as the value for Y of new data point.

Note : Unlike Classification, Regression outputs are continuous values for regression and hence we take the mean of the top ‘k’ elements of the array and assign that value as the predicted value for our new data.

Why Value of K is Important ?

K affects the accuracy of our model and the number of nearest neighbors from which out new data is classified is highly significant
very small value for K could lead to overfitting and very small value for k could lead to underfitting. So, Choosing the right value for K is very important

How to Choose K Value ?

There is no perfect way to choose the value for K. But, it is pretty much dependent on the dataset.

But , there still are some evaluation metrics that makes things easier.
using a test set and loss function.
so, A loss function is something that tells us a measure of how much our predicted values are deviating from the actual value

For Regression problem ,We can use Mean Squared Error Loss and for classification problems, we could use either binary cross entropy or categorical cross entropy loss based on the type of classification.

Mean Squared Error

Explaining these loss functions is beyond the scope of this blog post. But, the basic takeaway is that we could use different K values and calculate the loss using the loss function for each value.
We finally select the K Value for which the loss is the lowest.

Conclusion

So, KNN is a simple yet a efficient algorithm as it depends so much on the co-ordinates of it’s neighboring values and this is how it works.

Thanks for Reading

--

--

Ashwin Prasad
Analytics Vidhya

I write about things that intrigue me on any field of Computer Science, with more weightage to Machine Learning and Systems Programming