Machine learning applied to the design industry: K-Means for image palette generation

Maria Magdalena Balos
8 min readFeb 25, 2024

In this era of advanced technology, where Artificial Intelligence and automation have become commonplace, I find myself remembering about my first encounter with technology. At that time, I was a 9 years old child living in Romania, where agriculture was the main industry.

I was captivated by a vintage GPO rotary phone, not only for the mechanical process of dialing — inserting my tiny finger into a hole and rotating the disk to enter numbers — but also for the underlying “magic” of communication that it represented.

Photo by engin akyurt on Unsplash

Since I was young, I have seen how technology has changed and evolved, aiming towards automation and improvement, making people’s life easier.

Today, I caught myself contributing to this cause by developing a machine learning application for the design industry. The purpose of this project is to create a website where users can upload an image and get a 10 colours palette magicaly extracted from it, creating the same sense of fascination I experienced when I first interacted with technology.

However, unlike magic, I will explain how the application works!

If you’re 9 years old and want to preserve the ✨ magic ✨ alive, watch the following video and skip the rest of the article .

Video showing how HueHarvest extracts a palette from the image

Project Schema

I successfully achieved the goal of this project by utilizing my knowledge of website design, data science, and various Python libraries. Before go through details, let’s break the project down into simple and accessible tasks that are required to complete it:

  • Create a website that allows the user to uploading an image.
  • Acquire the image from the user.
  • Process the image to extract the ten most representative colours.
  • Display the colour palette extracted ftom the image to the user.

Now that we are all on the same page, let me explain how I tackled each of these tasks separately and how I combined them in the final step of this project.

User interface — website

I created a simple HTML homepage that provides users with a brief description of the website. Users can browse and upload an image from their device. Once they upload the image, they need to click on the “generate” button to extract the color palette from the image.

A screenshot of the index.html file after the image was uploaded
A screenshot of the homepage

Once the extraction process is finalised, a new page is loaded to display the provided image and the extracted palette colour.

I have added some styles such as typography, colors, and hierarchy to both the homepage and resulting page. To ensure responsiveness and arrange the elements on the screen, I used the Bootstrap grid system.

A screenshot of the colors.html file after the image was processed
A screenshot of the resulting page

Since I will be using the Flask framework in the upcoming stages of this project, I have organized these files according to the Flask app folders structure , which can be found in the next code cell.

-- main_folder
-- image_processing.py
-- main.py
-- static
-- css
-- styles.css
-- templates
-- colors.html
-- index.html

The content of the html templates is shockingly simple. In the spirit of brevity, I will skip that part, but please have a look at the template folder in my github to see the details.

Process the image and extract the palette

Having the interface in place, in this section we are going to focus on extracting the 10 most predominant colors from an image.

To achieve this, I wrote a function that receives the image path as input argument and returns an array containing all pixels' RGB colors. As you may know, an image’s colors are located in its layers, where RGB are representative of red, green, and blue respectively.

To extract the image colors, I first load the image using PIL library and convert the image to a numpy array, creating a 3D matrix with height x width x color-channels dimensions, which enables numerical operations on the image data. Then I flatten the image pixels, producing a 2D array with each row representing a pixel and each column representing the RGB channels of that pixel (shape width•height x channels).

def get_colors(img_path: str):
"""Get all the colors from an image by following the given path.
Args:
img_path (str): The image path to be open and generate the colors from.
Returns:
array: All the colors available in the image.
numpy array: The main image in numpy array format.
"""

# Asign the image to the image variable
image = Image.open(img_path).convert("RGB")

# Convert the image to np array format
image = np.array(image) # shape (H, W, C)

# Extract the colours from the image in RGB format
colors = image.reshape(image.shape[0] * image.shape[1], 3) # shape (H•W, C)

return colors, image
colors, image = get_colors(image_path)

Due to the large number of pixels, I decided to randomly select a sample of 5000 colors from the image and base the following analysis on them.

algorithm_colors = colors[np.random.choice(colors.shape[0], 5000, replace=False)]

Next, I wrote a function that takes the 5000 sample colours and trains a K-Means model to divide the colours into 10 different clusters. The centers of these 10 clusters are the most representative colours of the image. This function will return the model and the cluster centers (an array of shape 10 x 3, i.e. the clusters, and the 3 RGB channels) which we will use in the next steps.


def kmeans_model(algorithm_colors):
"""Uses the colours of an image as input for a KMeans model to group them in 10 clusters and return them.

Args:
algorithm_colors (numpy.ndarray): A list of 5000 colours radomly chosen from the image.

Returns:
sklearn.cluster._kmeans.KMeans: The train Kmeans model.
numpy.ndarray: The centers of the clusters generated by the model.
"""
model = KMeans(init="random", n_clusters=10, random_state=RANDOM_SEED)
model.fit(algorithm_colors)
return model, model.cluster_centers_
model, centers = kmeans_model(algorithm_colors=algorithm_colors)

At this point I alredy reached the goal of this section. However, when ploting the palette it had no order and looked messy. I decided to go a bit further so I wrote a function that sorts the colors so that the sequence minimizes the total Eucludean distance. This function will return the main list of predominant colours but now sorted.

def order_palette(palette_rgb):
"""Order the colors so that the resulting palette is more usefull and pleasent visually.

Args:
palette_rgb (list): The choosen colors to represent the image and create the palette.
Returns:
list of tuples: A list of 10 tuples representing the final color palette.
"""
# Find all possible permutations in the RGB palette
permutations = list(itertools.permutations(palette_rgb))

# Calculate adjacent color distances by aplying the Euclidean distance formula
permutations = np.array(permutations) # Passing the permutations to np.array
distances = np.sqrt(((permutations[:, :-1] - permutations[:, 1:]) ** 2).sum(axis=-1)).sum(axis=-1) # Euclidean distance formula
index_min = distances.argmin() # Getting the minimum distance between all the distances

# Find the permutation with min distance
ordered_palette = permutations[index_min]
return [tuple(color) for color in ordered_palette.tolist()]py
palette_rgb = order_palette(palette_rgb=palette_rgb)

Note that this method is very inefficient, as it evaluates each and every possible sequence of colors, but I decided to keep it as is for the sake of simplicity. As a next step, I plan to look into shortest path algorithms to make it more efficient.

The last stepis to transform the RGB colours into HEX colours. I wrote a function for that

def rgb2hex(palette_rgb):
"""Transforms a RGB palette into a HEX palette.
Args:
palette_rgb (list): The palette of colors in RGB to be converted in HEX.
Returns:
list: The list of converted collors in HEX format.
"""
return "#%02x%02x%02x" % palette_rgb
palette_hex = [rgb2hex(color) for color in palette_rgb]

Other functions created

Although this is not part of this project, I also wrote a couple of functions to help me experiment and understand the problem better. Here is a function that gets the hexadecimal codes and plots the colour palette.

def plot_palette(hex_palette):
"""Plot and save the palette image.
Args:
hex_palette (list): The list of the HEX colors that will create the final palette.
"""
palette = sns.color_palette(hex_palette)
sns.palplot(palette)
plt.title("Color palette")
plt.savefig("./day_92/static/uploads/img_palette.jpg")
plt.xlabel(hex_palette)
plt.show()
    plot_palette(palette_hex)

And here is a function that compress the image by using the k-means model to predict to which of the 10 clusters each pixel belongs and sustitute its colour with the centroid of the cluster.

def compress_img(image, model):
"""A program that compress the image based on the model centroids.

Args:
image (numpy array): The image used to get the colors and train the model.
model (sklearn.cluster._kmeans.KMeans): The trained model.

Returns:
numpy.ndarray: A compressed version of the image.
"""
image_flat = image.reshape(-1, 3)
labels = model.predict(image_flat)
new_colors = model.cluster_centers_[labels]
image_compressed = new_colors.reshape(image.shape)
return image_compressed
img_compresed = compress_img(image, model)

Example of image compression. In the left is located the original image, and in the right is the compressed one.

A collage of a compressed image

Get the image and return the palette to the user

And now is the moment to put everything together. The tasks that are missing to complete this project are: getting the image from the user and display the palette. To complete this, I used the Flask python library.

Imports and initialize the Flask application

from flask import Flask, render_template, request, send_from_directory, url_for
import os

from image_processing import extract_palette_from_img

app = Flask(__name__)

The decorator (@) of the home function below defines the URL for the home page. The home() function returns the HTML content of home page template that will be rendered in the screen. This is the main page where the users are asked to upload an image.

@app.route("/")
def home():
return render_template("index.html")

The get_image_palette() function is located in the /palette URI (see the decorator), where the user will be redirected once they upload the image to see the resulting palette, through a POST request.

This function first ensures that the request method is "POST", then it receives the uploaded image from the request, saves it and generates a path to the saved image.

Next, the functions created to process the image are called inside a function named extract_palette_from_img(), which calls the previously introduced functions in turn and returns the palette.

Finally, the following function renders the HTML template to show the extracted color palette alongside the image.

@app.route("/palette", methods=["POST"])
def get_image_palette():
if request.method == "POST":
image = request.files["img"]
path = os.path.join(app.config["UPLOAD"], image.filename)
image.save(path)
image_path = url_for("uploaded_file", filename=image.filename)
colours = extract_palette_from_img(path)
return render_template("colors.html", image_path=image_path, colours=colours)

Learnings and Results

If you are curious about the code and want to try it yourself, please follow the link bellow to access the repository where the code is:

Link to the HueHarvest project in Github

Possible improvements for this project

  • allow the users to download the colour palette in a file.
  • provide the users with a way of copying all the colours at once.
  • handle errors and exceptions in the code.
  • give the user the possibility to produce a k-means compressed image.
  • optimize the color palette sorting algorithm.
  • allow the user to select the number of codes in the palette.

Conclusion

This project has been part of 100 days of code, a course I am following to gain experience with Python code. I am a UX designer by training, and after realising how disconnected this field is from automation, data science and analytics, I decided to delve into this learning adventure. I am about to finish Datamecum, an intensive program in data science that helped me a lot in this process.

This implementation is just a small example of what can be done in the intersection between graphic design and data science. If you like it, be brave and come onboard, it's a great learning experience.

If you have any questions, improving ideas or any other feedback I’m happy to have a chat about it. Thank you for dedicating your time to explore this article.

--

--