Building multimodal image search with PyTorch (part 1)

Mikhail Korotkov
19 min readJan 25, 2023

--

Cross modal retrieval

The problem of cross-modal retrieval is becoming more and more popular every day. We mostly use text description to specify our request as it is the most convenient and compact type of information for a human. At the same moment, this test description could be a search request for data of completely different modality like image, video or even more complex structure like Instagram post containing image and text at the same moment.

Recent models for cross-modal retrieval have benefited from an increasingly rich understanding of visual scenes, afforded by scene graphs and object interactions. This has resulted in an improved matching between the visual representation of an image and the textual representation of its caption. Yet, current visual representations overlook a key aspect: the text appearing in images, which may contain crucial information for retrieval.

In this article, I will discuss some important aspects of how cross-modal retrieval pipeline could be interpreted and designed and show how to build an example of model for multimodal image retrieval both image-to-text and text-to-image with an extra block for image captioning as well.

Text-to-Image search in general

As mentioned before, we tackle the problem of multimodal information retrieval. Given one view (image or text), we want to retrieve the most relevant other view from a database. Mathematically, we want to build a system that computes a similarity measure between an image and a text. To do so, we use the machine learning general framework: using a dataset of aligned images and textual descriptions, we try to learn from this dataset a way to compute a similarity between those views.

For this purpose, we will need:

  • Dataset with image-caption pairs
  • Model pipeline, designed to map input data of different modalities to single vector space
  • Similarity metric for computing distance between embeddings
  • Loss for training our pipeline
  • Setup for efficient search through calculated embeddings

In this part of the article I will cover the overall task definition from dataset to different types of loss that could be applied and in Part 2 I will dive deeper into certain encoder architectures and search setups.

Let’s start with the dataset.

Dataset

We will use the CTC-5k dataset proposed in paper StacMR that allows exploration of cross-modal retrieval where images contain scene-text instances. For evaluation purposes, we define two test splits. The first one, which we refer to as CTC-1K, is a subset of CTC explicit. The second test set, CTC-5K, contains the previous 1000 explicit images of CTC-1K plus 4, 000 non-explicit images. The remaining 738 explicit plus 4945 non-explicit images are used for training and validation purposes.

Full pipeline architecture

You may heard about STARNet (Scene-Text Aware Retrieval Network) model. It is composed of the following modules: a joint encoder for both an image and its scene text, a caption encoder for image caption, and an additional caption generation module.

STARNet architecture

Unfortunately, this architecture possesses several drawback, such as Google API OCR usage, excessive GRU blocks and slow two-stage Faster RCNN detector. Further, we will built similar architecture but avoid Google API OCR module usage as well as avoid usage of precomputed features for image encoder. This will make this pipeline easy to reproduce and improve. We’ll also add more advanced methods for analysing images and text scenes and make the whole pipeline ready for end-to-end inference.

Talking about cross-modal retrieval in general and text-image retrieval especially the most common method is to transform both entities into similar vector space where similar images and their text captions embeddings will appear close in terms of some distance metrics (like cosine, or euclidean). So, what do we need to create such embeddings?

Firstly, we will need some encoder for image and some encoder for image caption to convert our data to some vector space. Then we will need some Loss Function to train this embedding pairs to be close for relative captions and distant for irrelevant ones. We may also need another image captioning model head if we want our model to be able to generate captions for a given image.

That’s how the whole pipeline will look like:

Modified Scene Text Aware Cross Modal Retrieval architecture (MStacMR)

Before diving deeper into certain encoder architectures, lets discuss what types of Loss functions could be applied within such domain, what is Contrastive Loss and Triplet Ranking Loss and how they could be used for cross-modal retrieval.

Loss functions for metric-learning

Several real-world applications in the industry, ranging from Face Recognition to Object Detection, POS Tagger to Document ranking in NLP, are formulated as a multi-class classification problem. Image retrieval pipeline a first glance looks quite similar to an image classification task where different captions appears as different classes. The most common way dealing with classification task is Softmax function.

Softmax

The softmax function takes some vector and turn it into numbers in a range of 0 to 1 with the sum of all the numbers equaling 1. One other nice property of softmax is that one of the values is usually much bigger than the others. When calculating the loss for categorical cross-entropy, the first step is to take the softmax of the values, then the negative log of the labeled category.

Softmax

Metric learning

Unfortunately, a typical softmax-based deep network wouldn’t help when the number of classes in the output layer is too high because of the sparseness of the network. Rather, this kind of problem can be formulated in a different way. The idea is to learn distributed embeddings representation of data points in a way that in the high dimensional vector space, contextually similar data points are projected in the near-by region whereas dissimilar data points are projected far away from each other.

In this case the model is trying to solve a metric learning problem: to learn some sort of similarity metric between images and text captions. The model is trying to reorganise the input space, pull the similar images and texts together in some form of a cluster while pushing dissimilar pairs away.

Similarity metric

In order to measure how similar two vectors are to each other, we need a way of measuring distance. In 2 or 3 dimensions the euclidian distance is a great choice for measuring the distance between two points. However, in a large dimensional space, all points tend to be far apart by the euclidian measure. In higher dimensions, the angle between vectors is a more effective measure. The cosine distance measures the cosine of the angle between the vectors. The cosine of identical vectors is 1 while orthogonal and opposite vectors are 0 and -1 respectively. More similar vectors will result in a larger number. Calculating the cosine distance could be done by taking the dot product of the vectors.

So then, some special loss (like Contrastive, Triplet or other) takes the output of the network as a positive example and calculates its distance to an example of the same class and contrasts that with the distance to negative examples. Said another way, the loss is low if positive samples are encoded to similar (closer) representations and negative examples are encoded to different (farther) representations.

Contrastive Loss

The general formula for Contrastive Loss is shown at Fig. 1.

Figure 1 — Generalised Contrastive Loss

Y term here specifies, whether the two given data points (X₁ and X₂) are similar (Y = 0) or dissimilar (Y = 1). The Ls term in Fig. 1 stands for the loss function, which should be applied to the output if the given samples are similar, the Ld term — a loss function to apply, when the given data points are dissimilar. The Dw term is the similarity (or, rather, dissimilarity) between 2 transformed data points, given by Le Cunn like so:

Figure 2 — Distance measure between transformed data points

The G in this formula stands for the mapping function itself — i.e. a Neural Network in our case. This is a regular Euclidean distance function (calculated between outputs of the Neural Network), which was used by Le Cunn in the paper — however you can use other similarity metrics like Manhattan distance, Cosine similarity, etc.

The formula in Fig. 1 is highly reminiscent of the Cross-entropy loss — it has the same structure. The difference is that Cross-entropy loss is a classification loss which operates on class probabilities produced by the network independently for each sample, and Contrastive loss is a metric learning loss, which operates on the data points produced by network and their positions relative to each other. This is also part of the reason a cross-entropy loss is not usually used for metric learning tasks like Face Verification — it doesn’t impose any constraints on the distribution on the model’s inner representation of the given data — i.e. the model can learn any features regardless of whether similar data points would be located closely to each other or not after the transformation.

The exact loss function Le Cunn came up with is presented in Fig. 3.

Figure 3 — Actual Contrastive Loss function

So Ls (loss for similar data points) is just Dw, distance between them, if two data points are labeled as similar, we will minimise the euclidean distance between them. Ld, on the other hand, needs some explanation. One may think that for two dissimilar data points we just need to maximise distance between them — i.e. minimise something like 1/Dw. But why didn’t Le Cunn just use 1/Dw?

Time for a little visualisation. Let’s say we have some data point (blue dot) and a couple of other data points, which are similar to it (black dots) and dissimilar (white dots) on Fig. 4. We would naturally like to pull black dots closer to the blue dots and push white dots farther away from it. Specifically, we would like to minimise the intra-class distances (blue arrows) and maximise the inter-class distances (red arrows).

Figure 4 — What we would like the algorithm to do. Notice how the white dots that were outside weren’t moved farther away from the margin.

What we would like to achieve is to make sure that for each group of similar points (in case of Face Recognition task it would be all the photos of the same person) the maximum intra-class distance is smaller than the minimum inter-class distance. What this means is that if we define some radius m (margin), all the black dots should fall inside of this margin, and all the white dots — outside (Fig. 4). This way we would be able to use a nearest neighbour algorithm for new data — if a new data point lies within m distance from other, they are similar (belong to same group). The same goes for Face Recognition — if a new face image is located within m distance from another, there is likely the same person on both of them.

So we need to make sure that black dots are inside the margin m, and white dots are outside of it. And that’s exactly what the function proposed by Le Cunn does! In Fig. 5 you see, that the right part of the loss penalises the model for dissimilar data points having the distance between them Dw < m. If Dw ≥ m, the m-Dw expression is negative and the whole right part of the loss function is thus 0 due to max() operation — and the gradient is also 0, i.e. we don’t force the dissimilar points farther away than necessary.

Figure 5 — Again, the loss function itself, so that you don’t have to scroll back.

Siamese network model architecture

Usage of Contrastive Loss leads us to a model architecture that is often called a Siamese Network. (Fig. 6):

Figure 6 — Siamese network architecture

For the same Face Verification task you have a convolutional neural network that gets applied to 2 images, then loss is calculated on its outputs and then the back-propagation algorithm is run.

Called a siamese architecture, it consists of two copies of the function Gw which share the same set of parameters w, and a cost module. A loss module whose input is the output of this architecture is placed on top of it. The input to the entire system is a pair of images (X₁, X₂) and a label Y . The images are passed through the functions, yielding two outputs G(X₁) and G(X₂). The cost module then generates the distance Dw(Gw(X₁), Gw(X₂)). The loss function combines Dw with label Y to produce the scalar loss Ls or Ld, depending on the label Y . The parameter w is updated using stochastic gradient. The gradients can be computed by back-propagation through the loss, the cost, and the two instances of Gw . The total gradient is the sum of the contributions from the two instances.

Contrastive Loss Summary

  1. Contrastive Loss is a metric learning loss function introduced by Yann Le Cunn in 2005
  2. It operates on pairs of embeddings received from the model and on the ground-truth similarity flag — a boolean label, specifying whether these two samples are similar or dissimilar. So the input must be not one, but 2 data points (images / image-text pair or other)
  3. It penalises samples for being far from each other in terms of similarity metric (Euclidean distance, cosine similarity or other distance metric)
  4. Dissimilar samples are penalised by being too close to each other, but in a somewhat different way — Contrastive Loss introduces the concept of margin — a minimal distance that dissimilar points need to keep. So it penalises dissimilar samples for beings closer than the given margin.
  5. It can be also applied to cross-modal retrieval task in the following way:
  • Train the model on image + caption pairs from the given dataset using captions from given images as positive examples and from different images as negative ones.
  • If you want to verify wether some caption is relevant to some image you can at first calculate the embedding of the given image and the embedding of the caption that you are trying to associate with that image.
  • Then, calculate the distance between these 2 embeddings (using the same metric you used during training)
  • If the distance is smaller than the margin specified during training, these are likely the images relevant to this request, otherwise not.
  • If you want to find the most relevant image to given text caption you could sort all image embeddings in database by similarity metric and take most relevant (close by distance) results.

Contrastive loss, like triplet and magnet loss, is used to map vectors that model the similarity of input items. These mappings can support many tasks, like unsupervised learning, one-shot learning, and other distance metric learning tasks.

Triplet Ranking Loss

Triplet Loss relies on similar concept but could be even more effective

Triplet Loss was first introduced in FaceNet: A Unified Embedding for Face Recognition and Clustering in 2015, and it has been one of the most popular loss functions for supervised similarity or metric learning ever since. In its simplest explanation, Triplet Loss encourages dissimilar pairs to be distant from any similar pairs by at least a certain margin value. Mathematically, the loss value can be calculated as:

Triplet Loss = max(Dw(a, p) - Dw(a, n) + m, 0)

where:

  • p (positive), is a sample that has the same label as a (anchor),
  • n (negative), is another sample that has a label different from a,
  • Dw is a function to measure the distance between these three samples,
  • and m is a margin value to keep negative samples far apart.

The paper uses Euclidean distance, but it is equally valid to use any other distance metric, e.g., cosine distance. The function has a learning objective that can be visualised as in the following:

Triplet Loss visualisation

Notice that Triplet Loss does not have a side effect of urging to encode anchor and positive samples into the same point in the vector space as in Contrastive Loss. This lets Triplet Loss tolerate some intra-class variance, unlike Contrastive Loss, as the latter forces the distance between an anchor and any positive essentially to 0. In other terms, Triplet Loss allows to stretch clusters in such a way as to include outliers while still ensuring a margin between samples from different clusters, e.g., negative pairs.

Additionally, Triplet Loss is less greedy. Unlike Contrastive Loss, it is already satisfied when different samples are easily distinguishable from similar ones. It does not change the distances in a positive cluster if there is no interference from negative examples. This is due to the fact that Triplet Loss tries to ensure a margin between distances of negative pairs and distances of positive pairs. However, Contrastive Loss takes into account the margin value only when comparing dissimilar pairs, and it does not care at all where similar pairs are at that moment. This means that Contrastive Loss may reach a local minimum earlier, while Triplet Loss may continue to organise the vector space in a better state.

Let’s demonstrate how two loss functions organise the vector space by animations. For simpler visualisation, the vectors are represented by points in a 2-dimensional space, and they are selected randomly from a normal distribution.

Animation that shows how Contrastive Loss moves points in the course of training.
Animation that shows how Triplet Loss moves points in the course of training.

From mathematical interpretations of this two loss functions, it is clear that Triplet Loss is theoretically stronger, but Triplet Loss has additional tricks that help it work better. Most importantly, Triplet Loss introduce online triplet mining strategies for automatically forming the most useful triplets.

Why triplet mining matters?

An important decision of a training with Triplet Ranking Loss is negatives selection or triplet mining. The strategy chosen will have a high impact on the training efficiency and final performance.

The formulation of Triplet Loss demonstrates that it works on three objects at a time:

  1. anchor
  2. positive - a sample that has the same label as the anchor,
  3. negative - a sample with a different label from the anchor and the positive.

Let’s analyse 3 situations of this loss:

  • Easy Triplets:
    Dw(a, n) > Dw(a, p) + m
    Dw(a, n) > Dw(a, p) + m

    The negative sample is already sufficiently distant from the anchor sample respect to the positive sample in the embedding space. The loss is 0 and the net parameters are not updating.
  • Hard Triplets:
    Dw(a, n) < Dw(a, p)
    Dw(a, n) < Dw(a, p)

    The negative sample is closer to the anchor than the positive. The loss is positive (and greater than m).
  • Semi-Hard Triplets:
    Dw(a, p) < Dw(a, n) < Dw(a, p) + m
    Dw(a, p) < Dw(a, n) < Dw(a, p) + m

    The negative sample is more distant to the anchor than the positive, but the distance is not greater than the margin, so the loss is still positive (and smaller than m).
An obvious appreciation is that training with Easy Triplets should be avoided, since their resulting loss will be 0

In a naive implementation, we could form such triplets of samples at the beginning of each epoch and then feed batches of such triplets to the model throughout that epoch. This is called offline strategy. However, this would not be so efficient for several reasons:

  • It needs to pass 3n samples to get a loss value of n triplets
  • Not all these triplets will be useful for the model to learn anything and yielding a positive loss value
  • Even if we form useful triplets at the beginning of each epoch with one of the methods that I will be implementing in this series, they may become useless at some point in the epoch as the model weights will be constantly updated

Instead, we can get a batch of n samples and their associated labels, and form triplets on the fly. That is called online strategy. Normally, this gives n³ possible triplets, but only a subset of such possible triplets will be actually valid. Even in this case, we will have a loss value calculated from much more triplets than the offline strategy. Given a triplet of (a, p, n), it is valid only if:

  1. a and p has the same label
  2. a and p are distinct samples
  3. n has a different label from a and p

These constraints may seem to be requiring expensive computation with nested loops, but it can be efficiently implemented with tricks such as distance matrix, masking, and broadcasting.

Distance matrix

A distance matrix is a matrix of shape (n, n) to hold distance values between all possible pairs made from items in two n-sized collections. This matrix can be used to vectorise calculations that would need inefficient loops otherwise. Its calculation can be optimised as well, and we will implement Euclidean Distance Matrix Trick (PDF) explained by Samuel Albanie. You may want to read this three-page document for the full intuition of the trick, but a brief explanation is as follows:

  1. Calculate the dot product of two collections of vectors, e.g., embeddings in our case.
  2. Extract the diagonal from this matrix that holds the squared Euclidean norm of each embedding.
  3. Calculate the squared Euclidean distance matrix based on the following equation: ||a-b||² = ||a||² -2(a, b) + ||b||²
  4. Get the square root of this matrix for non-squared distances.

PyTorch implementation:

import torch
import torch.nn as nn
import torch.nn.functional as F


eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks


def euclidean_distance_matrix(x):
"""Efficient computation of Euclidean distance matrix
Args:
x: Input tensor of shape (batch_size, embedding_dim)

Returns:
Distance matrix of shape (batch_size, batch_size)
"""
# step 1 - compute the dot product

# shape: (batch_size, batch_size)
dot_product = torch.mm(x, x.t())

# step 2 - extract the squared Euclidean norm from the diagonal

# shape: (batch_size,)
squared_norm = torch.diag(dot_product)

# step 3 - compute squared Euclidean distances

# shape: (batch_size, batch_size)
distance_matrix = squared_norm.unsqueeze(0) - 2 * dot_product + squared_norm.unsqueeze(1)

# get rid of negative distances due to numerical instabilities
distance_matrix = F.relu(distance_matrix)

# step 4 - compute the non-squared distances

# handle numerical stability
# derivative of the square root operation applied to 0 is infinite
# we need to handle by setting any 0 to eps
mask = (distance_matrix == 0.0).float()

# use this mask to set indices with a value of 0 to eps
distance_matrix += mask * eps

# now it is safe to get the square root
distance_matrix = torch.sqrt(distance_matrix)

# undo the trick for numerical stability
distance_matrix *= (1.0 - mask)

return distance_matrix

Invalid triplet masking

Now that we can compute a distance matrix for all possible pairs of embeddings in a batch, we can apply broadcasting to enumerate distance differences for all possible triplets and represent them in a tensor of shape (batch_size, batch_size, batch_size). However, only a subset of these triplets are actually valid as I mentioned earlier, and we need a corresponding mask to compute the loss value correctly. We will implement such a helper function in three steps:

  1. Compute a mask for distinct indices, e.g., (i != j and j != k)
  2. Compute a mask for valid anchor-positive-negative triplets, e.g.,
    labels[i] == labels[j] and labels[j] != labels[k]
  3. Combine two masks
def get_triplet_mask(labels):
"""compute a mask for valid triplets
Args:
labels: Batch of integer labels. shape: (batch_size,)
Returns:
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
A triplet is valid if:
`labels[i] == labels[j] and labels[i] != labels[k]`
and `i`, `j`, `k` are different.
"""
# step 1 - get a mask for distinct indices

# shape: (batch_size, batch_size)
indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
indices_not_equal = torch.logical_not(indices_equal)
# shape: (batch_size, batch_size, 1)
i_not_equal_j = indices_not_equal.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
i_not_equal_k = indices_not_equal.unsqueeze(1)
# shape: (1, batch_size, batch_size)
j_not_equal_k = indices_not_equal.unsqueeze(0)
# Shape: (batch_size, batch_size, batch_size)
distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)

# step 2 - get a mask for valid anchor-positive-negative triplets

# shape: (batch_size, batch_size)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
# shape: (batch_size, batch_size, 1)
i_equal_j = labels_equal.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
i_equal_k = labels_equal.unsqueeze(1)
# shape: (batch_size, batch_size, batch_size)
valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))

# step 3 - combine two masks
mask = torch.logical_and(distinct_indices, valid_indices)

return mask

Batch-all strategy for online triplet mining

Now we are ready for actually implementing Triplet Loss itself. Triplet Loss involves several strategies to form or select triplets, and the simplest one is to use all valid triplets that can be formed from samples in a batch. This can be achieved in four easy steps thanks to utility functions we’ve already implemented:

  1. Get a distance matrix of all possible pairs that can be formed from embeddings in a batch
  2. Apply broadcasting to this matrix to compute loss values for all possible triplets
  3. Set loss values of invalid or easy triplets to 0
  4. Average the remaining positive values to return a scalar loss

Implementation example:

class BatchAllTripletLoss(nn.Module):
"""Uses all valid triplets to compute Triplet loss
Args:
margin: Margin value in the Triplet Loss equation
"""
def __init__(self, margin=1.):
super().__init__()
self.margin = margin

def forward(self, embeddings, labels):
"""computes loss value.
Args:
embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
Returns:
Scalar loss value.
"""
# step 1 - get distance matrix
# shape: (batch_size, batch_size)
distance_matrix = euclidean_distance_matrix(embeddings)

# step 2 - compute loss values for all triplets by applying broadcasting to distance matrix

# shape: (batch_size, batch_size, 1)
anchor_positive_dists = distance_matrix.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
anchor_negative_dists = distance_matrix.unsqueeze(1)
# get loss values for all possible n^3 triplets
# shape: (batch_size, batch_size, batch_size)
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin

# step 3 - filter out invalid or easy triplets by setting their loss values to 0

# shape: (batch_size, batch_size, batch_size)
mask = get_triplet_mask(labels)
triplet_loss *= mask
# easy triplets have negative loss values
triplet_loss = F.relu(triplet_loss)

# step 4 - compute scalar loss value by averaging positive losses
num_positive_losses = (triplet_loss > eps).float().sum()
triplet_loss = triplet_loss.sum() / (num_positive_losses + eps)

return triplet_loss

Triplet Ranking Loss for Multi-Modal Retrieval

So, in our case, the anchor sample is the image, the positive sample is the text associated to that image, and the negative sample is the text of another “negative” image. To choose the negative text, we explored different online negative mining strategies, using the distances in the GloVe space with the positive text embedding. Triplets mining is particularly sensible in this problem, since there are not established classes. Given the diversity of the images, we have many easy triplets. But we also have to be careful mining hard-negatives, since the text associated to another image can be also valid for an anchor image.

Triplet Loss architecture helps us to learn distributed embedding by the notion of similarity and dissimilarity. It’s a kind of neural network architecture where multiple parallel networks are trained that share weights among each other. During prediction time, input data is passed through one network to compute distributed embeddings representation of input data.

To be continued…

In the next part we’ll discuss the architecture of Image encoder, caption encoder, OCR for text scenes GCN for scene graph reconstruction and the whole pipeline implementation and metrics.

Links

Github project

Articles used for inspiration:
https://towardsdatascience.com/triplet-loss-advanced-intro-49a07b7d8905
https://towardsdatascience.com/contrastive-loss-explaned-159f2d4a87ec
https://medium.com/@maksym.bekuzarov/losses-explained-contrastive-loss-f8f57fe32246

--

--