The rise of Attention in Neural Networks

Elena Fortina
Analytics Vidhya
Published in
32 min readJun 2, 2020

--

You may have noticed that something very special has been going on in Natural Language Processing in the last couple of years. First, Google Translate has reached unprecedented translation quality. You may also have bumped into some AI-written stories about unicorns found in the Andes.

What happened? Attention mechanisms came into play!

Attention-based architectures currently lay at the basis of state-of-the-art NLP models. Their key idea is to mimic the human ability to selectively pay attention to certain parts of sentences when performing classification or prediction tasks. This post will discuss the origins of attention mechanisms, their basic underlying principles and the most important attention-based models for NLP (the Transformer, BERT and gpt-2).

This is meant to be an introductory post aimed at those who are approaching the topic for the first time and would like to grasp the main ideas before deep diving into scientific articles or advanced Tensorflow tutorials. Some of the explanations are intentionally a bit simplified to facilitate intuition.

The origins of Attention

In order to explain how attention was born, we will start by discussing a related and more widely known concept: memory.

The notion of memory was first introduced in deep learning to enable modeling of sequential data - i.e., for all those cases where data have an intrinsic order and it is necessary to somehow remember the past when processing newcomers. This includes natural language and time series.

The most famous networks based on the idea of memory are recurrent neural networks, which have already been around for a few decades. Below is a typical recurrent neural network architecture for a text classification task - in this case, classifying newspaper article titles (e.g. “Juventus forward Ronaldo defies gravity with jump”) into one of the three topics Football, Economics or Fashion:

Recurrent Neural Network for text classification

The green cells displayed in the picture are actually just the same cell (with the same parameters) represented multiple times. This is called the “unfolded” network representation and is supposed to facilitate visualization and understanding of the underlying process. Sequence elements (words, in this case) are processed by the cell one by one, in sequential order. The cell has its own internal memory, encoded by a vector hᵢ (also called hidden state), whose elements are initialized at 0 and then updated every time a new word is processed.

In this unfolded schema, all intermediate memory updates are represented explicitly: h is the memory up to word i. Therefore, hₙ is the whole sequence memory, or, said otherwise, a compressed synthetic representation of the full sentence.

If the cell mechanisms are properly designed and optimized, as in the case of LSTM and GRU cells, the memory vector will be updated efficiently to remember the relevant parts of the input sequence and forget the irrelevant ones. In the example below, the memory will clearly remember Juventus and Ronaldo, but forget jump and gravity.

Relevant, of course, is relative to what the network is trained to do. If the target had been some other weird task like predicting the body part involved in the action described (legs, arms, head, abdominals…), the memory would instead have focused on the word “jump”.

Traditionally, the last memory hₙ was considered to be the only one with dignity. Intermediate memories would just be thrown away and only the last memory would be passed on to subsequent layers to perform prediction or classification tasks.

The bottleneck: translation

Up to about 2015, recurrent neural networks were also state-of-the-art for most NLP tasks, including translation.

The architecture was analogous to that for topic classification¹ ². Again, the input sentence (the one to be translated) would be passed through the network one element at a time and encoded in the final memory / sequence representation h. In fact, the network was called Encoder.

h would then be passed to another complex structure, the Decoder, which would generate the translation one word at a time.

hₙ must remember a lot of things: destination, travel period, traveler’s name…

We will not discuss the details of the Decoder architecture this time (by the way, it is just another recurrent neural network). The important point is that all the Decoder could see about the initial sentence was this single synthetic representation h.

Of course, this made translation quite a hard job for the Decoder, especially when texts started growing in length. For example, there was no way for the Decoder to dynamically focus on specific parts in the original sentence while generating the translation - which is what a human would do and looks like a more sensible and efficient approach.

However, in 2015, someone got an idea: “Why not pass all intermediate memories to the Decoder?” ³.

This was a quite straightforward way to allow the Decoder to focus on specific parts of the input when generating specific parts of the translation: exactly what we had been looking for. This was the first form of attention in neural networks: nothing more than a clever trick added on top of a recurrent neural network to facilitate it in the translation task.

A detailed description of this new architecture and a Tensorflow tutorial can be found here.

Contextualized word representations

The reason why this new construction works is that the h effectively work as contextualized word representations. In fact, because of the way they are constructed in this architecture, we can expect that the hᵢ will primarily encode (or remember) word i plus something about what happened right before — i.e., plus some context.

To see this, consider first what happens at the lower layer of the network. Words must be converted to vectors prior to being passed to the recurrent cell (of course, the cell cannot process bare words!). This transformation is performed via a simple Embedding layer that maps each word (or, better to say, its one-hot encoding) to a continuous vector separately. The parameters of the transformation are learnt during training.

Being mapped separately, words do not see each other in this phase. This means that their initial representations (schematized as dark green rectangles in the picture below) are context-free, just like word2vec or other similar embeddings.

Then, essentially, what the recurrent neural network accomplishes is combine these initial context-free word representations into context-aware ones:

The search for context-aware representations is indeed a crucial point in machine translation. For example, in the two sentences

This is a river bank

This is a business bank

you can only disambiguate the meaning of “bank” by looking at the previous word. Using the recurrent structure, which looks behind at each step, the vector representation of “bank” can effectively be adjusted depending on the sentence. In this way, the Decoder receives a disambiguated representation of “bank” and is facilitated in directing it towards the appropriate translation.

But what about those cases when we instead need to look ahead in order to disambiguate a word?

I arrived at the bank after crossing the river

I arrived at the bank after crossing the road

We can just use a bidirectional network, and everything works fine. The final word representations are obtained by concatenating the hidden states from the two networks:

Using a bidirectional RNN, words can be disambiguated by looking at other words in both directions (behind and ahead).

And if a single layer is not enough to obtain satisfactory representations, we can build a more complex architecture by stacking multiple recurrent layers on top of each other:

A bidirectional RNN with two stacked layers. The hidden states of the first layer are passed as input to the second layer.

Limitations of the recurrent architecture

The addition of attention mechanisms gave a boost to the translating capabilities of recurrent neural networks. However, the recurrent architecture was still inherently limited in its power to connect related words and disambiguate efficiently. What if related words were many positions away from each other, as in this case?

If two connected words are many steps away from each other, a RNN will have difficulties remembering and encoding the relationship. If only a direct connection existed…

Even if a bidirectional network is used to capture context from both sides, the network memory will still have a hard time trying to remember that “it” was indeed connected to “museum”. It would certainly be more efficient to open a direct connection between these two words, but how to do this without giving up the sequential architecture?

Actually, giving up the sequential architecture would be desirable here. Being forced to process the sequence in order implies lack of parallelizability, which dramatically increases training time and therefore poses severe limitations on the amount of documents that can be used for training. But most NLP models need tons of training documents to learn all subtleties and nuances of language, so this is a critical point.

The sequential architecture is a burden, but seems nevertheless necessary: how to otherwise keep into account the order of words in the sentence? Of course, order matters…

The lion chased the zebra and ate it

The zebra chased the lion and ate it

An order-agnostic (or bag-of-words) model can sometimes be ok for text classification or sentiment analysis, but will just not work for translation tasks!

Positional encodings

Finally, someone came up with the idea of using positional encodings (or positional embeddings). The key fact is that information about the position of a word in a sentence can effectively be encoded in a vector — and, crucially, a low-dimensional one.

One may instinctively think that a vector of 10000 components is necessary to encode positions 1 to 10000. Actually, a lot less are required if we use the following formula:

Formula to calculate the jth component of the vector encoding position i

See this link for an intuitive explanation of why it works. In the following plot, rows correspond to progressive positional encodings for positions 1 to 10000 and are represented via a heatmap. Vectors have just 512 components:

Heatmap displaying positional encodings of 512 components (one per row) for positions 1 to 10000

Zooming in to the first 10 positions, we can appreciate how the encodings very gradually change as the position increases:

Same as above, zooming on positions 1 to 10

Why are these positional encodings so disruptive? By e.g. concatenating them to the initial (context-free) word embeddings, we can get rid of the sequential architecture and open up to all possible connections. The result is an entirely new architecture where the recurrent cells are replaced by an attention layer:

Actually, there is even no need to concatenate. It is sufficient to sum the encodings to the initial word embeddings (of course, the two must have the same dimensionality). See here for a discussion on this point.

In the new architecture, each word looks at all others when constructing its own context-aware representation and there are direct links between each pair of words. Moreover, training is highly parallelizable.

RNN with attention VS Attention-only network: global architecture comparison. This is a simplified representation with one single layer; in practice, multiple layers are stacked on top of each other in both architectures.

As with RNN, multiple attention layers are normally stacked on top of each other, the output representations of each layer being passed as input to the subsequent layer.

Attention Layer: how does it work?

Thanks to this innovation, attention became no longer just an additional feature of RNNs, but a new and independent architecture. But how does it all work exactly — i.e., what happens inside the green cells in the attention layer?

As usual, the first step is to turn words into initial context-free vector representations via a simple Embedding layer (the results are represented as dark green rectangles in the above picture). These initial, context-free representations will intuitively look like this:

This is a simplified 2-dimensional representation, but you can actually expect to see something similar if you e.g. take pre-trained word2vec embeddings and project them onto a 2-dimensional subspace using PCA (word embeddings typically have several hundreds of dimensions).

Bank is somewhere in the middle between money and river, perhaps a bit more inclined towards money, since that is its most frequent use. The attention layer (the light-green cell) must learn to dynamically transform the representation of “bank” (for example) according to the other words in the sentence, just like the RNN already did:

This is indeed very similar to how disambiguation works in the human mind. RNNs achieved this using their own internal mechanisms, which will not be discussed here. What about the new attention layer?

We illustrate the functioning of attention mechanisms by focusing on a single sentence, “I am on the river bank”, and seeing what happens to the word “bank” (the same happens to all other words in the sentence).

First, all words go through the Embedding layer separately and are mapped to their initial context-free representations, including “bank”.

Then, the following steps are performed:

STEP 1: for each word in the sentence, calculate how related it is to “bank”. “Related” means that it can somehow influence or help disambiguate / better understand the meaning of bank, or vice versa. We will see in the section Mathematical details how “relatedness” can be measured quantitatively.

In this case, river is the only word in the sentence that is significantly “related” to bank

STEP 2: once all “relatednesses” have been calculated, calculate the attention scores of bank towards each of the other words. Attention scores are higher for words that are highly related to “bank” compared to the other words in the sentence.

Attention between the same two words may therefore change depending on the particular sentence. Consider the following examples:

In the first sentence, “river” is the only word highly related to “bank”, so the attention of “bank” is completely directed towards “river”. In in the second sentence, instead, “laying” is also highly related to “bank” (one would lay on a bank = river bank, not on a bank = business bank), so the attention of “bank” is partially directed towards “river” and partially towards “laying”. This implies that the attention score of “bank” towards “river” is higher in the first than in the second case. Attention is a finite quantity that must be distributed among all related words.

STEP 3: move the representation of “bank” closer to those of the words with highest attention score. What happens physically is something like this (brackets [ ] denote the initial, context-free word representations):

If instead the sentence had been “I put ten millions in bank”,

I.e., the final representation of the target word (bank, in this case) is a weighted average of all the other words in the sentence (including the word itself), with weights provided by the attention scores. These scores are constructed so that they lay between 0 and 1 and sum up to 1.

The same process is repeated for all other words in the sentence (actually, all contextualized representations are computed at once using matrices; see Mathematical details).

An important aspect of attention is that it is not symmetric. Consider, for example, the words “river” and “bank” in the following sentence: “The water was flowing on the river bank due to the floods”.

Here, there are many words that are highly related to “river”: besides “bank”, we have “water”, “flowing” and “floods”, which indicate that the river is a swollen river, not a dry one. This may be relevant in translation (there may exist some exotic language where there is a single and specific word for “swollen river”…). Conversely, “bank” is only directly related to “river”. “Water”, “floods” etc., by themselves, add little to our understanding of the meaning of “bank”.

This implies that “bank” will pay a lot of attention to “river”; conversely, “river” will not pay so much attention to “bank”, since its attention will be distributed among all related words.

Attention visualized

The attention scores calculated in STEP 2 can be stored in a matrix whose entry ij is the attention that word i pays to word j — which can be seen as a measure of how relevant word j is to disambiguate / better understand the meaning of word i. The matrix of attention scores is typically visualized with a heatmap (lighter squares indicate higher attention scores):

But there is still one open point: how can the model perform STEP 1 — i.e., how does it know when two words are “related”?

The secret is in the computation of the initial, context-free word representations: they must be placed sensibly, so that related words are somehow geometrically close to each other. In other words, a correspondence must be established between geometric and linguistic properties. Intuitively, the left schema would work, the right schema would not:

Of course, we don’t have to hand-craft these initial representations — the Embedding layer will take care of learning them during training.

Unlike attention, “relatedness is symmetric.

Mathematical details

The output of a generic attention layer is given by the following formula (this is the one you will find in all articles; n denotes the embedding dimension):

This formula can be a bit confusing, since it is very general and includes many other cases besides the one we have just examined. In our case, the three matrices coincide, i.e. Q = K = V. Q is a matrix of shape (input_seq_len, embedding_dim) whose rows contain the initial word representations calculated by the Embedding layer (= the dark green rectangles at the lower layer, positional encoding included).

The entries of Q are learnt during training.

The “relatedness” scores in STEP 1 are calculated as

This is a symmetric matrix of shape (input_seq_len, input_seq_len) whose entry ij (or ji) provides a measure of how “related” words i and j are. This measure is just the dot product between the two vector representations. The dot product is equal to the cosine similarity in the case when the two vectors are normalized, but does not provide a proper distance in the general case. Nevertheless, the authors chose not to normalize the rows, perhaps to leave the network some additional flexibility (the Embedding layer is still free to learn normalized vectors, if it deems it appropriate).

In order to turn the relatedness scores into attention scores (STEP 2), you then simply take the softmax by row. Prior to this, you divide the relatedness scores by a fixed normalizing factor in order to obtain a “gentler” softmax (otherwise, only very highly related words would get significant attention scores). The resulting matrix is made of numbers between 0 and 1, with each row summing up to 1.

Non-symmetric matrix of attention scores; shape = (input_seq_len, input_seq_len)

Finally, you multiply the above matrix by Q. The result is matrix of shape (input_seq_len, embedding_dim) whose rows contain the new word representations, each of which has therefore been obtained as a weighted average of the initial word representations.

A few additional remarks:

  • Q’s rows are transformed via a (learned) affine transformation prior to undergoing steps 1, 2 and 3 (or, said otherwise, Q is passed through a Dense layer with linear activation).
  • as already mentioned, multiple attention layers are typically stacked on top of each other. This means that the output of step 3 is transformed via another learned affine transformation and then passed again through steps 1, 2 and 3… And so on.

Much more than semantics

The previous examples may lead to the idea that word representations only capture the words’ meaning in a very strict sense. This would imply that only nouns, adjectives and some verbs are actively involved in the attention process, while particles like articles, conjunctions etc. are to be deemed irrelevant.

Indeed, contextualized word representations model much more complex linguistic properties than basic “word meaning”. This should not be surprising, after all. Consider the following two sentences:

I can’t bear eating meat

I saw a bear eating meat

“Meaningful” words are the same in the two sentences: what makes it clear that bear is “a bear” VS “to bear” is the presence of apparently “meaningless” particles like “a” or “can’t”. Word representations must be able to capture all these types of relations. To help attention layers in this challenging task, multi-head attention was introduced. In practice, this means that multiple word representation systems (typically 8 to 16) are computed in parallel, giving birth to a bunch of independent attention score matrices:

Multi-head self-attention. Source: https://docs.dgl.ai/en/0.4.x/tutorials/models/4_old_wines/7_transformer.html

This is somehow similar to using multiple filters in convolutional neural networks: each attention head is meant to capture a particular linguistic property. In the end, we can still obtain a unique contextualized word representation by concatenating the representations calculated by the different heads.

A world of applications

Summing up, attention mechanisms allow us to very efficiently calculate contextualized vector representations for any word in any sentence. Thanks to positional encodings, we can even include positional information in the word representation without the need for a sequential architecture.

The same mechanisms can work very well every time a finite set of interacting discrete inputs (either sequential or not) must be processed, as in videogames. Google Deepmind recently released a model, AlphaStar, who learnt how to play StarCraft and defeated top human players. The model was first trained on historical game data in a supervised way, then started learning whole new strategies by repeatedly playing against itself within a reinforcement learning framework.

Guess what, attention mechanisms came into play in the architecture definition, modeling relations between the various game elements (characters, buildings, etc). The basic idea is the same as before: assume you want to decide the best move for your character or calculate your victory odds. You can first map your “character” (as well as all other game elements) to a vector using a simple Embedding layer and get an initial context-free representation of it, encoding properties like “I am a hunter, I am humanoid, I have these weapons”. In the attention layer, your character “pays attention” to the other game elements and dynamically adjusts its representation to take context into account. For example, it may notice that a dragon is about to kill it, or that a huge stone is about to fall on its head. The character’s final context-aware representation may encode something like “I am a hunter, a humanoid… And I’m not very well off at all”. Passing this vector to a final regression layer, we can find out that the odds of winning are pretty low in this case. This is a very simplified explanation, but should give an idea of the potential of attention mechanisms.

Cracking translation: the Transformer

But let’s now go back to translation, which was our initial problem. We had left off with RNN-with-attention being state-of-the-art. Can we move any forward, now that we are equipped with powerful attention-only networks?

As we have seen, this new architecture allows us to very effectively disambiguate all words in any sentence and any language. Also, the Embedding layer provides initial context-free word representations that are sensibly placed. If this is the situation, it should be pretty easy to go from, e.g., English to Italian.

However, there is a issue. Consider what happens, for example, it we rotate all Italian representations by the same angle. What we obtain is still a perfectly valid representation of the Italian language: all distances and relationships between words have remained the same. However, there is no alignment whatsoever with their English counterparts. In fact, if two language models are trained independently, we have no guarantee that the resulting representations will somewhat be “aligned”.

So, how to effectively connect these two worlds? This is exactly what the Transformer achieves.

We have this. How to make these two worlds communicate?

The Transformer was introduced in 2017 and immediately became state-of-the-art for translation, setting the basis for the modern Google Translate. It was the first model for translation based entirely on attention mechanisms — in fact, the paper is eloquently titled “Attention is all you need”. Its name stems from the fact that the Transformer, well, transforms sentences from one language to another.

A distinct Transformer model must be trained for each monodirectional input-to-target language translation task: e.g., three distinct Transformer models are needed to perform Italian-to-English, English-to-Italian and French-to-English translation. This tutorial shows a complete Tensorflow / Keras implementation.

Let us examine the architecture of an Italian-to-English Transformer model. We will omit some architectural details (layer normalization, dropout…) to focus on the core components. The Transformer is made of two parts, the Encoder and the Decoder, just like previous models based on recurrent networks. The difference is that the inner mechanisms are now entirely based on attention.

Training the Transformer

During training, pairs of corresponding sentences in the source and target language are fed into the Encoder and Decoder, respectively.

The Encoder performs the exact steps we have just seen: it uses an attention layer (or, more precisely, multiple stacked attention layers) to compute context-aware representations of each word in the source sentence. From a functional point of view, this is exactly what the Encoder part of the RNN-based architecture was doing: the difference is merely in the increased efficiency of the network’s internal mechanisms.

The Decoder does essentially the same as the Encoder with the sentence in the target language, but adds a second attention-based layer that somehow aligns word representations in the target language with those in the source language. We call this additional layer “cross-attention layer” to distinguish it from the “self-attention” layer we have considered so far.

Transformer architecture: the Decoder takes the Encoder’s output as a side input to produce contextualized word representations in the target language that are also somehow “aligned” to those in the source language.

Let us see in detail what happens inside the Decoder by focusing on the word “pizza” in the above example.

First, all words in the target sentence are mapped to initial context-free representations (+ positional encoding) via a simple Embedding layer, including “pizza”.

In the self-attention layer, “pizza” adjusts its own representation according to the other words in the sentence (Joe and likes) applying the mechanisms described in the previous paragraphs. In this case, Joe and likes add little to our understanding of pizza, so let’s say that the context-aware representation of “pizza” (output of the self-attention layer) is just

(we are omitting the preliminary affine transformation for simplicity).

This output is then passed to the cross-attention layer, whose mechanisms are analogous to those of the self-attention layer. Just like the self-attention layer learns to express each word in the sentence in terms of the other words in the same sentence, the cross-attention layer learns to express each word in the sentence in terms of the words of another sentence — in this case, it learns to express English words in terms of the Italian words (or, better to say, in terms of the context-aware representations of Italian words that are provided as output from the Encoder). In this case, we may expect something like

The same process is repeated for all other words in the target sentence. The final Decoder outputs are context-aware representations of the Italian words that are also somewhat “aligned” with the context-aware representations of the English words.

The Transformer: focus on the Decoder’s internal layers

As in the self-attention layers, cross-attention scores can be plotted in a matrix and help us visualize the correspondences between the two languages. Again, we can use multiple attention heads to capture more linguistic properties.

Cross-attention scores for an English-to-German translation task (source: https://docs.dgl.ai/en/0.4.x/tutorials/models/4_old_wines/7_transformer.html)

Mathematical details

Again, let’s dive a bit into the mathematical details. We said that the Decoder learns to “express English words in terms of the Italian words”; let us see how this is accomplished in practice.

The self-attention layer works exactly as in the Encoder. The formula for the output of the cross-attention layer is also the usual one

but, this time, K = V and Q is different. Precisely, Q’s rows contain the context-aware representations of the English words calculated by the Decoder’s self-attention layer, while K’s rows contain the final (context-aware) representations of the Italian words — i.e., the Encoder’s outputs:

If you have a closer look at the formula, you will see that the cross-attention layer actually outputs new representations of the English words as weighted averages of the (context-aware representations of the) Italian words.

As before, both Q and K’s rows undergo learned affine transformations prior to being passed to the cross-attention layer.

Transformer: some results

Let’s have a look at some outputs from Google Translate to see the power of the Transformer model. Words are correctly disambiguated…

… and even pronouns.

Positional embeddings do work!

All great, but… Wait! Isn’t there something we have forgotten? There are still a couple of things to clarify.

What is the target?

First, we should clarify what target the Transformer is trained on. During training, two token words, [START] and [END], are added at the beginning and at the end of each sentence in the target language (actually, the same is done to sentences in the source language, but we can ignore this detail at the moment). The resulting training set is

Then, the Transformer is trained to predict… Well, just the target sentence shifted ahead by one position.

Using the same pictures as before,

Sounds like a trivial task, right?

Based on what we have seen, the Decoder’s internal structure should be the following:

So, does the Decoder trivially do this when predicting, for example, “Joe”?

Actually, no, because this shortcut is made illegal in the Decoder!

Precisely, all backward paths are forbidden. The self-attention layer in the Decoder has a fundamental difference from that in the Encoder: each word can only look behind when constructing its own context-aware representation.

The actual Decoder architecture

This is a form of masked self-attention: some elements cannot pay attention to some other elements, i.e., some connections are closed. This type of masking is called look-ahead masking, in that each word can only “look behind” itself and not ahead. In practice, the “illegal” entries of the relatedness matrix are set to -inf, so that the softmax will effectively ignore those entries. The attention scores matrix will look like this, with some elements obscured. Each word’s attention will thus only be distributed among “legally connected” elements.

If you examine the Decoder target more closely, you will realize that this is nothing but a compact way to train the Decoder to predict the next word for all subsentences in the input sentence:

Instead of passing all subsentences one by one, we can exploit the look-ahead masking trick to just pass the full sentence and do all next word predictions in one single passage.

So, we can say that the Transformer is trained to perform guided next word predictions in the destination language — guided by the sentence in the source language.

Again, multiple pairs of self-attention + cross-attention layers are stacked on top of each other in the Decoder.

A few more words about masking

The concept of masking is actually quite general: in principle, any connection can be declared illegal and excluded from the attention scores computation. This can be applied to other situations besides Natural Language Processing. For example, OpenAI recently released a RL-based model teaching two teams of puppets to play hide-and-seek against each other. If you have a look at the model architecture, here is what you will find:

Similarly to AlphaStar, game elements are first mapped to context-free embeddings, then change their representations by “paying attention” to each other in the self-attention layer. And, look, the self-attention layer is masked! But this is no longer a look-ahead masking: similarly to what would happen in the real world, puppets are simply not allowed to pay attention to elements that are out of their vision cone and line of sight.

Another (a bit less exciting, but still important) form of masking is the padding mask. Both the Encoder and Decoder take fixed-length sequences as input. If you want to pass them sentences that are shorter than this fixed length, you have to first complete them with a series of so-called [PAD] tokens. For example, if the established fixed length is 10 and you want to pass the usual Joe likes pizza to the Decoder, then you have to first transform it to

(remember, you must add [START] and [END] tokens as well!).

But, of course, you don’t want “Joe” and “pizza” to pay attention to the [PAD] token when constructing their own context-aware representations. Therefore, all [PAD] tokens must be masked in both the self-attention and cross-attention layers and in both the Encoder and the Decoder. The padding mask must always be defined and passed as input to the Transformer.

Generating the translation

All right, so the Transformer is trained! Now, how can we generate new translations from scratch for sentences that were not in the training set?

First, the sentence to be translated is passed as input to the Encoder, while the Decoder just takes the [START] token as input. The predicted word is the first word of the final translation. Then, the [START] token and this first generated word are passed to the Decoder, which predicts the second word of the translation. The process is repeated until the [END] token is predicted.

What about monolingual tasks?

The Transformer was born to tackle the problem of translation, but its potential to generalize to monolingual problems was immediately clear. In fact, “disassembling” the Transformer, we obtain the two models that are currently state-of-the art for NLP: BERT and gpt-2.

BERT

BERT is an Encoder-only model and the current state-of-the-art for almost all NLP tasks. (Actually, the original BERT has already evolved into more advanced models, just like the Transformer has already been further developed into the Reformer and gpt-2 into gpt-3. However, the main underlying ideas have remained the same).

Given an input sentence, BERT computes powerful context-aware representations of all words in the sentence using self-attention mechanisms. Exactly what the Encoder part does in the Transformer model. These representations can then be passed to custom additional layers to perform various NLP tasks such as sentence classification, sentiment analysis, question answering, etc.

Actually, BERT (like many other language models) works with tokens and not with words — i.e., input sentences are split into tokens, not into words. If you have a look at the vocabulary file (vocab.txt) of the multilingual BERT model, what you will see is a list of short sequences of characters and symbols in multiple alphabets, from Latin to Chinese, some of which will appear meaningless. This helps drastically reduce the vocabulary size. A contextualized representation is calculated for each token and all principles discussed before are still valid, except that they are applied at a more abstract level.

What task is BERT trained on? This time, what we do is randomly mask out some words in the input sentence (this physically means replacing them with the token word [MASK]) and train the model to re-predict them. Doing this on e.g. all English Wikipedia forces the model to learn nontrivial language structures.

CAUTION: this “masked language prediction” task, i.e. this process of “masking out” and then re-predicting some words in the input text, has nothing to do with the previously discussed concept of masked attention (i.e., look-ahead masking, padding masking…)! Unfortunately, there is some overlapping in the terminology.

There is also a second and slightly more complex task BERT is trained on. We will discuss it later, in the session “Fine-tuning BERT for classification tasks”.

The BERT acronym, Bidirectional Encoder Representations from Transformer, should now be clear: context-aware representations of single words are built by looking at all other words in the text, both behind and ahead the target word, in a bidirectional fashion.

BERT actually does not work with words, but with tokens. If you look at the vocabulary file of a BERT model,

gpt-2

gpt-2 is instead a Decoder-only model. You may wonder: what does it mean Decoder-only? The Decoder as we have defined it makes no sense without the Encoder — it even takes the Encoder’s output as a side input!

Actually, the expression “Decoder-only” is a bit improper. In fact, not the whole Decoder is kept in gpt-2: the cross-attention layer is thrown away. What remains is basically a self-attention layer with look-ahead masking applied.

The target is the same as the Transformer: the input sentence shifted ahead by one position. Thanks to the presence of the look-ahead masking, this is equivalent to a nontrivial next word prediction task. Exactly as in the Transformer, except that, this time, the next word prediction task is free and no longer “guided” by the sentence in the source language.

Again, gpt-2 must be trained on an extremely large corpus, e.g. all English Wikipedia. As you may guess, the crucial difference with respect to BERT is that word representations are no longer bidirectional: they are built by only looking at words behind.

This seems like a weakness, since less context can be incorporated in the final representation of each word. In fact, the very first version of gpt-2, named gpt, was sensationally defeated by BERT on all benchmark NLP tasks. However, gpt-2 (which is essentially gpt with more parameters and trained on a larger corpus) still outperforms BERT in one task: text prediction — which is equivalent to text generation. A hilarious report on the BERT-vs-gpt challenge can be found here (look for the Manager VS Random Engineer discussion).

There are some links where you can try out the full power of gpt-2. The quality of generated texts is really pretty impressive. Seems like giving up bidirectionality is not always a disadvantage, after all!

Fine-tuning gpt-2

The pre-trained gpt-2 model (which you can freely download and try out) tends to generate quite generic text, having been trained on such a large and heterogeneous corpus. However, it is quite straightforward to fine-tune gpt-2 on a corpus of your choice and teach it to generate text of a more specific type or genre. For example, you may fine-tune gpt-2 to generate Shakespeare plays, horror stories or rap songs — all with just a few lines of Python!

This interactive book is also a cool example of what you can achieve by applying fine-tuning to gpt2.

Unfortunately, a pre-trained version of gpt-2 is currently only available in English. However, you can exploit the mtranslate library (which calls Google Translate APIs) to generate high quality text in virtually any language of your choice.

Fine-tuning BERT for classification tasks

While gpt-2 is specialized for text generation, BERT and its descendants are usually the way to go for all other NLP tasks. We will now discuss in deeper detail how to fine-tune BERT for a custom classification task, which is a quite simple example and one of the most common in industrial contexts (think, for example, of FAQ matching).

But, first, we have to say something about the second task that BERT is trained on (besides the “masked word prediction” task). This second task is called next sentence classification. Basically, a training set is built by selecting pairs of sentences from the training corpus (e.g. Wikipedia). Some pairs are made up of consecutive sentences, while others consist of two randomly selected sentences. BERT is then trained to predict which pairs of sentences are consecutive and which are not.

The two sentences in each pair are not fed into the network separately, but concatenated and joined by a [SEP] token word (which is also added at the end of the second sentence). Moreover, an additional [CLS] token word is added at the beginning of the resulting concatenated sentence. Crucially, the final classification is performed using the output representation of the [CLS] token only. In fact, [CLS] is shorthand for “classification”.

In other words, it is only the final (context-aware) representation of [CLS] that is passed on to the binary classification layer. All other representations are basically thrown away for this secondary task (but they are still used for the masked word prediction task).

This architecture can be used with little modifications to fine-tune BERT for single-sentence classification. In fact, it is also possible to pass a single sentence (instead of two concatenated sentences) to BERT, provided that it is encapsulated between [CLS] and [SEP] tokens. The output representation of the [CLS] token can then be passed to a custom classification layer. Typically, a simple Dense(num_classes) layer followed by a softmax will be sufficient, perhaps with the addition of a Dropout layer in the middle.

To conclude, a small Q&A session containing some questions I asked myself when I started playing with BERT, together with the answers I found after a few searches and discussions:

Q: Why use just the [CLS] token for sentence classification? I would be tempted to pass all output word representations to the classification layer.

A: In principle, the BERT architecture has been designed so that the [CLS] token should be able to provide an effective sentence representation (or sentence embedding) on its own. If this is the case, including the other output representations would just add computational burden without increasing performance.

Q: You said the [CLS] token should provide a sentence embedding. Then, if I download a pre-trained BERT model, feed it with a sentence of my choice (encapsulated between [CLS] and [SEP] tokens, as prescribed) and retrieve the output representations of [CLS], I may use it as a sentence embedding and e.g. calculate cosine distances between sentences, right?

A: Unfortunately, the [CLS] token does not seem to work well for this particular use case. Performance is even worse than the notorious baseline of taking the average of the word2vec representations of the single words in the sentence as a sentence representation. This could be due to the fact that BERT has not been trained explicitly on sentence similarity tasks — unlike other models, like Universal Sentence Encoder (USE), which are in fact more suitable for this particular use case. A variant of the original BERT model called Sentence-BERT has also been proposed to improve performance on sentence similarity. In general, one should keep in mind that there exists no universally “best” natural language model: the optimal choice depends on the particular problem under study. When using pre-trained models for transfer learning or other applications, a good practice is to always check which task(s) the original model has been trained on and to verify how aligned they are to the objective task.

Q: What if I instead take all the output word representations (excluding the [CLS] and [SEP] tokens), average them out and use the resulting vector as a sentence embedding?

A: This approach actually works nicely when applied with context-free word representations, like those provided by word2vec (this is a well-known sentence embedding baseline to validate more complex models against). Unfortunately, and perhaps surprisingly, experimental results are not as satisfactory when BERT’s context-aware representations are used. One possible reason is that these representations tend to be very context-aware, i.e., very “polluted” by the surrounding context. The boundaries between single word and surrounding context are actually quite loose and difficult to define in BERT. For the same reason, it can be dangerous to use BERT’s word representations to e.g. compute cosine similarities between words or perform other word2vec-typical tasks.

Play with attention-based models

If you now want to play with attention-based models, here are some resources to start:

  • there are several tutorials available to experiment with BERT and gpt-2; huggingface provides pytorch code and pre-trained models for nearly anything you could possibly need. If you are more familiar with the Keras framework, you can checkout the notebooks on my github. An interesting library to try is ktrain, which you can use to fine-tune BERT on custom classification tasks with only a few lines of code.
  • I found the Keras tutorial on the Transformer to be very instructive to understand the architecture in depth, including some details that have not been mentioned in the article (e.g. layer normalization and dropout).

Have fun!

[1] K. Cho et al., Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation, Sep 2014

[2] I. Sutskever et al., Sequence to Sequence Learning with Neural Networks, Sep 2014

[3] D. Bahdanau, K. Cho, I. Benjo, Neural Machine Translation by jointly learning to align and translate, Apr 2015

[4] Y. Wu et al., Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation, Oct 2016

[5] A. Vaswani et al., Attention Is All You Need, Jun 2017

[6] A. Radford et al., Improving Language Understanding by Generative Pre-Training, Jun 2018

[7] J. Devlin et al., BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Apr 2019

--

--