K-means clustering

Dhanoop Karunakaran
Intro to Artificial Intelligence
3 min readAug 24, 2023
K-means clustering. Source: [1]

Clustering is an unsupervised technique that groups similar data points to form clusters. In K-means clustering, it groups data points into k clusters. Let’s explain the k-means clustering with an example of two clusters where it means that k=2.

Randomly initializing the two cluster centroids

Initially, the algorithm randomly picks two centroids for the data points as shown above in cross marks. Then it will go through each of these data points to see which of these two cluster centroids(red and blue) are closer. Then, assign these data points to either a red or blue centroid (whichever closer).

Data points are assigned to closer cluster centroids

As you can see in the above figure, the data points are assigned to the specific cluster centroids. Now we look at all the red points and average them to find the new centroids. Similarly, we need to find the new centroids for the blue cluster by averaging data points in that cluster.

New centroids by averaging the data points assigned to the specific clusters previously.

As we can see in the above figure, we have got two new centroids. This entire process repeats, starting from finding which centroids are closer to the data points until no more changes to the cluster centroid are possible.

Pseudocode of k-means clustering

1. Randomly initialize the k cluster centroids
repeat{
2. Assign the data points to cluster centroids
3. Move cluster centroids
}

Application

K-means clustering is being used for a wide variety of tasks such as customer segmentation, image segmentation, etc. I found image compression using k-means clustering interesting. Let’s dive into that by utilising the implementation provided by Scikit-learn rather than implementing the k-means algorithm from scratch.

We have taken the examples provided by [2] and wrapped in a Jupyter notebook and docker container if you need to run them easily. We have published this to GitHub repo.

Images with RGB channels are represented in a 3-dimensional array. The first step is to transform the 3D array to a 2D array using below the code.

rows = image.shape[0]
cols = image.shape[1]
image = image.reshape(rows*cols, 3)

The next step is to run the K-means algorithm with k=6.

kmeans = KMeans(n_clusters=6)
kmeans.fit(image)

Now, replace each pixel value with nearby centroids.

compressed_image = kmeans.cluster_centers_[kmeans.labels_]
compressed_image = np.clip(compressed_image.astype('uint8'), 0, 255)

Here is the comparison between the original image and the compressed image.

Original image
Compressed image as the only 6 colours(k=6) used to represent the entire image, so it reduces the size.

If you like my write-up, follow me on Github, Linkedin, and/or Medium profile.

Reference

  1. https://www.javatpoint.com/k-means-clustering-algorithm-in-machine-learning
  2. https://iq.opengenus.org/image-compression-using-k-means/

--

--