# Extracting Dominant Colours from an Image using K-means Clustering from Scratch

*Extract the dominant colours from any image of your choice in less than 5 minutes from scratch!*

You’ve probably heard the phrase ** “a picture is worth a thousand words.”** In our digitally-advanced age, this is more accurate than ever; a lot of information can be extracted from an image. High-level computer vision systems have allowed

*self-driving cars*to recognize whether an object is a pedestrian crossing or a static road hazard up ahead, and Instagram filters are face-detecting and interactive. These advancements stem from most of the fundamental approaches of machine learning.

For more on Self-Driving Cars:

A Beginner’s Guide to Reinforcement Learning and its Basic Implementation from Scratch

*Machine learning* involves the learning process machines undertake in order to understand data and provide some answers about the data. In the context of *image processing*, an application of machine learning could be the attempt to process an image digitally, with numbers that represent the *pixels* and *colours* as data.

For more on Machine Learning:

A Beginner’s Guide for Getting Started with Machine Learning

Approaches that don’t provide prediction or assume a correct set of outputs but instead uncover insights from a given dataset are referred to as *unsupervised*. One such technique for *image processing* and *information extraction* is** K-means clustering,** a learning approach that aims to

*partition n data points into k groups.*

For the conceptual overview of K-means Clustering, refer —

Everything you need to know about K-Means Clustering

*We shall now begin by the code walkthrough for the implementation of the K-means Clustering algorithm from scratch:*

Fret not! I promise you that it’s going to turn out as fascinating as it sounds!

importnumpyas np

importmatplotlib.pyplotas plt

np.random.seed(42)def euclidean_distance(x1, x2):

return np.sqrt(np.sum((x1 - x2)**2))class KMeans():

def __init__(self, K=5, max_iters=100, plot_steps=False):

self.K = K

self.max_iters = max_iters

self.plot_steps = plot_steps

# list of sample indices for each cluster

self.clusters = [[] for _ in range(self.K)]

# the centers (mean feature vector) for each cluster

self.centroids = []

def predict(self, X):

self.X = X

self.n_samples, self.n_features = X.shape

# initialize

random_sample_idxs = np.random.choice(self.n_samples, self.K, replace=False)

self.centroids = [self.X[idx] for idx in random_sample_idxs]

# Optimize clusters

for _ in range(self.max_iters):

# Assign samples to closest centroids (create clusters)

self.clusters = self._create_clusters(self.centroids)

if self.plot_steps:

self.plot()

# Calculate new centroids from the clusters

centroids_old = self.centroids

self.centroids = self._get_centroids(self.clusters)

# check if clusters have changed

if self._is_converged(centroids_old, self.centroids):

break

if self.plot_steps:

self.plot()

# Classify samples as the index of their clusters

return self._get_cluster_labels(self.clusters)

def _get_cluster_labels(self, clusters):

# each sample will get the label of the cluster it was assigned to

labels = np.empty(self.n_samples)

for cluster_idx, cluster in enumerate(clusters):

for sample_index in cluster:

labels[sample_index] = cluster_idx

return labels

def _create_clusters(self, centroids):

# Assign the samples to the closest centroids to create clusters

clusters = [[] for _ in range(self.K)]

for idx, sample in enumerate(self.X):

centroid_idx = self._closest_centroid(sample, centroids)

clusters[centroid_idx].append(idx)

return clusters

def _closest_centroid(self, sample, centroids):

# distance of the current sample to each centroid

distances = [euclidean_distance(sample, point) for point in centroids]

closest_index = np.argmin(distances)

return closest_index

def _get_centroids(self, clusters):

# assign mean value of clusters to centroids

centroids = np.zeros((self.K, self.n_features))

for cluster_idx, cluster in enumerate(clusters):

cluster_mean = np.mean(self.X[cluster], axis=0)

centroids[cluster_idx] = cluster_mean

return centroids

def _is_converged(self, centroids_old, centroids):

# distances between each old and new centroids, fol all centroids

distances = [euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]

return sum(distances) == 0

def plot(self):

fig, ax = plt.subplots(figsize=(12, 8))

for i, index in enumerate(self.clusters):

point = self.X[index].T

ax.scatter(*point)

for point in self.centroids:

ax.scatter(*point, marker="x", color='black', linewidth=2)

plt.show()

def cent(self):

return self.centroids#Extracting Dominant Colours in an Imageimportcv2

fromskimageimportio

fromgoogle.colab.patchesimportcv2_imshowurl = "https://www.teahub.io/photos/full/35-355143_windows-10-wallpaper-umbrella.jpg"

img = io.imread(url)img.shapeOut:(1080, 1920, 3)img_init = img.copy()plt.figure(figsize=(6, 6))

plt.imshow(img_init)Out:<matplotlib.image.AxesImage at 0x7f49e6a7d6a0>

img = img.reshape((img.shape[0] * img.shape[1],img.shape[2]))k =KMeans(K=5)#for 5-most dominant colours

y_pred = k.predict(img)

k.cent()Out:array([[ 53.16708662, 93.69632655, 175.47967713],

[133.07051658, 195.54817432, 47.66459003],

[237.62760178, 76.63096981, 21.42656026],

[248.665852 , 31.23121874, 121.13346739],

[206.22142881, 229.36967717, 152.23724866]])y_predOut:array([2., 2., 2., ..., 0., 0., 0.])label_indx = np.arange(0,len(np.unique(y_pred)) + 1)

label_indxOut:array([0, 1, 2, 3, 4, 5])np.histogram(y_pred, bins = label_indx)Out:(array([565545, 172073, 559377, 593291, 183314]), array([0, 1, 2, 3, 4, 5]))(hist, _) = np.histogram(y_pred, bins = label_indx)

hist = hist.astype("float")

hist /= hist.sum()

histOut:array([0.27273582, 0.08298274, 0.26976128, 0.28611642, 0.08840374])hist_bar = np.zeros((50, 300, 3), dtype = "uint8")startX = 0

for (percent, color) in zip(hist, k.cent()):

endX = startX + (percent * 300)# to match grid

cv2.rectangle(hist_bar, (int(startX), 0), (int(endX), 50),

color.astype("uint8").tolist(), -1)

startX = endXplt.figure(figsize=(15,15))

plt.subplot(121)

plt.imshow(img_init)

plt.subplot(122)

plt.imshow(hist_bar)

plt.show()

Hope you enjoyed and made the most out of this article! Stay tuned for my upcoming blogs! Make sure to

CLAPandFOLLOWif you find my content helpful/informative!

*For complete code implementation:*