Extracting Dominant Colours from an Image using K-means Clustering from Scratch

Tanvi Penumudy
Jan 16 · 4 min read

Extract the dominant colours from any image of your choice in less than 5 minutes from scratch!

You’ve probably heard the phrase “a picture is worth a thousand words.” In our digitally-advanced age, this is more accurate than ever; a lot of information can be extracted from an image. High-level computer vision systems have allowed self-driving cars to recognize whether an object is a pedestrian crossing or a static road hazard up ahead, and Instagram filters are face-detecting and interactive. These advancements stem from most of the fundamental approaches of machine learning.

For more on Self-Driving Cars: A Beginner’s Guide to Reinforcement Learning and its Basic Implementation from Scratch

Machine learning involves the learning process machines undertake in order to understand data and provide some answers about the data. In the context of image processing, an application of machine learning could be the attempt to process an image digitally, with numbers that represent the pixels and colours as data.

For more on Machine Learning: A Beginner’s Guide for Getting Started with Machine Learning

Approaches that don’t provide prediction or assume a correct set of outputs but instead uncover insights from a given dataset are referred to as unsupervised. One such technique for image processing and information extraction is K-means clustering, a learning approach that aims to partition n data points into k groups.

For the conceptual overview of K-means Clustering, refer —Everything you need to know about K-Means Clustering

We shall now begin by the code walkthrough for the implementation of the K-means Clustering algorithm from scratch:

Fret not! I promise you that it’s going to turn out as fascinating as it sounds!

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))

class KMeans():

def __init__(self, K=5, max_iters=100, plot_steps=False):
self.K = K
self.max_iters = max_iters
self.plot_steps = plot_steps

# list of sample indices for each cluster
self.clusters = [[] for _ in range(self.K)]
# the centers (mean feature vector) for each cluster
self.centroids = []

def predict(self, X):
self.X = X
self.n_samples, self.n_features = X.shape

# initialize
random_sample_idxs = np.random.choice(self.n_samples, self.K, replace=False)
self.centroids = [self.X[idx] for idx in random_sample_idxs]

# Optimize clusters
for _ in range(self.max_iters):
# Assign samples to closest centroids (create clusters)
self.clusters = self._create_clusters(self.centroids)
if self.plot_steps:
self.plot()

# Calculate new centroids from the clusters
centroids_old = self.centroids
self.centroids = self._get_centroids(self.clusters)

# check if clusters have changed
if self._is_converged(centroids_old, self.centroids):
break

if self.plot_steps:
self.plot()

# Classify samples as the index of their clusters
return self._get_cluster_labels(self.clusters)


def _get_cluster_labels(self, clusters):
# each sample will get the label of the cluster it was assigned to
labels = np.empty(self.n_samples)

for cluster_idx, cluster in enumerate(clusters):
for sample_index in cluster:
labels[sample_index] = cluster_idx
return labels

def _create_clusters(self, centroids):
# Assign the samples to the closest centroids to create clusters
clusters = [[] for _ in range(self.K)]
for idx, sample in enumerate(self.X):
centroid_idx = self._closest_centroid(sample, centroids)
clusters[centroid_idx].append(idx)
return clusters

def _closest_centroid(self, sample, centroids):
# distance of the current sample to each centroid
distances = [euclidean_distance(sample, point) for point in centroids]
closest_index = np.argmin(distances)
return closest_index

def _get_centroids(self, clusters):
# assign mean value of clusters to centroids
centroids = np.zeros((self.K, self.n_features))
for cluster_idx, cluster in enumerate(clusters):
cluster_mean = np.mean(self.X[cluster], axis=0)
centroids[cluster_idx] = cluster_mean
return centroids

def _is_converged(self, centroids_old, centroids):
# distances between each old and new centroids, fol all centroids
distances = [euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]
return sum(distances) == 0

def plot(self):
fig, ax = plt.subplots(figsize=(12, 8))

for i, index in enumerate(self.clusters):
point = self.X[index].T
ax.scatter(*point)

for point in self.centroids:
ax.scatter(*point, marker="x", color='black', linewidth=2)

plt.show()
def cent(self):
return self.centroids
#Extracting Dominant Colours in an Imageimport cv2
from skimage import io
from google.colab.patches import cv2_imshow
url = "https://www.teahub.io/photos/full/35-355143_windows-10-wallpaper-umbrella.jpg"
img = io.imread(url)
img.shapeOut:
(1080, 1920, 3)
img_init = img.copy()plt.figure(figsize=(6, 6))
plt.imshow(img_init)
Out:
<matplotlib.image.AxesImage at 0x7f49e6a7d6a0>
Image for post
Image for post
img = img.reshape((img.shape[0] * img.shape[1],img.shape[2]))k = KMeans(K=5) #for 5-most dominant colours
y_pred = k.predict(img)
k.cent()
Out:
array([[ 53.16708662, 93.69632655, 175.47967713],
[133.07051658, 195.54817432, 47.66459003],
[237.62760178, 76.63096981, 21.42656026],
[248.665852 , 31.23121874, 121.13346739],
[206.22142881, 229.36967717, 152.23724866]])
y_predOut:
array([2., 2., 2., ..., 0., 0., 0.])
label_indx = np.arange(0,len(np.unique(y_pred)) + 1)
label_indx
Out:
array([0, 1, 2, 3, 4, 5])
np.histogram(y_pred, bins = label_indx)Out:
(array([565545, 172073, 559377, 593291, 183314]), array([0, 1, 2, 3, 4, 5]))
(hist, _) = np.histogram(y_pred, bins = label_indx)
hist = hist.astype("float")
hist /= hist.sum()
hist
Out:
array([0.27273582, 0.08298274, 0.26976128, 0.28611642, 0.08840374])
hist_bar = np.zeros((50, 300, 3), dtype = "uint8")startX = 0
for (percent, color) in zip(hist, k.cent()):
endX = startX + (percent * 300) # to match grid
cv2.rectangle(hist_bar, (int(startX), 0), (int(endX), 50),
color.astype("uint8").tolist(), -1)
startX = endX
plt.figure(figsize=(15,15))
plt.subplot(121)
plt.imshow(img_init)
plt.subplot(122)
plt.imshow(hist_bar)
plt.show()
Image for post
Image for post

Hope you enjoyed and made the most out of this article! Stay tuned for my upcoming blogs! Make sure to CLAP and FOLLOW if you find my content helpful/informative!

For complete code implementation:

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data…

Sign up for Analytics Vidhya News Bytes

By Analytics Vidhya

Latest news from Analytics Vidhya on our Hackathons and some of our best articles! Take a look.

By signing up, you will create a Medium account if you don’t already have one. Review our Privacy Policy for more information about our privacy practices.

Check your inbox
Medium sent you an email at to complete your subscription.

Tanvi Penumudy

Written by

CS Undergrad at Bennett University

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Tanvi Penumudy

Written by

CS Undergrad at Bennett University

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store