From Recurrent Neural Network (RNN) to Attention explained intuitively

Cakra
7 min readAug 22, 2021

--

Background

RNN is one of the fundamental building blocks in Deep Learning. It’s as fundamental as feedforward networks & convolutional networks. Many advanced Deep Learning architecture basically mix & match these different building blocks.

Machine learning models that has text/audio as part of the input/output would likely involve RNN. Its role is mostly seen in natural language processing (NLP) tasks like text generation (by building language model), speech recognition, text classifier, question & answering, and many others.

You might’ve heard about terms like Transformer, Attention, Encoder-Decoder, etc around RNN that seems overwhelming. Actually they are all related, and there’s a path to connect them together. This article will explain that path to get the intuition up to the concept of Attention.

Let’s begin with a “basic” neural network.

Feedforward Neural Network (FNN)

FNN consumes all inputs at the same time (aka in one time step). An example of FNN is a simple image classifier which takes an image as the input, and it outputs the category of the image (whether it’s an image of a cat/dog /bird/etc).

Simple feedforward network. Inside, it can contain many hidden layers. The input is consumed as a whole at once. Note: The rectangles with sharp edges represent a vector.

Recurrent Neural Network (RNN)

The way Recurrent Neural Network (RNN) processes the input is different from FNN. In FNN we consume all inputs in one time step, whereas in RNN we consume one input for each time step. This one input can be one character (if we consume the input by characters), one word (if we consume the input by words), one sample / one audio frame (for audio input), etc.

Note: Since each input word has to be a vector, we need to pass the word as one-hot encoding vector or passing it through Embedding Layer to get an embedding vector of the word. Embedding Layer is omitted from the diagram for simplicity.

The characteristic of RNN is that the input has time dependency. It means that the input consists of sequence of values that has time order, like this value comes first, then the other value, and so on. Time series data is one example. A sentence is another example (it starts with the first word, then the second word follows the first word, and on and on).

Hence, the input of RNN is a sequence, or a list of values. In neural network, a value is represented as a vector. In other words, the input of RNN is a sequence of vectors (list of vectors).

In real code, we process the input in batch to make the computation more efficient. Different inputs from the same batch are independent from each other, thus the model would behave the same as if we are processing the input one by one).

More about State in RNN

During processing of the input, we can pass information from one time step to the next. This is done by keeping State (often called hidden state). State has similar role to memory (RAM) in computers.

For example, when you add two numbers digit by digit, you keep track the carry over (if sum of digit is greater than 9) to the next digit position. This carry over state is stored in some memory.

Just like RAM, we can read from or write to the State. In short, in each time step of RNN, it does all of these operations:

  1. Read value(s) from State (from previous time step) and an input value (at the current time step)
  2. Make computation using those values to get a result.
  3. Write the result to State (at the current time step). Of course we can keep the history of previous States as well.

Variations in step (2) and (3) define different flavors of RNN (vanilla RNN, LSTM, GRU, etc).

In vanilla RNN and GRU, we only use one vector to represent the state. In LSTM, we use two vectors to represent the state (called cell state and hidden state respectively, each has different role). In Neural Turing Machine (NMT), we use N vectors (we can define how many) to represent the state.

From now on when we mention RNN, it could refer to any variation of RNN (can be LSTM, GRU, etc).

In pure RNN (without adding additional layer after RNN), the hidden state is the output. We can pass the hidden state to an output layer to transform it to the output shape that we want. For example, we can use linear+softmax layer as the output layer if we want to have categorical output (aka classifier). In RNN, the terms output & hidden state might be used interchangeably.

Fun fact: Softmax is the continuous or the smooth version of hardmax (commonly known as argmax). Softmax is also generalization of sigmoid for N dimension.

We can add output layer to transform the hidden state to the output shape that we want.

There are two common ways to use the output of RNN. The first one is to only use the output from the last time step. For example, in sentiment classification (to classify whether the input has positive or negative opinion), we only use the output from the last time step since we only deal with one value output (no matter how long our input sentence is).

The second one is to use the output from all time steps. For example, in name entity recognition, we can have each word to have probability of being part of name entity. In this case, we’ll use every output from each time step (because in each time step we compute the probability for each word).

Stacking RNNs

After we get the basic of RNN, we can upgrade it further. We only talk about one layer RNN so far. How about stacking multiple RNNs? When it comes to stacking RNNs, there are two ways to stack them: vertical or horizontal.

Two ways to stack two RNNs: vertical or horizontal

In vertical stacking, we add more computation per time step (each time step becomes more complex, adding model capacity). We put different RNN on top of the other. We pass the hidden state from previous layer as the input to the next layer. This way, we proceed to the next time step only after we compute every RNN layers at the current time step.

The final output is the hidden state on the last layer of the RNN (furthest from the input).

Stacking RNN vertically

In horizontal stacking, one RNN layer has to process all inputs before passing the output to the next RNN layer. This is often called Encoder-decoder.

There doesn’t seem to be any point to stack more than two layers horizontally. Maybe it has some use, but the term horizontal stacking is just trying to generalize between vertical-horizontal.

Encoder-decoder

When the input is a sequence and the output is a sequence, we often also call it sequence to sequence (seq2seq). We can use Encoder-decoder architecture to deal with seq2seq problems.

In the most basic form of Encoder-decoder, Decoder is another RNN that gets the initial state from the last hidden state of Encoder.

Encoder’s hidden state on the last times step becomes the initial state of Decoder. Note: <start> and <end> are special words/tokens to indicate the start and end of the sequence in Decoder. There’s nothing special about <start> and <end> tokens other than their role as markers.

The famous example of encoder-decoder architecture is machine translation, where we have an input sentence from one language (like English) and we want the output to be the translated sentence (the same sentence but in Chinese for example). The example in above animation is translating English sentence “I love you” into French sentence “Je t’aime”. Note that the length of the source & target sentences are different (3 words vs 2 words excluding <start> & <end> tokens).

In encoder-decoder, the original input (the sequence of vectors representing words in a sentence) is only consumed by Encoder (the first RNN). Since Decoder only gets the information from Encoder’s hidden states (and not the original input), Decoder can produce sequence with different length than the original input sequence. This enables more flexible final output, as the output length is not tied to the input length.

As we can see, the only information about the input that Decoder receives is from the last hidden state of Encoder. From this one vector, Decoder needs to generate the entire output sequence. That’s pretty heavy responsibility to have. Although it might work fine in some cases, it’s certainly very limiting.

Why do we only consider the last Encoder hidden state? Why not consider all Encoder hidden states? Let’s enter Attention.

Attention

In Decoder, we can get better information about the input by considering all Encoder hidden states (instead of just the last one). This information (called context vector) is used to enhance decoder hidden state to generate better final output.

We use context vector combined with decoder hidden state to generate the final output in each time step.

Attention Layer computes the context vector from all encoder hidden states; and the decoder hidden state of current time step. It recomputes the context vector for each time step in Decoder.

We can see Attention Layer as: “OK, with current decoder hidden state which part of the input is relevant?”. One way to compute this relevance is by computing similarity score between decoder hidden state and each of encoder hidden states. That’s why often see dot product operations when computing attention (dot product can be used to compute similarity).

We can also see attention as doing lookup on a memory (containing information from encoder hidden states). If we have Decoder hidden state as the Query; an Encoder hidden state as the Key; then lookup operation is looking for Keys that’s relevant to the Query. The value inside the memory is the encoder hidden state itself. This is called content based attention.

When the Keys & Values itself coming from current sequence, it’s called self-attention.

Hopefully you’ve formed some base understanding to help you follow more detailed resources.

This article mainly explains those architectures in prediction mode. We haven’t talked about how it’s trained. Training is related to prediction but that’d be another story.

In the next article we’ll talk about how we can use attention without recurrent network, which is what Transformer does.

Resources

--

--