Captioning Images with CNN and RNN, using PyTorch

One of the most impressive things I have seen is the image captioning application of deep learning. I have wanted to implement one myself from scratch to dwell deeper into the architecture details. Here I am trying to describe the general algorithm behind the automatic image captioning and to build the architecture, using my favorite deep learning library — PyTorch.

Problem Statement.

First we need to state the problem in order to solve it. The statement is pretty straight forward:

Given an image, we want to obtain a sentence that describes what the image consists of.

Now, when we defined what we want our model to do, we can see that our model should take a batch of images as its input and produce a batch of sentences as the output. Neural networks is a perfect machine learning family to perform such a task.

Dataset.

I am going to use the COCO (Common Objects in Context) dataset for training the model. COCO is a commonly used dataset for such tasks since one of the target family for COCO is captions. Every image comes with 5 different captions produced by different humans, hence every caption is slightly (sometimes greatly) different from the other captions for the same image. Here is an example of a data point from the COCO dataset:

Point in the COCO dataset

COCO is a great dataset for other challenges such as semantic segmentation or object detection due to the provided targets and diversity of the data.

Model Definition.

So we established that out inputs are going to be images and out outputs are going to be sentences. We can think of sentences as sequences of words. Luckily, sequential models can help us to process sequences of words (or characters, or other sequential data such as time-series).

Hence, we can map our data from the image space into some hidden space and then map that hidden space into the sentence space.

Here is an overview of the architecture we will be using to build our caption model:

Caption model architecture overview

As you can see in the figure above we are going to break our whole model into encoder and decoder models that communicate through a latent space vector. It is easy to think about this type of architecture as functions, i.e. we map an image to some intractable latent space via encoding and map the latent space representation of the image to the sentence space via decoding.

Encoder.

We are going to use a convolutional neural network to encode our images. It is crucial to understand that we can perform transfer learning, using a network pretrained on the ImageNet dataset. However, given that the data-generating distributions of ImageNet and COCO are different, performance of the whole model may be subpar. Hence, it is recommended to perform transfer learning if you act under tremendous resources constraints.

Dense block: arXiv:1608.06993v5

For this particular model I have used the DenseNet121 architecture in the encoder. It is relatively light weighted and very well performing architecture for computer vision applications; however, you can use any other convolutional architecture to map the images to the latent space. The dense net I used wasn’t pretrained to avoid the shift between the data-generating distributions.

We can easily import the model in PyTorch, using the torchvision’s models package. Even though the imported dense network isn’t pretrained it still comes with the classifier for the ImageNet dataset which has 1000 classes. Luckily, it is very easy to replace the classifier in PyTorch. I have replaced the classifier of the network with a two-layer perceptron with a parametric ReLU activation function and applied dropout to reduce overfitting. You can find the implementation of the encoder in PyTorch below:

One detail you should pay attention to is the output dimensionality of the encoder network. Notice that the network results in a 1024-dimensional vector in the latent space, which we will feed as the first input to our LSTM model (at time t=0).

Decoder.

LSTM cell.

Ok, up until now everything is pretty straightforward: we have an image and we pass it through a slightly modified densely connected neural network to obtain a 1024-dimensional output vector. Decoder part of the architecture is where thing can get messy and difficult to debug and understand, but let’s give it a shot.

As I showed in the overview of the architecture, decoder consists of a recurrent neural network. We could use Gated Recurrent Unit or Long Short-Term Memory unit, in our particular case I have used the latter. It is crucial to mention that there is still a lot of debate going on which recurrent cell is better GRU or LSTM. Both have shown to be working very well in numerous applications with mixed performance gains. The differences of GRU and LSTM are beyond this article but this paper sums up some empirical results.

Ok, back on track. An LSTM cell has long and short term memory (duh) as the name implies. Here is a high level anatomy of an LSTM cell:

  • LSTM cell has an input for a data point (at time t=n)
  • LSTM cell has an input for a cell state (previous cell state)
  • LSTM cell has an input for a hidden state (previous hidden state)
  • LSTM cell has an output for a cell state (current cell state)
  • LSTM cell has an output for a hidden state (current hidden state)
  • LSTM cell’s data output is the LSTM cell’s hidden state output
LSTM cell

I am not going to throw a bunch of mathematical equations in your face since they are not important for implementation tasks but if you are interested in the detailed anatomy of an LSTM cell refer to the original paper. What is more important to understand for implementation is that for every step in the sequence we use exactly same LSTM (or GRU) cell, so the goal of the cell under optimization is to find the right set of weights to accommodate the whole dictionary of words (characters in char-to-char models). This means that for every word in our sentence (which is a sequence) we are going to feed the word as input and get some output which is typically a probability distribution over the whole dictionary of words. This way we can obtain the word that the model thinks fits the most given the previous word.

It is important to understand that the practitioner should choose the dimensionality of the input, hidden and cell states, hence these will be a factor in our model.

Dictionary/Vocabulary.

Now what the hell is dictionary and why the hell do we need one? I am sure most of the readers already know the answer to this questions, however, let me elaborate for people who are just starting out with the sequence models. The point is: sequence models (actually models in general) do not understand the symbolic language, i.e. images have to be represented as tensors of real numbers in order for a model to be able to process them since neural networks are a number of parallel (vectorized) computations with non-linearities in between (and other whistles once you dig deeper). Converting images to the language that models understand is pretty easy and the most common approach is to take intensity of every pixel which are represented in real numbers. Luckily, there is a way to turn the words into this language too.

Now, we know that there is limited number of meaningful words that our data-generating distribution can generate as a part of target sentences. So what we want to do is to take every single word that is presented in all the captions in our training dataset and enumerate it to obtain a mapping between the words and integers. We are half way there to start using the words in our decoder. We could now build a 1-to-K mapping (typically using a single layer perceptron) between the integer representations of the words to a K-dimensional space that we can use as an input to our LSTM cell. However, there is a better way to do this; we can obtain the embeddings of the integer representations using a built in embedding layer in PyTorch. Please refer to this great PyTorch tutorial for more details.

We can now see that every word is going to be embedded in a higher dimensional real number space with which we can operate to handle the recurrent neural network. Embeddings are also useful in other natural language processing applications as they allow the practitioner to examine the word or character manifold once it is mapped to a 2-dimensional space, typically using the T-SNE algorithm.

Teacher Forcer.

Here is a common scenario for usage of recurrent networks:

  • Feed a <start> (start of a sentence or a word) token to the LSTM cell as an input at time t=0.
  • Get a vector of vocabulary size as the output of the LSTM at time t=0.
  • Find the index of the most probable character (word) at time t=0, using argmax.
  • Embed the most probable character (word).
  • Pass the resulting embedding as the input to the LSTM cell at time t =1.
  • Repeat until <end> token is obtained as the output of the cell.

To sum up the aforementioned algorithm, we pass the most probable word or character as the input to the LSTM cell in next time step and repeat the procedure.

However, deep learning practitioners came up with an algorithm called teacher forcer and, in most cases (where applicable), it helps the convergence of the recurrent neural network. It is essential to remember that we have the whole caption (sentence) available to us as a target, not just part or a single word.

The teacher forcer algorithm can be summed up as follows:

  • Feed a <start> (start of a sentence or a word) token to the LSMT cell as an input at time t=0.
  • Find the index of the most probable character (word) at time t=0, using argmax.
  • Feed the next token (next embedded word from our target) to the LSMT cell as an input at time t=1.
  • Repeat until <end> token is obtained as the output of the cell.

Notice that we don’t feed the last most probable word anymore, we feed the already available to us next word embedding.

Putting The Decoder Together.

Here is a general overview of the LSTM structure for the first 2 time steps:

Notice that we pass the latent vector as the input at the first time step.

There are a lot of nuances in the decoder part of the model and I will try to be as thorough as possible.

First, we are going to use a single layer LSTM to map the latent space vector to the word space.

Second, as I have mentioned earlier, the output of an LSTM cell is the hidden state vector (shown in purple in the LSTM cell diagram). Hence, we will need some kind of mapping from the hidden state space to the vocabulary (dictionary) space. We can achieve this by using a fully connected layer between the hidden state space and the vocabulary space (line 14 in the gist).

Forward pass is pretty straightforward if you have at least some experience with the recurrent neural networks. If not, I find it very useful to read over the PyTorch documentation or tutorials to understand what kind of dimensions the LSTM cell expects for hidden state and cell state and input. In fact, data preparation and wrangling takes 90% of building such models.

The key idea here is to feed the latent space vector that represents the image as the input to the LSTM cell at time t=0. Beginning at time t=1 we can start feeding our embedded target sentence into the LSTM cell as a part of the teacher forcer algorithm.

Training.

Training loop is very basic:

Notice that, even though we have two model components, i.e. encoder and decoder, we train them jointly by passing the output of the encoder, which is the latent space vector, to the decoder, which, in turn, is the recurrent neural network.

I trained the model on NVIDIA GTX 1080Ti with batch size of 48 for 3 epoch, which took around 1 day. After 3 epochs the results of the model were already pretty good, signifying the convergence of the model.

Results.

Here are some results from running the model on the validation part of the COCO dataset:

And here are some of the captions to the photographs from my Facebook:

My graduation from the University of Texas, my family in Russia and my old apartment outside area.

It is worth mentioning that the sampling step can be implemented, using Beam Search for a better diversity of captions. Also, there are attention mechanics that might help to form better captions as the attention mechanism pays different degrees of attention to different parts of the image. I hope to write more about the attention models in the future.

As I already mentioned, there are a lot of nuances and many ways to implement this kind of model; however I hope that this article will make you interested in sequence models and give you a little head start for your project.

Some of the code for data wrangling was taken from Udacity, however you can use the PyTorch built in utilities to pack the sequences of words manually.

Hi, how are you? Data Scientist | SWE @ Apple