Basics of K-Nearest Neighbour Classifier: Illustrations with Python

Topics in Machine Learning

Harikrishnan N B
Analytics Vidhya
6 min readDec 5, 2019

--

Introduction

In machine learning (ML), one of the key important points is that there is no perfect algorithm that works for all datasets. Every algorithm has some assumptions. It is important to know the assumptions so that we will know when to apply different ML algorithms. Today we will try to understand the assumptions and principles of an ML algorithm called K - Nearest Neighbours (KNN).

Intuition

Before exploring the technical details we will try to understand the intuition behind KNN. We shall do this by playing a game called identifying the profession. The goal and rules of this game are as follows:

Goal: Identify the profession of two individuals A and B.

Rule: You only know the professions of five people with whom A and B spend their maximum time.

Let us come up with an algorithm to identify their professions.

Figure 1: Identify the profession of A and B.

We shall use the model of Jim Rohn to solve this problem.

Figure 2: Source: https://candicelatham.com/friends-inspire-me/

In this model, we are trying to classify the professions of A and B based on the professions of their five friends with whom A and B spend their most time. The majority of the A’s top five friends are scientists’. Hence, we classify the profession of A as a scientist. Similarly, the majority of B’s top five friends are police staff. So we classify B’s profession as police. We can name this classifier as Nearest Friends Classifier.

Assumptions

Remember every model has some assumptions and no model is perfect. What are the assumptions used in the Nearest Friends Classifier? If we assume the profession of each person as a data instance, then we are assuming that similar data instance has similar labels. Based on this assumption we classified the profession of A as a scientist and B as police staff.

Why did we choose this model?

One reason why we choose this model is by observing the dataset (i.e information of five friends of A and B). We could see a pattern in this dataset i.e., a similarity in the professions of A’s and B’s friends. So choosing this model is purely based on data and not arbitrary. In both these cases, we have high confidence in classifying the profession of A and B. Now consider a third person C who is a Forensic Scientist (a Forensic Scientist hang out with both scientists and police staff). What will happen if you play the game again?

Can we classify the person C correctly? The answer to this depends on the profession of five people with whom C spends his/her maximum time. If out of five, three are scientists and two are police then we classify C as a scientist. On the other hand, if three friends of C are police and two are scientists then we classify C’s profession as police. Note that, in the case of A and B the confidence of classifying the profession is high whereas in the case of C the confidence reduced.

KNN Classifier

KNN also assumes that similar points (data instance) have similar labels. Now we shall go into the details of the algorithm.

Let us consider a binary classification task, classifying circles and crosses.

The training data and labels are from a probability distribution P. We only have some data instances from this distribution. We are ignorant about what is the actual distribution. Let the dataset D be as follows:

Algorithm

Goal: Given a test data instance, find out the most common label of its k nearest training data instances. Assign the most common label as the label for test data.

Step 1: Choosing the value of k. k represents the number of nearest neighbours. k is a hyperparameter and has to be fixed by cross-validation. The value of k is an odd number.

Step 2: Define the distance metric used in KNN

The classification output of KNN relies on the distance metric. The commonly used distance metric is Minkowski distance. The Minkowski distance is defined as follows:

x and z are vectors of length r.

  1. When p = 1, dist(x,z) is the Manhattan distance(L1 norm),
  2. When p = 2, dist(x,z) is the Euclidean distance (L2- norm).
  3. p is also a hyperparameter and has to choose by cross-validation.

Step 3: Consider a single test data instance z. Compute the distance of z from its k nearest training instances.

Sx is a subset of the training set D. Sx consists of k training data instances which are near to the test data instance z. The distance of test data instance z from data instances in D but not in Sx is greater than or equal to the distance of test data instance z from data instances in Sx. This is because Sx consists of k nearest neighbours of the test data instance z. (k+1)th neighbour is farther than k-th neighbour[1].

Step 4: Classification of the test data instance- From the set Sx choose the most frequently occurring label as the label for the test data instance z.

Toy Example

Dataset Generation

We shall create a dataset of circles and crosses. We will use python to generate and plot the dataset.

Figure 3: Dataset with two features f1 and f2.

Three-fold cross-validation and hyperparameter tuning

The dataset is divided into training and testing randomly. The training data is further split into train and validation set. Now we do three-fold cross-validation for finding the best k value.

It is important to note that test data should only be visited once. We should never fix hyperparameters by repeatedly testing on test data. Hyperparameters must be chosen based on the performance on the validation set. Once the hyperparameters are fixed retrain the model using the complete training set and apply the model on the test set. The test set has to be visited only once. This part is the same in all machine learning algorithms.

Figure 4: Cross-validation

Splitting dataset into training and testing

KNN main function

Three-fold Cross-validation for fixing k

In this example, we are only doing cross-validation to fix the best value of k. We have fixed the value of p as 2. Ideally, the value of p also has to be fixed by cross-validation.

Figure 5: Choosing k value for KNN using three-fold cross-validation.

During hyperparameter tuning, k = 11 gave the best F1-score.

Retraining with total training data and applying the model to classify test data

Evaluation Metric

In this tutorial, we tried to code KNN from scratch. For faster implementation, we can use scikit-learn [2].

Reference

  1. Weinberger, Kilian Q. “CORNELL CS4780 ‘Machine Learning for Intelligent Systems.’” YouTube, YouTube, www.youtube.com/playlist?list=PLl8OlHZGYOQ7bkVbuRthEsaLr7bONzbXS.
  2. Pedregosa, Fabian, et al. “Scikit-learn: Machine learning in Python.” Journal of machine learning research 12.Oct (2011): 2825–2830.

--

--

Harikrishnan N B
Analytics Vidhya

Research Associate, Consciousness Studies Programme, National Institute of Advanced Studies, Bengaluru, India