Let’s build ‘Attention is all you need’ — 1/2

On why RNNs suck!

Yash Bonde
DataDrivenInvestor

--

Have you ever been with someone who speaks alot, and even if you pay close attention, you miss out on some of the details. Well computers are absolutely horrible at that, they can’t remember anything. Even the state-of-the-art machines out their can’t remember things beyond a few steps. This kind of information that depends on time i.e. sequential, is called temporal knowledge. And solving temporal problems has created solutions no one thought were possible, think of the practically global language translator Google Translate.

Most of the temporal deep learning models use a special type of neural network called Recurrent Neural Networks (RNN). Google in their translation system using hard coded rules along with RNN to create seamless results across languages on this planet. The main idea behind these models is to use a same network every time some information comes, i.e. at each time step. Once all the input information has been given, the input neural network does what we call encoding. After encoding it uses a different neural network (though structurally similar) to decode the information and predict the translated results.

The most common neural networks used are called Long Short Term Memory (LSTM), but there are a variety based on the requirement like Gated Recurrent Units (GRU) which perform exceptionally well on sound models. RNNs have been very effective in solving the translation tasks, but there are numerous setbacks for using them:

  1. They are computationally inefficient: Most of the hardware which runs these behemoth networks today is GPUs, and they are optimised to perform one task, matrix multiplications. Now yes RNNs also require matrix multiplications, but they equally require continuous transfer of data between memory as it need to re-run the network at each time step.
  2. They are slow: As explained above they can take a long time to run. Think of a sentence which has 12 words, then it has to run the encoder neural network 14 times (12 words + <START> and <END> tags), and about equal number of times for decoder. Which totals at ~28 times. It has to store and retrieve information from the memory and push it to processor 28 time for one sentence! And according to Google guys, the average length of sequence they get is 70 words.
  3. Distance each word has to travel: Now this is going to be a bit tricky to explain in words so look at the diagram below. What you are seeing is a simple RNN, ignore the x, W, h, y. But the word ‘Echt’ has to travel multiple steps. The last red layer has to store the encoded information. If you think of large sentences which are over 50 words long, the amount of distance each word has to travel increases linearly. And since we keep writing over that encoded information, we are sure to loose important words that come early in the sentence. After encoding it also has to travel to get to it’s decoded destination.
from here

The single most important issue with using RNN is the point number 3. So to deal with that new methods were introduced, most popular of them are attention mechanisms. With an attention mechanism we no longer try encode the full source sentence into a fixed-length vector. Rather, we allow the decoder to “attend” to different parts of the source sentence at each step of the output generation. Attention mechanims are loosely based on visual attention found in humans. Using this when we perform tasks based on sequential data like comprehensions, we immidiately look for the similar words in text and branch around them to search for the answer.

The significant achievement of attention mechanism was to improve spatial understanding of the model, and allow it to focus on different parts of sentences as it performs tasks like translation. A normal attention mechanism model looks like the one below.

The encoder not only creates hidden states that are fed into decoder like conventional models but at each time step, it also generates a context vector which is also fed into the decoder. Now exact location and method these context vectors are added varies from model to model, but this is the gist of the idea. Attention mechanisms soon outperformed the conventional models and became industry standard, from translation models for various companies to more specialised cases where RNN was used.

Transformer Network

But even if we have solved the 3rd point, we still are left to solve the first two points. This is where transformer network comes to play. Attention is all you need, is not only a very catchy title for a research paper but also a very appropriate. The authors demostrated that by just using attention they outperformed the conventional models. But for me the greatest achievement was the clever way their model works. The network looks something like this:

We have an encoder on left and decoder on right, each stack is repeated 3 times. The most unique element is the Multi-Head attention mechanism that is uses.

The way it works is very different to the way conventional RNN models work and it takes time to wrap head around it, especially if like me you decide to code it first (i.e. being stupid). The simplest way to explain how this model runs is as follows, unlike conventional model we don’t feed it word by word for input first and then fetch output later.

We feed it both the input and output sentences at the same time. The outputs initially can be filled with anything, the model ignores whatever you fill into that. It uses the entire input sentence and output sentence to predict the next word in a single go. Once we predict the word, we replace that in output sequence, and model only considers output till that point and ignores what is ahead of it. We continue to do that till we have a complete sentence.

Usage of Attention

To implement attention the transformer network uses something called the multihead attention (MHA). The main idea behind attention is lookup-table, a table that has a large number of values for some other values and you ask it a query and it returns one closest to it. In the method used here we feed it three values, key, value and query. There are large number of keys, basically 1-dimensional vectors in n-dimensional space, where each key has some corresponding value.

The attention mechanism used here is additive and also uses dot-product. The reason for the formula used can be explained much more effectively using visualisations. Look at the image below and you will get a better idea:

Query (Q), Key (K) and Value (V) and dk = dmodel/num_heads

The MHA is just a larger implementation of attention mechanism. Rather than using one attention over the text, we apply 8 (in paper) attention heads, merge these heads and perform further operations. Main advantage is that using more than one attention heads, reduces variance and also allows for better learning over time. Maybe the different heads learn to attend to different things, but final output is same. Scaled Dot-Product Attention is the one explained above. The unit looks as follows:

The Masking is optional and it’s use is explained below

We would ideally like to use attention over the entire input sequence and only till the required point in generated output, because in language tasks we cannot have future predict the past. But this will again require us to loop over the yet generated sequence to predict the next word, and loops are against what we want. So the authors proposed a simple method that does this in one go. In the decoder stack at the bottom there is a special type of multi-head attention called masked multihead attention (MaskedMHA). This masks out all the values of the future and converts them to 0. Actual implementation is a bit different but this is the idea behind it.

We use MHA, MaskedMHA, combined with feedforward network to create one module. This module is a complete processing unit in itself, we can use as many of these as we want. In the paper authors used 3 stacks for each encoder and decoder, i.e. Nx = 3. In the final output of the decoder, we pass it through feedforward layer and perform final softmax over entire vocabulary. To get the required word.

Usage of Positional Encoding

Since we are processing both the input and output sentences simultaneously, our model does not have an understanding of positions and places of words in sequence. To give knowledge about positions authors use a simple cosine based encoding method. The simple idea behind using cosines is like giving each position a vector of values in between -1 and +1. So each positional vector then behaves like a giant array of switches which are either off or on or somewhere in the middle. The encoding can be visualised as:

For this example the vocabulary size is 512 words

The following functions were chosen because the authors imagined it would allow model to learn to attend by relative positions and because the simplicity of it allows to use it on sentences longer than the ones it was trained on.

Positional Encoding

Think of this in the following way, we have a large number of cosine functions of varying wavelengths/frequencies and for each position we have to select a column corresponding to that index. Something like this:

Selection of columns

This is how the main components of Transformer Network are built. The amount of effort and detailing for each small part in the model is commendable. Especially interesting for me while making this model was the intuition and logic behind attention mechanism. The simplicity of method can also be used to determine similar vectors in a large space. In fact I have actually used that in another project to determine the similarity of multiple trajectories. I have not added alot of images that show the attention working as they can be seen on other blogs.

Next time we build the transformer network using tensorflow in python and train it on simple toy dataset. Stay tuned!

--

--

Final Year Undergrad at NIT Raipur! Interest lie from Artificial Intelligence to Graphics Design. Seem interested in my musings, hit that clap! 🐾