K Nearest Neighbors

Swapna Patil
Machine Learning Algorithms
9 min readJul 31, 2018

This blog focuses on how KNN (K-Nearest Neighbors) algorithm works. We have tried to explain every concept in layman’s term. You can find the code on the github link. If you are familiar with the concepts, please feel free to skip to the next section.

Let’s start with what is KNN?

KNN is the simplest machine learning algorithm used for classification and regression. It makes decision based on the entire training dataset. No time is spent on training the algorithm but only for pre-processing and testing.

The algorithm makes predictions by calculating similarity between the input sample and each training instance. This algorithm does not make strong assumptions about the form of mapping function hence it is Nonparametric. In simple words, by not making assumptions, the algorithm is free to learn any functional form from the training data.

Understanding KNN with an example:

To give an overview of where KNN can be used, let us consider an example.

The SPEED and AGILITY ratings for 10 college athletes are taken into consideration to label if they were drafted by a professional team or not.

Visualizing the above data:

We are looking for query instance for the label of x (with SPEED = 6.75 and AGILITY = 3.00) as indicated in the below image.

We find the distance of x from all the other points present in the dataset.

We take K samples which are closer to x and based on those K samples, the label of x is decided.

Assume K = 3, we take 3 closest samples near to x. Since we have one ‘draft’ sample and two ‘no_draft’ samples in the circle, we will assign the label ‘no_draft’ to our test sample x.

k=3

This is exactly what KNN does! Calculate the distance between the new instance and it's neighbors and then decide the class of this instance from K nearest neighbors.

Diving into the code

Here few libraries are imported first. If you are familiar with these, just skip their introduction.

NumPy is a library in Python which supports multi-dimensional arrays and matrices in addition to mathematical functions.

Pandas is an open-source library in Python for reading, manipulating and analyzing data.

Matplotlib is a plotting library for Python programming language.

Counter is used to count the unique number of samples in the list.

Dividing the code into steps for better understanding:

  1. Data: Open dataset from CSV file.
  2. Visualize the data.
  3. Distribute the dataset into train set and test set in the ratio 3:1
  4. Calculate the distance function.
  5. Sort the distance list in ascending order.
  6. Pick top K elements of the sorted distance list.
  7. Count the number of samples for every class in the selected K elements.
  8. Assign a class which has more number of samples to the test data.

Step 1. Data: Open dataset from CSV file.

Here, names have been assigned to the columns of the dataset.

IRIS dataset (csv file) has been downloaded and read into data frame df.

fig 1: Iris setosa fig 2: Iris versicolor fig 3: Iris virginica

The iris flower dataset of 3 species (setosa, versicolor or virginica) has total 150 rows and 5 columns: first 4 columns are sepal_length, sepal_width, petal_length, petal_width and the last column is for label (class).

Since the dataset is in CSV format, read_csv method of Pandas is used. Filename can be directly specified if the dataset is in the same folder as the python code. If not, then the absolute path should be specified.

DataFrame(df) is a 2D labeled data structure with columns of different types.

Header = None : Column names are considered as header. This is used to imply that the dataset doesn’t have a header which simply means our program starts reading the dataset from row 1 and is aware that the dataset does not contain column names. If not, it will consider the first row as header.

Names: a list which contains the column names.

Step 2: Visualize the data.

Visualizing a dataset helps a person to see how the data is distributed, how to pre-process, normalize data etc. To get an initial idea of the dataset, we can specify df.head() which returns the first 5 rows (by default) of the dataset.

Plotting the statistics:

In the above snippet, the column ‘Class’ will be set to 1 (true) where ‘Iris-setosa’ is present and will return the rows where the class is ‘Iris-setosa’.

In the second line, the column ‘Class’ will be set to 1 (true) where ‘Iris-versicolor’ is present and will return the rows where the class is ‘Iris-versicolor’.

In the last line, the column ‘Class’ will be set to 1 (True) where ‘Iris-virginica’ is present and will return the rows where the class is ‘Iris-virginica’.

To get a feel of the dataset :

Plotting the sepal_length and sepal_width.
Plotting the petal_length and petal_width.

The function len(df) determines the count of rows in the dataset.

Distribution of classes

Step 3: Distribute the dataset into train set and test set in the ratio 3:1

In Machine Learning Algorithms, a training set is implemented to build a model, while a test set is to validate the model that is built.

For testing purpose, 1/4th of the dataset is kept aside, keeping as much data in training set which helps to predict the class of the new instance.

Random.uniform generates 150 random numbers (length of our dataset) between 0 and 1. Values less than 0.75 are assigned TRUE in ‘is_train’ column of the DataFrame.

Next, we proceed to take the first 4 columns (sepal_length, sepal_width, petal_length, petal_width) as features and 5th column (class) as the label of the data:

Step 4: Calculate the distance function.

In the code, the formula to calculate the Euclidean and Manhattan distance has been specified. Either of them can be used to compute the distance.

Step 5: Sort the distance list in ascending order.

In the above snippet, distance of the test sample is calculated with respect to training samples. These distances are stored as value and index of corresponding training data sample as the key of dictionary ‘distances’ which is sorted in ascending order.

distances.items returns tuples of key and value pair. Since we need the closest neighbors, sorting should be done according to the distances and not the index. Therefore, a lambda function is passed which reverses the elements of the tuple. Thus, a list sorted_dist is obtained which contains tuples of (index,distance).

Step 6: Pick top K elements of the sorted distance list.

The top K elements are picked. It is generally taken as an odd number so that it acts as a tie breaker in case of a tie.

First element of top K tuples is taken from previously sorted list as it is the index of the training sample. Now, the ‘neighbors’ list contains the indices of the nearest k training samples.

Step 7: Count the number of samples for every class in the selected K elements.

A list of classes is created which corresponds to every element in the neighbors list. Then a ‘Counter’ class is used to count the frequency of every class present in the list.

Step 8: Assign a class which has more number of samples to the test data.

The list ‘list_keys’ contains the name of the classes and ‘list_values’ contains the frequency of the corresponding classes. All that is left in the end is to predict the name of the class of the test sample. For this, we take into consideration the class which has maximum frequency and return it as a result.

Accuracy of KNN:

Accuracy is the ratio between the total number of correctly classified instances and the total number of instances. Here, we can see our model is 97.3% accurate when Euclidean distance is used as a similarity measure.

Pros and Cons of KNN:

Pros:

  1. It is the simplest algorithm - easy to grasp and implement as well.
  2. It takes time only for pre-processing and testing, no training time.
  3. New training samples can be added easily.

Cons:

  1. Distance of the new instance is computed with all the training samples every time.
  2. Adding training samples results in increase in time taken to predict the class of the test sample as the distance of the test sample with respect to the added training sample has to be calculated.
  3. Choosing the value of K is a tough choice to make as lower value of K results in higher influence of noise whereas increasing the number of K might result into overfitting. Overfitting simply means the model learns the training features way too nicely but fails to generalize over the test dataset. As a result of which training accuracy is almost perfect but accuracy on test dataset is low.
  4. If a particular class is very frequent in the training set, the model will tend to dominate the majority voting of the new instance.

Applications of KNN:

KNN is a better choice for applications where predictions are not requested frequently but where the accuracy is important.

Text mining: The KNN algorithm is one of the most popular algorithms for text categorization or text mining.

Agriculture: KNN has been applied to simulate daily precipitations and other weather variables. The other applications of the KNN method in agriculture include climate forecasting and estimating soil water parameters.

Finance: Stock market forecasting is one of the most core financial tasks of KNN. Stock market forecasting includes uncovering market trends, planning investment strategies, identifying the best time to purchase the stocks, and what stocks to purchase.

Medicine: This prediction model helps the doctors in efficient heart disease diagnosis process with fewer attributes.To predict whether a patient who is hospitalized due to a heart attack, will have a second heart attack. The prediction is based on demographic, diet and clinical measurements for that patient.

Well, this concludes KNN for now! :)

References:

  1. Introduction of KNN
  2. Nonparametric Machine Learning Algorithm
  3. Example of KNN
  4. Pros and Cons of KNN
  5. Applications of KNN

Footnotes:

Co-author: Abhash Sinha

This blog has been written in collaboration with our github code. If you have any questions or suggestions, please feel free to reach out to us. We will come up with more Machine Learning Algorithms soon.

--

--