Implementing K-means clustering

Mohit Gupta
4 min readMar 6, 2024

--

This is the first post in the ‘From scratch’ series, where we will be developing the entire code for ML algorithms and visualization ourselves without using any prebuilt functions. I will only be using PyTorch (or Numpy) for matrix manipulation, pure Python, and Matplotlib for plotting.

K-means clustering process visualization

Let us start with a dataset as shown below.

Fig 1 — Unlabeled dataset to start with

To implement K-means clustering, we need to perform 5 steps. In plain English, these steps are:

1. Select the number of clusters (K) and randomly select cluster centers
2. Calculate the distance of each point from all cluster centers
3. Find the closest cluster center and transfer the label
4. Calculate the mean of the cluster and select the nearest point around the mean as the new cluster center
5. Repeat steps 2–4

The code is written below which is well-documented and, thus, easy to follow.

# Import libraries
from functools import partial
import torch
from torch import tensor
import matplotlib.pyplot as plt
from torch.distributions.multivariate_normal import MultivariateNormal
import random
torch.manual_seed(131) # for reproducability

############# Setup - Generate Synthetic Dataset #########
#### Let's create a dataset of 6 clusters with 250 points each that are normally distributed #####
n_clusters = 6
n_samples = 250

centroids = torch.randn([n_clusters, 2])*100-35 # set the centroids
data = []
for i, centroid in enumerate(centroids):
samples = MultivariateNormal(loc = centroid, covariance_matrix= torch.diag(tensor([50., 50.]))).sample((n_samples,))
data.append(samples)
data = torch.cat(data)

# plot the data
plt.figure(figsize = (4,3))
for i, centroid in enumerate(centroids):
plt.scatter(data[:,0], data[:,1], s=1)
plt.title('Unlabelled data')

##################### all relevant functions #####################################
def get_dist_matrix(data, centers):
''' gets euclidean dist of each point from given cluster centers'''
n_r = data.shape[0]
n_c = centers.shape[0]
dist_matrix = torch.zeros(n_r, n_c)
for r in range(n_r):
pt = data[r]
for c in range(n_c):
center = centers[c]
# eucliden distance
dist_matrix[r,c] = ((pt-center)**2).sum().sqrt()
return dist_matrix

def plot_preds(data, rand_centers, preds):
plt.figure(figsize = (4,3))
plt.scatter(data[:,0], data[:,1], s=1, c=preds)
plt.scatter(rand_centers[:,0], rand_centers[:,1], marker='x', s=100, c ='red')
plt.title('after clustering')
###################################################################################

### Step-1: Select the number of clusters (K) and randomly select cluster centers
num_clusters = 6
rand_inds = torch.multinomial(torch.linspace(0,data.shape[0], data.shape[0]), num_clusters)
rand_centers = data[rand_inds] # initialise random cluster centers

plt.figure(figsize = (4,3))
for i, centroid in enumerate(centroids):
plt.scatter(data[:,0], data[:,1], s=1, c='blue')
plt.scatter(rand_centers[:,0], rand_centers[:,1], marker='x', s=100, c ='red')
plt.title('Initalized random cluster centers')

steps = 5
centers = rand_centers

for step in range(steps):

### Step-2. calculcate distance of each point from all cluster centers
dist_matrix = get_dist_matrix(data, centers)

### Step-3. Find closest cluster center and transfer the label
preds = torch.argmin(dist_matrix, dim=1) # predicted clusters
plot_preds(data, centers, preds)

### Step-4. Calculate the mean of the cluster and select the nearest point
# around the mean as the new cluster center
new_centers = []
classes = preds.unique()
for target in classes:
idxs = (preds == target).nonzero(as_tuple=True)[0]
id_cluster = data[idxs];# print(id_cluster.shape)
cluster_mean = id_cluster.mean(dim=0)
new_centers.append(cluster_mean)

centers = torch.stack(new_centers, dim=0) # update the center
Fig 2 — Final results of K-means clustering with 6 clusters

The result in Fig-2 looks good, but it might not always be achievable. Why?
1. K-means is dependent on initialization. For eg: Fig-3 shows the result when initialization is different.

Fig 3 — Suboptimal results

2. K-means require the number of clusters as input. Many times data is multidimensional and not just two-dimensional. So, it becomes difficult to visualize it and decide how many clusters exist. For eg: if the number of clusters is selected as 3 instead of 6, the results would look like Fig 4.

Fig 4 — if clusters were 3 and not 6

Interesting questions to ponder:

  1. Are there any advanced initialization techniques?
  2. What metrics can we use to know when to stop the iterations?
  3. Is there a way to take a better guess about the number of clusters?
  4. How do K-means handle outliers and noisy data?
  5. What are some improvements that exist?
  6. What are some practical use cases of K-means?

If you are looking for the answers, see here.

Unlisted

--

--