Image Similarity Search in PyTorch

A Simple Image Search Engine.

Aditya Oke
PyTorch
5 min readSep 2, 2020

--

This blog post walks you through how to create a simple image similarity search engine using PyTorch. This tutorial is great for machine learning beginners who are interested in computer vision.

The Image Similarity Search Problem

The problem is simple. We are given an image from the user. We have a large set of images available to us. We want to compute similar images to the given image.

The Image Similarity Problem

To search over images, we first need to understand how do we `learn` about images. If our algorithm understands how images look like, it can find out similar images. Hence, our first challenge is learning to represent these images.

Learning to learn with Auto-Encoders

Let us say we learn to reconstruct an image. While reconstructing an image we need to learn how the image looks. This notion is captured by an `encoding network` called a convolutional encoder. The convolutional encoder converts images into feature representations. These `feature representations` helps to recognize images.

To reconstruct an image we would again need to convert these `feature representation` to original images. To achieve it we use a `decoding network` called a convolutional decoder. The convolutional decoder reconstructs an image from its feature representation.

In short we have the following :

These two networks work in cooperation. One tries to learn how the image can be transformed into features. While the other focuses on how these features can be converted back to the original image. They both mutually help each other in learning.

Basic Idea behind Auto-Encoder

This way of learning from representations is called `representation learning`. Here we aim to find out suitable `representation` or `features` that will describe our data. Here we are not labelling images in the dataset. We have hundreds or thousands of images from which we wish to recommend a similar image. Hence, this method is an `unsupervised learning method`.

From Idea to Code using PyTorch

Let us convert these ideas to code. PyTorch makes it very simple to do so. We will create our `dataset` class and the `model` for training.

The Dataset Class

Here is a simple dataset class that converts all our image in a folder to PyTorch dataset.

Simple dataset from folder

Here we return two images. One as input to our model; while the other to compare with the original image for reconstruction.

The Model

Our encoder model is a repetition of convolutional, relu and maxpool layers.

Encoder Model in PyTorch

Encoder model thus converts our input image to a feature representation of size (1, 256, 16, 16). It can be calculated by passing a dummy image to the encoder. This feature representation serves as input to our decoder model.

Decoder Model in PyTorch

The decoder takes an input of feature representations and reconstructs back the image. We upscale the feature representations to the original image using transposed convolution layers of kernel size (2, 2) and stride (2, 2).

Training and Saving the Feature Representations

Training our image similarity model is simple. We create the PyTorch `dataset` and the `dataloaders`. To measure the difference between the reconstructed image and original image we use Mean Squared Error loss. It measures the overall squared difference between the two.

Simplified Training

Our train step and validation step functions are simple. We feed the training image to the encoder. The output of encoder goes through the decoder. This reconstructed image is used to calculate loss which we return.

Simple Training and Validation Steps

Finally, we save our feature representations of all images in the dataset. These are called image embeddings. We store them in NumPy `.npy` format. It serves as our `image indexes` which we can use to search for similar images.

Searching for Similar Images

Great! Now we have feature representations (embedding) for our complete dataset. We now need to search for a similar image.
A query image whose similar images are required too can be converted to feature representation using our encoder network. For a moment, let us think of these `feature representations` as points.

What we need to find is `closest points to a given point` as illustrated below.

Consider each point as Feature Representation

Recalling our machine learning basics, one way of finding these is using K-Nearest Neighbors ! Where “K” is the number of similar images the user requires.

The final missing piece, Nearest Neighbors Search

Let us put these ideas into code.
We need to convert the user’s image to embedding using the encoder. After this, we need to Compute similar images using K-Nearest Neighbors algorithm.

Searching For Similar Images

Voila !! We Are Done !!

We have successfully found images similar to a given image using our image similarity system!

Let us have a look at some outputs. I have trained this on a dataset containing images of animals.

Certainly not bad for our simple image search system.

Final Thoughts

We built a basic image search system from scratch ourselves. There are multiple ways to achieve this task. One can make use of pre-trained models such as ResNet or VGG as a feature extractors. These models can directly be used to create feature representations.
Also, if you are looking for a production-ready system one can use the following libraries or tools.

FAISS: A library from Facebook for image similarity search. You can find more information about it here. It is an advanced, state of the art and open-source implementation that is highly scalable.

Deep Ranking: It is another technique that performs image similarity. It formulates the problem in different fashion comparing three images (triplets) at a time. You can read more about it in this paper.

The code and documentation for this blog post are here.

Citations

The images are taken from these sources.

  1. Image similarity problem. Thanks to Europeana pro.
  2. Basic Idea behind Auto-encoder. Taken from this medium blog post by HackerNoon.
  3. K-Nearest Neighbors Feature representations: Adopted from this website

--

--