New Blooms: Classification with k-Nearest Neighbours

Beatrice de Waal
AI Odyssey
Published in
13 min readNov 20, 2023
Credits for the picture here.

As winter is coming it’s common to feel a bit under the weather: it rains, fallen leaves become all mushy (blech!) and early dusk screws with your circadian rhythm even more than midterms could. A dream of spring is still possible though, and this year you came prepared: a month before the first frost you have planted a big batch of iris seeds, that will bloom next spring. However, life in Milan is expensive, and you wanted to save some money: instead of purchasing those nice little packets of specific species, you have chosen a big bag of mixed seeds, but now that you’ve been taking care of them for so long, you’re disappointed that you won’t be able to tell the flowers apart. Or will you?

As I was saying, you have kind of developed an affection for your iris plants, and you have a lot of time on your hands to identify their species. So worry not: as long as you are willing to sit down and measure all the lengths and widths of the petals and sepals of your flowers, there is one handy algorithm and a dataset to come to your rescue: you need to learn about Iris and kNN.

The Iris Dataset

In the 1930s American botanist Edgar Anderson shared your interest for iris flowers: he thus collected measures of sepals and petals widths and lengths of three different species of iris: iris setosa, iris versicolor, and iris virginica. In case you’re wondering (because I sure did), sepals are a part of the flower that sits under the corolla, and serve to protect the bud before it blooms. For most flowers sepals are green, but not for iris. Considering that your batch contains these three species, it will look something like this:

If you thought you’d be able to tell them apart just by looking at them, I guess this picture won’t be good news for you.

In 1936 illustrious statistician and polymath Ronald Fisher used the data collected by Anderson and compiled what is today simply known as the “Iris” dataset, which contains 150 observations of iris (50 observations for each specie) and 5 variables: the specie, the lengths and the widths of sepals and petals.

The kNN Algorithm

The kNN, short for k-Nearest Neighbours algorithm, is a supervised machine learning technique that can be used for both classification (when the output is categorical or discrete) and regression (when the output is numeric and continuous). The underlying assumption for its functioning is extremely simple: instances with similar characteristics (thus belonging to the same class) will be found close to each other, while instances with different characteristics won’t be in each other’s proximity. Therefore, if you have data to train and test a kNN algorithm, once you have new data that you wish to classify, the model will be able to do it for you, by labelling each piece of it as the closest data points available in the model. Obviously, your data must be compatible with what you used to train the model: you can’t train a kNN with iris data and then use it to classify species of ducks.

In your case, you can build a classification model using the k-Nearest Neighbours algorithm, and then feed it the data about your flowers to know to what species they belong. The model is a classification one and not a regression one because the species are categorical variables.

To start, you’ll need to split the Iris dataset into a training set and a test set, with the former being bigger than the latter (usually the training set is 60–80% of the data and the test set is 40–20%). You use the training set to train the model, and the test set to see how well it performs. kNN is called a “lazy-learner”, because there is not really much to do apart from loading the training set: each datapoint already has a class label (setosa, versicolor, or virginica)! This characteristic of kNN is also the reason why kNN is considered an instance-based model and a nonparametric model: each time you want to classify a query point, you are not comparing it to a global model (such as an equation), but rather with the single instances that surround it locally. A nonparametric model does not have a fixed number of parameters, but rather flexible ones; also, its sample data does not need to satisfy any assumption, such as following a particular distribution. A very familiar example of non-parametric models is a histogram: before you plot a histogram you don’t know whether your data will follow a distribution such as a Gaussian or a Poisson distribution, but rather you discover which parameters you can use to describe the model (if they are useful and practical) thanks to the histogram itself.

Before you can feed it the test sample, however, there are three key decisions you need to make:

First: you need to decide how many data points (how many neighbours) your model must take into consideration to assign a label to the observation to be classified. This number is called k, and it is where the name of the algorithm comes from; searching for the optimal k is called parameter-tuning. If you wanted to, you could even choose k = 1: in this case, the observation would be classified as the closest point available. However, it is easy to see why this choice could lead to mistakes: imagine your observation is surrounded by data of one class, apart from one point of another, which also happens to be the closest. If k > 1, your point will be classified as the majority of the ones surrounding it, but if k = 1 then you risk classifying it as a potential outlier.

How to classify the yellow point?

Normally, a good choice for k is between 3 and 15, with k being odd (especially in binary classification problems) to avoid tiebreakers (i.e. situations in which half of the k-closest points belongs to class A and exactly the other half belongs to class B).

A common procedure is to adopt a trial-and-error approach: you start with a lower k, see how your model performs and then work your way up as performance improves, until eventually it starts declining again. In some cases though it may not be feasible to run the algorithm multiple times (for example, it may be very time consuming). A common choice then is to take as k the odd integer closest to the square root of the number n of your observations.

There are some serious risks associated with picking the wrong k: if k is too small you will have high variance and low bias, and the model will be overfitted; if k is too high, you will have low variance and high bias, and your model will be underfitted. If there are a lot of outliers and noise in your training set, it is better to go with a higher k, as you do not want your model to be too sensitive to the variations that characterise the outliers. Ultimately though, it is a balancing act to be perfected. A tool for such refining is cross-validation.

Second: You need to decide how to actually calculate the distance between your data point and the k ones you will use to classify it; it is in fact often not so obvious which points are actually the closest. You thus need to choose between different types of distances. The most known and used one is Euclidean distance:

An alternative is using Manhattan, or taxicab distance, that takes its name from the path a taxicab would need to follow in an area where all streets meet at 90° angles. Using Manhattan metrics, you can only move north, south, east, or west; no in-between:

In this example, which type of distance you choose to use actually has an impact on how your query point will be labelled:

If you were to choose Euclidean distance (dashed circle) then your query point would be classified as a green square. If instead you chose Manhattan distance (dotted square) then your query point would result in a red circle.

Both Euclidean and Manhattan distances can be generalised using Minkowski distance, varying the value of p (you obtain Manhattan distance if p = 1 and Euclidean distance if p = 2):

Another commonly used distance is cosine distance, which utilises the angle between two vectors A and B:

Cosine distance is actually very important in certain areas of machine learning, such as for the recognition of handwriting.

Third: You need to decide how to put together all the distances you have calculated, to ultimately make a choice about which label to give to your data point. In the case of regression, you can compute the mean of the distances from the k closest neighbours; for a more accurate representation of your data, especially if they are dispersed, you can decide to assign a weight to each distance, that will be inversely proportional to it, such that the closest points have a larger weight than the farthest ones.

If you are instead dealing with a classification problem, you need to compute the mode of the k nearest neighbours: you thus obtain a modal class to which you assign your datapoint. This procedure in the literature is called “majority voting”, but this term is misleading and “plurality voting” would be a better fit. In fact, if you have more than two classes you do not need the absolute majority of the k-neighbours to belong to the same class, but only a relative majority.

Now that we have seen and hopefully understood how kNN models work, we can comment on its advantages and disadvantages. As far as the advantages go, kNN is an algorithm easy to implement and to adapt, as new training samples are added; furthermore, it only needs a few hyperparameters, which are always tricky to determine: k and the distance metrics.

Among the disadvantages, aside from being prone to overfitting and taking up a substantial amount of memory as you scale it up, the one that stands out is the curse of dimensionality. It is a phenomenon that we may encounter when we have a fixed size of training examples but an increasing number of dimensions and ranges of variability along each of these dimensions.

Imagine having a cube, with points scattered all throughout its volume; you could then cut the cube into eight equal smaller cubes by bisecting it with two planes. Assuming your data is evenly distributed in the volume of the cube, you’d have one eight of the data in each cube. You then decide to eliminate a dimension, thus going from a cube to a square, and projecting on this square the third dimension. If you cut this square with two straight lines you obtain four squares, each containing 25% of your data. If you then decide to eliminate another dimension, going from two to one and once again making a projection, you obtain a straight line. If you cut it in half into two shorter straight lines, you would have 50% of your data into each segment. The point is, the more dimensions you have, the farther away your data points become from each other, and the less meaningful distances become, since we assumed that similar points are close together. At the same time, if you were to eliminate some dimensions, for instance through feature selection or extraction techniques, you would lose some potentially valuable information. To summarise, high dimensionality is called a curse for a reason: you have to deal with a trade off that, in some cases, may even lead you to utilising a different algorithm entirely.

Time to Code!

It is now time to implement a kNN using the Iris dataset. We are going to utilise R, as it is able to provide great graphical representations.

First, we need to load some libraries:

# Library for kNN
library(class)

# Libraries for graphical representations
library(ggplot2)
library(GGally)

Then, we can assign the data contained in the Iris dataset, which is preloaded in R, to a variable, and visualise it both in the console and as a plot. In case you need to download the Iris dataset, you can find it here.

# Save and view the data
iris <- iris
summary(iris)

# Plot iris with a scatterplot matrix
ggpairs(iris, ggplot2 :: aes(color = iris$Species),
upper = list(continuous = wrap('cor', size = 3)),
diag = list(continuous = wrap("densityDiag", alpha = 0.7)),
lower = list(continuous = wrap( "points",size =0.4)),
title="Scatterplox matrix of Iris features grouped by Species")

We will obtain the following graphical representation, that also shows us the correlation between each pair of features:

We now rescale the values and create three different dataframes, each one containing 50 observations belonging to the same specie. We can also split the data in a training set (comprising 70% of data) and a test set (comprising 30% of data):

# Scaling 
set.seed(436643)
iris [, 1:4] <- scale(iris[, 1:4])

# New separate dataframes
setosa<- rbind(iris[iris$Species=="setosa",])
versicolor<- rbind(iris[iris$Species=="versicolor",])
virginica<- rbind(iris[iris$Species=="virginica",])

# Splitting the data into training and test sets
ind <- sample(1:nrow(setosa), nrow(setosa)*0.7)
iris.train<- rbind(setosa[ind,], versicolor[ind,], virginica[ind,])
iris.test<- rbind(setosa[-ind,], versicolor[-ind,], virginica[-ind,])

We now need to choose our distance metrics and how many neighbours (k) to consider. We’ll use Euclidean distance and define and optimise an error function, to find an optimal k:

# Parameter-tuning
error <- c()
for (i in 1:15)
{
knn.fit <- knn(train = iris.train[,1:4], test = iris.test[,1:4],
cl = iris.train$Species, k = i)
error[i] = 1- mean(knn.fit == iris.test$Species)
}

ggplot(data = data.frame(error), aes(x = 1:15, y = error)) +
geom_line(color = "Blue") +
scale_x_continuous(breaks = c(1:15))

We obtain the following plot for the error function:

The value of k that we’ll go with is thus 5. It’s finally time then to check the outcome of our model:

# Running kNN with k = 5
iris_pred5 <- knn(train = iris.train[,1:4], test = iris.test[,1:4],
cl = iris.train$Species, k=5)
table(iris.test$Species,iris_pred5)

The command table( ) will show us in the console the confusion matrix for k = 5; I provided a nicer representation:

We see that out of 45 predictions, 44 are correct. We can check how the results would change if we picked a different k; we’ll check for k = 3 (error is higher than for k = 5 but still quite low) and for k = 7 (which should be worse):

43 out of 45 predictions are correct
41 out of 45 predictions are correct

We notice that, even when the choice of k is not optimal, iris setosa is never missclassified. To understand why we can simply refer to the scatterplots above and immediately see that the setosa points are much farther from both versicolor and virginica than these two types are from each other.

If we, for example, were to opt for an unsupervised approach and choose k-means clustering, we most likely would not obtain clusters that actually coincide with the classes, as it would be difficult to distinguish versicolor from virginica.

To draw conclusions, I would like to show you another rapresentation of each neighbourhood, i.e. the set of all possible points that would fall within its borders, given a set of data. This representation is called a Voronoi Tessellation or Diagram. In the following pictures, I have used the data in the test set:

As you can see, there is one area, that of iris setosa, that stands out very evidently from the other two, which are instead a lot closer together. The code to obtain a Voronoi tessellation is the following:

# To obtain a Voronoi tessellation
library(deldir)

tesselation1 <- deldir(iris.test$Sepal.Width, iris.test$Petal.Width)
tiles1 <- tile.list(tesselation1)
v1 <- plot(tiles1, pch = 20, xlab = "Sepal Width", ylab="Petal Width", close=T,
fillcol = hcl.colors(45, "GnBu"))

tesselation3 <- deldir(iris.test$Sepal.Length, iris.test$Petal.Length)
tiles3 <- tile.list(tesselation3)
v3 <- plot(tiles3, pch = 20, xlab = "Sepal Length", ylab="Petal Length",
close=T, fillcol = hcl.colors(45, "GnBu"))

The reason why I think this representation is truly beautiful is that it looks very similar to other patterns that can be found in nature, such as dragonly wings and epithelial cells:

Dragonfly wings. Credits here.
Epithelial tissue: simple squamous epithelium of a frog. Credits here.

This affinity can make us understand how fundamentally simple, beyond all the coding and the technicalities, the concept behind kNN is: you have a set of points, around which you draw boundaries, and the next time you’ll lay down a new point it will be within one of these areas, so you’ll associate the new point with the one already inside the boundary. Once again, it’s all a matter of optimisation downward: which is closest? Which is less expensive?

So as you get your measuring tape ready for all your petals and sepals, I want to leave you with a quote by Italian author Primo Levi, from his short story “The Story of a Carbon Atom”, closing his book “The Periodic Table” (1975):

“Such is life”, although rarely is it described in this manner: an inserting itself, a drawing off to its advantage, a parasitising of the downward course of energy, from its noble solar form to the degraded one of low-temperature heat. In this downward course, which leads to equilibrium and thus death, life draws a bend and nests in it.

Bibliography

IBM. “Background of parametric and nonparametric statistics”. Accessed November 4, 2023. https://www.ibm.com/docs/en/db2woc?topic=nonparametric-background

IBM. “What is the k-nearest neighbors algorithm?”. Accessed November 4, 2023. https://www.ibm.com/topics/knn#:~:text=Next%20steps-,K%2DNearest%20Neighbors%20Algorithm,of%20an%20individual%20data%20point.

Shalev-Shwartz, S., & Ben-David, S. (2014). Understanding Machine Learning: From Theory to Algorithms. Cambridge: Cambridge University Press.

Raschka S. “Machine Learning Lecture Notes”. University of Wisconsin-Madison. Accessed November 5, 2023. https://sebastianraschka.com/pdf/lecture-notes/stat479fs18/02_knn_notes.pdf

Susan D’Agostino on Scientific American. “Voronoi Tessellations and Scutoids Are Everywhere”. Accessed November 5, 2023. https://blogs.scientificamerican.com/observations/voronoi-tessellations-and-scutoids-are-everywhere/

Rolf Turner. “Voronoi diagrams in R with deldir”. Accessed November 5, 2023. https://r-charts.com/part-whole/voronoi-diagram/.

--

--