Attention mechanism

Heuritech
Heuritech
Dec 5 · 11 min read

Many researchers are interested in « Attention Mechanism » in neural networks. This post aims at giving a high level explanation of what Deep Learning Attention Mechanism is, as well as detailing a few technical steps in the computation of attention.

If you’re looking for more equations or examples, the references give a large number of details, in particular, the review by Cho et al [3]. Unfortunately, these models are not always straightforward to implement by yourself and only a few open-source implementations have been released up to now (note: this post dates from 2015).

Attention

Neural processes involving attention have been largely studied in Neuroscience and Computational Neuroscience [1, 2]. A particularly studied aspect is visual attention: many animals focus on specific parts of their visual inputs to compute adequate responses. This principle has a large impact on neural computation as we need to select the most pertinent piece of information, rather than using all available information, a large part of it being irrelevant to compute the neural response.
A similar idea — focusing on specific parts of the input- has been applied in Deep Learning, for speech recognition, translation, reasoning, and visual identification of objects.

Attention for Image Captioning

Let’s introduce an example to explain the attention mechanism. The task we want to achieve is image captioning: we want to generate a caption for a given image.

A “classic” image captioning system would encode the image, using a pre-trained Convolutional Neural Network that would produce a hidden state h.

Then, it would decode this hidden state by using a Recurrent Neural Network (RNN) and generate recursively each word of the caption. Such a method has been applied by several groups, including [11]:

The problem with this method is that, when the model is trying to generate the next word of the caption, this word is usually describing only a part of the image. Using the whole representation of the image h to condition the generation of each word cannot efficiently produce different words for different parts of the image. This is exactly where an attention mechanism is helpful.

With an attention mechanism, the image is first divided into n parts, and we compute with a Convolutional Neural Network (CNN) representations of each part h1,…,hn. When the RNN is generating a new word, the attention mechanism is focusing on the relevant part of the image, so the decoder only uses specific parts of the image.

On the figure below (upper row), we can see for each word of the caption what part of the image (in white) is used to generate it.

For more examples, we can look at the “relevant” part of these images to generate the underlined words.

Examples of attending the correct object. (Taken from [11])

We are now going to explain how an attention model works, in a general setting. A comprehensive review of attention models applications [3] details the implementation of an attention-based Encoder-Decoder Network.

What is an attention model?

What is an attention model, in a general setting?

An attention model is a method that takes n arguments y1,…, yn (in the preceding examples, the yi would be the hi), and a context c. It return a vector z which is supposed to be the “summary” of the yi, focusing on information linked to the context c. More formally, it returns weighted arithmetic mean of the yi, and the weights are chosen according to the relevance of each yi given the context c.

In the example presented before, the context is the beginning of the generated sentence, the yi are the representations of the parts of the image (hi), and the output is a representation of the filtered image, with a filter putting the focus of the interesting part for the word currently generated.

One interesting feature of the attention model is that the weight of the arithmetic means are accessible and can be plotted. This is exactly the figures we were showing before, a pixel is whiter if the weight of this image is high.

But what is exactly this black box doing? A figure for the whole attention model would be this one :

This network could seem to be complicated, but we are going to explain it step by step.

First, we recognize the input c is the context, and the yi are the “part of the data” we are looking at.

At the next step, the network computes m1,…mn with a tanh layer. It means that we compute an “aggregation” of the values of yi and c. An important remark here is that each mi is computed without looking at the other yj for j is not equal to i. They are computed independently.

Then, we compute each weight using a softmax. The softmax, as its name says, behaves almost like a argmax, but is differentiable. Let’s say that we have an argmax function such that

where the only 1 in the output is telling which input is the max. Then, the softmax is defined by

If one of the

is bigger than the other, then

will be very close to

Here, the si are the softmax of the mi projected on a learned direction. So the softmax can be thought as the max of the “relevance” of the variables, according to the context.

The output z is the weighted arithmetic mean of all the yi, where the weight represents the relevance for each variable according to the context c.

Another computation of “ relevance “

The model presented above of an attentive model can be modified. First, the tanh layer can be replaced by any other network. The only important thing is that this function mixes up c and yi. A version used is to compute only a dot product between c and yi.

Attention Model with an other computation method for relevance.

This version is even easier to understand. The attention model is “softly-choosing” the variable the most correlated with the context. As far as we know, both systems seem to produce comparable results.

Another important modification is hard attention.

Soft Attention and Hard Attention

The mechanism we described previously is called “Soft attention” because it is a fully differentiable deterministic mechanism that can be plugged into an existing system, and the gradients are propagated through the attention mechanism at the same time they are propagated through the rest of the network.

Hard attention is a stochastic process: instead of using all the hidden states as an input for the decoding, the system samples a hidden state yi with the probabilities si. In order to propagate a gradient through this process, we estimate the gradient by Monte Carlo sampling.

A Hard Attention model. The output is a random choice of one of the yi, with probability si

Both systems have their pros and cons, but the trend is to focus on soft attention mechanisms as the gradient can directly be computed instead of estimated through a stochastic process.

Return to the image captioning

Now, we are able to understand how the image captioning system presented before is working.

Attention model for image captioning

We can recognize the figure of the « classic » model for image captioning, but with a new layer of attention model. What is happening when we want to predict the new word of the caption? If we have predicted i words, the hidden state of the LSTM is hi. We select the « relevant » part of the image by using hi as the context. Then, the output of the attention model zi, which is the representation of the image filtered such that only the relevant parts of the image remains, is used as an input for the LSTM. Then, the LSTM predicts a new word and returns a new hidden state hi+1.

Learning to Align in Machine Translation

The work by Bahdanau, et al [5] proposed a neural translation model that learns to translate sentences from one language to another and introduces an attention mechanism.

Before explaining the attention mechanism, the vanilla neural translation model using an encoder-decoder work. The encoder is fed a sentence in English using Recurrent Neural Networks (RNN, usually GRU or LSTM) and produces a hidden state h. This hidden state h conditions the decoder RNN to produce the right output sentence in French.

A model for translation without attention

For translation, we have the same intuition as image captioning. When we are generating a new word, we are usually translating a single word of the original language. An attention model allows, for each new word, to focus on a part of the original text.

The only difference between this model and the model of image captioning is that the hi are the successive hidden layers of a RNN.

Attention model for translation

Instead of producing just a single hidden state corresponding to the whole sentence, the encoder produces hj hidden states each corresponding to a word. Each time the decoder RNN produces a word, it determines the contribution of each hidden states to take as input, usually a single one (see figure below). The contribution computed using a softmax: this means that attention weights aj are computed such that

and all hidden states hj contribute to the decoding with weight aj.

In our case, the attention mechanism is fully differentiable and does not require any additional supervision, it is simply added on top of an existing Encoder-Decoder.

This process can be seen as an alignment because the network usually learns to focus on a single input word each time it produces an output word. This means that most of the attention weights are 0 (black) while a single one is activated (white). The image below shows the attention weights during the translation process, which reveals the alignment and makes it possible to interpret what the network has learned (and this is usually a problem with RNNs!)

Word alignment in translation with an attention model. (Taken from [5])

Attention without Recurrent Neural Networks

Up to now, we only described attention models in an encoder-decoder framework (i.e. with RNNs). However, when the order of input does not matter, it is possible to consider independent hidden states hj.

This is the case for instance in Raffel et Al [10], where the attention model is fully feed-forward. The same applies to the simple case of Memory Networks [6] (see next section).

From Attention to Memory Addressing

NIPS 2015 hosted a very interesting (and packed!) workshop called RAM for Reasoning, Attention, and Memory. It included works on attention, but also the Memory Networks [6], Neural Turing Machines [7] or Differentiable Stack RNNs [8] and many others. These models all have in common that they use a form of external memory in which they can read (eventually write).

Comparing and explaining these models is out of the scope of this post, but the link between attention mechanism and memory is interesting.

In Memory Networks, for instance, we consider an external memory — a set of facts or sentences xi - and an input q.

The network learns to address the memory, this means to select which fact xi to focus on to produce the answer. This corresponds exactly to an attention mechanism over the external memory. In Memory Networks, the only difference is that the soft selection of the facts (blue Embedding A in the image below) is decorrelated from the weighted sum of the embeddings of the facts (pink embedding C in the image). In Neural Turing Machine, and many very recent memory based QA models, a soft attention mechanism is used.

Memory Network (Taken from [6])

Final Word

Attention mechanism and other fully differentiable addressable memory systems are extensively studied by many researchers right now. Even though they are still young and not implemented in real-world systems, they showed that they can be used to beat the state-of-the-art in many problems where the encoder-decoder framework detained the previous record.

At Heuritech, we became interested in the attention mechanism a few months ago and organized a workshop to get our hands dirty and code encoder-decoder with an attention mechanism. While we do not use attention mechanism in production yet, we envision it to have an important role in advanced text understanding where some reasoning is necessary, in a similar manner as the recent work by Hermann et al [9].

Léonard Blier and Charles Ollion

Note: this post has been writing at the end of 2015.

Acknowledgments

We thank Mickael Eickenberg and Olivier Grisel for their helpful remarks.

Bibliography

[1] Itti, Laurent, Christof Koch, and Ernst Niebur. “ A model of saliency-based visual attention for rapid scene analysis. “ IEEE Transactions on Pattern Analysis & Machine Intelligence 11 (1998): 1254–1259.

[2] Desimone, Robert, and John Duncan. “ Neural mechanisms of selective visual attention. “ Annual review of neuroscience 18.1 (1995): 193–222.

[3] Cho, Kyunghyun, Aaron Courville, and Yoshua Bengio. “ Describing Multimedia Content using Attention-based Encoder-Decoder Networks. “ arXiv preprint arXiv:1507.01053 (2015)

[4] Xu, Kelvin, et al. “ Show, attend and tell: Neural image caption generation with visual attention. “ arXiv preprint arXiv:1502.03044 (2015).

[5] Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. “ Neural machine translation by jointly learning to align and translate. “ arXiv preprint arXiv:1409.0473(2014).

[6] Sukhbaatar, Sainbayar, Jason Weston, and Rob Fergus. “ End-to-end memory networks. “ Advances in Neural Information Processing Systems. (2015).

[7] Graves, Alex, Greg Wayne, and Ivo Danihelka. “ Neural Turing Machines. “ arXiv preprint arXiv:1410.5401 (2014).

[8] Joulin, Armand, and Tomas Mikolov. “ Inferring Algorithmic Patterns with Stack-Augmented Recurrent Nets. “ arXiv preprint arXiv:1503.01007 (2015).

[9] Hermann, Karl Moritz, et al. “ Teaching machines to read and comprehend. “ Advances in Neural Information Processing Systems. 2015.

[10] Raffel, Colin, and Daniel PW Ellis. “ Feed-Forward Networks with Attention Can Solve Some Long-Term Memory Problems. “ arXiv preprint arXiv:1512.08756 (2015).

[11] Vinyals, Oriol, et al. “ Show and tell: A neural image caption generator. “ arXiv preprint arXiv:1411.4555 (2014).

Heuritech

Heuritech is a cutting-edge artificial intelligence company that provides fashion brands with predictive analytics on trends. Read our Tech and Fashion blog.

Heuritech

Written by

Heuritech

Heuritech

Heuritech

Heuritech is a cutting-edge artificial intelligence company that provides fashion brands with predictive analytics on trends. Read our Tech and Fashion blog.

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade