Show, Attend and Tell
Intro
If you are interested in reproducing some Computer Vision papers one of the good choices “Show, Attend and Tell: Neural Image Caption Generation with Visual Attention” (Xu et al., 2016). It’s not completely basic like cs231n assignments but it’s also much easier than many SOTA models.
There’s a good tutorial and working (almost) code that I’m using — see here. This code is based on the original repo published by the authors of this paper. The only issue — there’s still quite a gap between cs231n and this paper. The tutorial and comments in code are useful but probably not sufficient. So the goal of this post is to provide a bit more details. See this Github repo for all the files.
01 Can we run the model?
Suppose we have an image and want to run the pre-trained model in just a few lines of code to appreciate how it works. It’s better to be sure that code is actually working and producing a sound result before moving further. See the file 01_run_model.ipynb
.
02 What files do we need?
We have the following groups of files:
- Data processing:
datasets.py
andutils.py
. - Models:
models.py
. - Captioning:
caption.py
.
We don’t consider other files in this tutorial. In particular we don’t consider files for training the model. See details in the file 02_files_description.md
.
03 Data processing
We are using small flickr8k
dataset for our purposes. We usetrain/val/test
splits by Andrej Karpathy from so called karpathy_json
. This file contains a dictionary for each image with filename
and captions
(both tokenized and raw; usually 5 captions per image).
Details of data processing are in the file 03_data_processing.ipynb
.
03–1 Auxiliary files
Before we can create CaptionDataset
we have to build:
word_map
— dictionary{word: index}
. This is a traditional approach in NLP. To use a token as an input for LSTM we first need anint
index for it and then we use it to get an embeddings vector. To build it we first iterate over all tokenized captions to createword_freq
. We then use hyper-parametermin_word_freq
.- Special
hdf5
files that we are using inCaptionDataset
. Using this file is just one of the possible options. We could just read image files directly from the disk as usual. We create them for each split:6000/1000/1000
images fortrain/val/test
. This file contains images that reshaped to(256, 256)
. json
files with captions and their length, again for each split. That’s a standard approach inpytorch
is using captions of different length for efficiency . So we don’t discuss it here.
03–2 CaptionDataset
As long as we have files described above we may create our custom dataset:
- The main idea behind it: in
__get_item__()
we supply next caption, not next image as usual. Why is that? We have 5 captions per image so to supply a unique pair we need first choose a caption and then the corresponding image. - All captions has the same length — 52; use
<start>
and<end>
tokens; padded with<pad>
token (with index zero). Captions are encoded intoint
indicies usingword_map
. - We use standard normalization parameters for pre-trained models (see here). For some reason converting to tensor is done manually in the code. We use transforms on a per image base as usual.
- We return 3 values (not 2 as usual) from this method:
img, caption, caplens
. Again we use captions of different length for efficiency.
04 Encoder / Decoder model
There are quite a few questions about this model so lets go step-by-step.
04–1 Encoder
We use a pre-trained model resnet101.
See the file 04_model_encoder.ipynb
.
The main question:
- How do we change last layers? We remove last
fc
layer and change poolingoutput_zise
from(1, 1)
to(14, 14)
. This is a key change that we need for our attention module (see Attention description). The size(14, 14)
is from the paper (as you may see on the picture above).
There are also some technical questions:
- How do we change these layers? It seems the approach used in the code is quite popular (see here). We are creating list of children of the first level, remove the last 2 layers and wrap them up into
Sequential
module. It’s not very elegant for a few reasons. I’d prefer another approach — see here. In this case we change the pooling layer directly:resnet101.avgpool = ...
. We also change the lastfc
to identity. - Can we use pooling with
(14, 14)
? It turns out that we have(8, 8)
output from CNN layers, so we rather have unpooling. I’m not sure that’s a good idea. Maybe we should use some other CNN or take features from upper layers.
04–2 Decoder
Decoder is pretty involved and requires a detailed analysis that is in the file 04_model_decoder.ipynb
. Here we mention some key ideas behind it. Let's first look at the steps of forward()
method:
- Step 1. We get from a data loader images, captions and captions length. We then supply images to CNN encoder to get
encoder_out
. We slightly modify it at this step so we may use it in attention. - Step 2. We embed our captions and create an initial state.
- Step 3. We use
LSTM_cell
so we have to loop over the length of captions. a) We supplyencoder_out
and previous state toAttention
module to get a context vectorattention_weighted_encoding
. b) We runLSTM_cell
on combined input (embeddings
— embedded captions,encoder_out
after attention and hidden states), c) We use a linear projection of a hidden state to the vocabulary space to getpredictions
.
Probably the most interesting here: creating an initial state and using a combined input to LSTM
. Using a combined input we may see in the formulas (1)-(3) in the paper (see detailed explanation in the notebook). We can also read in the paper an approach for creating an initial state (and we’re doing exactly this):
The initial memory state and hidden state of the LSTM are predicted by an average of the annotation vectors fed through two separate MLPs.
At this stage to get predictions
we use a projection from hidden space to vocabulary space. All the logic for sampling a caption (including softmax
and BEAM
search) is incorporated in caption_image_beam_search()
(file caption.py
).
04–3 Attention
In the paper they mention that incorporating of the attention mechanism (see details in 04_attention.ipynb
):
is inspired by recent success in employing attention in machine translation (Bahdanau et al., 2014)
They also mention that they closely follow this paper:
There has been a long line of previous work incorporating attention into neural networks for vision related tasks. In particular however, our work directly extends the work of Bahdanau et al. (2014); Mnih et al. (2014); Ba et al. (2014).
To compute attention we need 3 steps (that are similar between 2 papers):
- Compute alignment scores using an alignment model (which is basically yet another neural net). We use as an input image features from Encoder CNN and hidden state from Decoder LSTM. This is the main idea behind attention — try to create some scores that are relevant for this particular moment in caption generation.
- Normalize them using softmax to get attention weights α.
- Finally build a context vector as a weighted sum of image features from Encoder using attention weights. As mentioned above we use this context vector together with the caption as an input to our Decoder LSTM. We don’t consider the difference between hard and soft attention from the paper here.
The code is quite straightforward if we understand those 3 steps above.
05 Caption generation
Caption generation incorporates BEAM search and quite tricky for this reason. We explain it in great details in 05_caption_gen.ipynb
. We’re talking about caption_image_beam_search
in caption.py
.
I’d suggest to start from printing some variables on a simple example (see the notebook). This let you familiarize yourself with lots of variables in this function.
First of all there’s a very good video about BEAM search by Andrew Ng.
Suppose for simplicity k=3
. We need to do the following for the BEAM search:
- Keep tracking of 3 generated captions (sequence of words) — we use
seqs
for this. At each step we have to produce 3 words with top scores — we usek_prev_words
. For each such word we should keep a score — we usetop_k_scores
.
2. We feed k_prev_words
into LSTM cell and produce vocab_size
scores for each of 3 previous words (see picture above).
- Here’s a tricky point here — we basically use
k
dimension as abatch_size
.
3. We have to choose next 3 best words:
- We need to add scores for these previous words:
# add scores of the previous words to generated score
# top_k_scores has a shape (3, 1)
# so we can't use broadcasting on the right
# (3, vocabulary_size)
scores = top_k_scores.expand_as(scores) + scores
- We also need to choose 3 max scores out of this updated scores. We just flatten
scores
to(3 * vocab_size, 1)
and take the max.
# unroll and find top scores, and their unrolled indices
top_k_scores, top_k_words = scores.view(-1).topk(k=k, dim=0, largest=True, sorted=True)
- Then we can compute actual indices using modulo division:
# we choose top_k_words from range(3 * vocab_size)
# so we need to get a number from range(vocab_size)
next_word_inds = top_k_words % vocab_size
- Here’s the first tricky point — we use only scores of a previous word and the next one, not accumulated score over the sequence.
- Here’s the second tricky point — we can choose for example 2 words from the same previous word (Jane — is, Jane — visits in the lecture).
- Now we need to update
seqs
with those newly chosen words:
seqs = torch.cat([seqs[prev_word_inds],next_word_inds.unsqueeze(1)], dim=1)