Image Captioning with Attention: Part 1
The first part includes the overview of “Encoder-Decoder” model for image captioning and it’s implementation in PyTorch
Introduction
Throughout the years of research an image captioning problem stays an active and at the same time mature topic of deep learning. The end goal is to describe the content of an image, mapping it into a sequence of words.
This comprehensive survey can be a great reference to various approaches, evolving over time, datasets and evaluation metrics³.
In this short series of articles, I will stick with one particular technique that allows accomplishing a capture generation task with an “attention”-based sequence-to-sequence model.
Initially, I took a baseline model from my Udacity Nanodegree program with the major adjustment in data loading (added validation set), training (scheduled sampling) and inference (beam search).
During the implementation, I was following the best practices from the “Show, Attend and Tell” paper and this comprehensive tutorial for image captioning.
The code is written in PyTorch framework and can be found in this original repo.
Data loading
In the following study I used MS COCO dataset to train and validate the model.
This version of the dataset was released in 2014 and contained:
- 82,783 training;
- 40,504 validation;
- 40,775 testing images;
- 5 captions for each training and validation split.
Let’s have a closer look at the data loading. The full code data_loader.py
is available in the repo.
- The
get_loader()
function fromdata_loader.py
receivesmode
as an argument that takes three possible values:'train'
,'valid'
or'test'
. - If mode is either
'train'
or'valid'
, we retrieve the image batch of specifiedbatch_size
and corresponding captions of randomly sampled length, using theget_indices()
method in theCoCoDataset
class.
def get_indices(self):
# randomly select the caption length from the list of lengths
sel_length = np.random.choice(self.caption_lengths)
all_indices = np.where([self.caption_lengths[i] == sel_length for i in np.arange(len(self.caption_lengths))])[0]
# select m = batch_size captions from list above
indices = list(np.random.choice(all_indices, size=self.batch_size))
# return the caption indices of specified batch
return indices
3. Indices will be passed to the data loader that returns the data points with respect to the mode
:
- image, caption if
mode == 'train'
; - image, caption and all captions, corresponding to an image if
mode == 'valid'
; - original image, pre-proccessed image if
mode == 'test'
.
The entire corpus of captions will be required in order to compute BLEU scoring on the validation stage.
We are going to store vocabulary in vocab.pkl
file and specify vocab_from_file == True
for validation and test datasets in order to load the vocabulary from the file.
Model Architecture
Generally speaking, the model comprises three main parts:
- Encoder (pre-trained CNN);
- Attention network;
- Decoder — trainable RNN model.
1. Encoder
The current implementation (Figure 1) assumes Resnet-152 as a feature extractor that consists of convolution building blocks, embedded with shortcut connections.
It receives a 224x224 randomly cropped image sample with transformers applied (described later) and extracts 2048 feature maps with a size of 7x7 each.
Note, we obtain the feature vectors from the last convolution block without applying the fully-connected layer. This allows the attention network to be selective and focus on various image features during the decoding.
Details:
- To load the pre-trained Resnet-152, we use
models
subpackage of torchvision:models.resnet152(pretrained=True)
; - Use the output from the lowest convolutional block (ignoring adaptive average pooling and linear layer at the bottom):
models = list(resnet.childern())[:-2]
; - To prepare the features for decoding we permute the dimensions -
features.permute(0,2,3,1)
and reshape it. The output will have a size of(batch, 49, 2048)
.
2. Attention
Using the attention mechanism, we place an emphasis on the most important pixels in the image.
To focus on relevant parts on each decoding step, the attention network outputs the context vector, which is the weighted sum of Encoder’s output (features).
To produce the context vector:
- First, we score each of the Encoder’s outputs (features) passed to the attention network with the scoring function.
- Then we get the probabilities, applying softmax function to the scores. These values express the relevance of each feature vector that we input to the Decoder.
- Calculate the weighted sum, multiplying features by corresponding probabilities.
I highly recommend this blog to learn in-depth about different types of attention⁵.
The current model employs soft Bahdanau (additive) Attention that learns attention scores using a feed-forward network during the training⁶.
The usage of soft attention for image captioning problem is well-described in “Show, Attend and Tell” paper under the 4.2
section and can be represented schematically as follows.
Details:
- Attention scores
atten_score
are calculated using the feed-forward network (notations can vary depending on the source and differ from the original paper):
- Next, we apply the softmax to calculate probabilities
atten_weights
:
- Finally, derive the context vector:
3. Decoder
Before jumping to a Decoder’s architecture, let’s formalize the image captioning task.
Given the input feature maps X and target captions Y with the length T, the model learns to accurately predict sequence Y, computing the log probability P(Y|X):
The model learns a set of parameters θ* that maximizes the log likelihood of correct sequence¹.
To process the sequence we use LSTM (long-short-term memory) cell that outputs hidden state (short-term memory) h and cell state (long-term memory) c.
Then we feed the hidden state h to a fully-connected layer, followed by softmax in order to compute probabilities for all tokens in a dictionary.
Important: if we train with Cross-Entropy loss, the loss function applies softmax to outputs and performs logarithmic operation afterward.
Thus, just given a hidden state and a previous token the model learns to generate the next token in a sequence.
The step-by-step decoding process is shown below.
Decoding steps:
- Create embeddings from the target captions:
embed = self.embeddings(captions)
. This vector has a size of(batch, t, embed_dim)
, wheret
corresponds to a sequence length.
2. Initialize the hidden states (h,c) using the fully-connected layers h,c = self.init_hidden(features)
:
The initial memory state and hidden state of the LSTM are predicted by an average of the annotation vectors fed through two separate MLPs (init_c and init_h) ².
3. Perform the scheduled sampling at each step during the training. We set the sampling probability as a condition on selecting the next input token ỹ either from the target sequence (so-called teacher forcing) or sample from the output.
The main idea behind the scheduled sampling is to mitigate the gap between training with teacher forcing and inference stage when the target tokens are not available.
If the probability of sampling sample_prob
is higher than random, we choose to sample token from output top_idx
, otherwise, the target token will be passed.
Scaling output with sampling temperature self.sample_temp
amplifies the outputs before applying a softmax. For example, the temperature equal to 0.5
will result in larger output values and lower diversity.
This makes the LSTM pickier, but conservative in its samples⁵.
4. Concatenate the embeddings and context vector into a single input to the LSTM cell:
5. Finally, we apply dropout regularization with p=0.5
to a hidden state h and supply it to a fully-connected layer:
The full code for Decoder:
Next steps
In the second part, I’ll go through insights into the training process (hyperparameters choosing, validation metrics) and caption sampling (greedy, beam search decoding).
Reference:
[1] Samy Bengio, Oriol Vinyals, Navdeep Jaitly, Noam Shazeer. (September 23 2015). Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks.
[2] Kelvin Xu, Jimmy Lei Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhutdinov, Richard S. Zemel, Yoshua Bengio. (April 19 2016). Show, Attend and Tell: Neural Image Caption Generation with Visual Attention.
[3] MD. ZAKIR HOSSAIN, FERDOUS SOHEL, MOHD FAIRUZ SHIRATUDDIN, HAMID LAGA. (October 14 2018). A Comprehensive Survey of Deep Learning for Image Captioning.
[4] Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. (September 1 2014). Neural Machine Translation by Jointly Learning to Align and Translate.
[5] Andrej Karpathy. (May 21 2105). The Unreasonable Effectiveness of Recurrent Neural Networks.
[6] Lilian Weng. (June 24, 2018). Attention? Attention!
[7] Sagar Vinodababu. A PyTorch Tutorial to Image Captioning.
GitHub repository of project: https://github.com/MakarovArtyom/Image-Captioning-with-Attention