Explaining Transformers in the context of RNNs

Jacob Stern
Deep Learning for Protein Design
6 min readOct 28, 2020

Have you ever wanted to really understand Transformers? This post diagrams self-attention — the core component of Transformers — in a way that illustrates its similarity to RNNs. We go through intermediate tensor shapes, which are crucial to understanding information flow.

This is a self-attention layer:

Single-Layer Self-Attention Network

But before we tackle that monstrosity, let’s build the up the connection to RNNs.

RNN -> Attention RNN -> Self-Attention

This post compares recurrent neural networks (RNNs), attention-based RNNs, and self-attention-based neural networks (Transformers). I have not adhered strictly to the architectures originally proposed by the paper authors, but instead adapted those architectures to illustrate their fundamental similarities. I encourage you to carefully study the arrows and dimensions in the diagrams to track the information flow. It may also be helpful to pull up two windows to compare the diagrams side-by-side.

RNNs

Recurrent neural networks are neural networks structured to operate on sequences. A recurrent unit takes in a token (a word, sub-word, or letter) and produces a hidden state and a prediction for the next token. An RNN can model one-to-one, one-to-many, many-to-one, and many-to-many relationships (see Karpathy’s blog). A vanilla many-to-one RNN is diagrammed below.

Vanilla RNN

Problem: Long-term dependencies

A common problem with RNNs is their inability to “remember” tokens from the beginning of the sequence. (Hochreiter 1997) introduced the long-short-term-memory (LSTM) recurrent architecture to allow error to flow backward to the earliest tokens in the sequence. However, recurrent neural networks continued to struggle with long-term dependencies, due to the constraint of needing to encode an entire context sentence into a fixed-length vector.

Solution: Attention

(Bahdanau 2015) introduced an “attention mechanism” to fix this problem. Rather than forcing the neural network to remember an entire context sequence, they allowed the network to “attend” to the parts of the input sequence that were relevant for predicting the next word. Thus, the context vector is a weighted combination of only the most relevant hidden states for the next word. Furthermore, in the computation graph, error can propagate more directly to the part of the network that produced those hidden states.

The diagram below is slightly different from the proposed architecture of (Bahdanau 2015). My diagram is encoder-only, uni-directional, and uses dot-product attention. I also leave out activation functions like the softmax function. I do this for simplicity, so you can see the parallels to a vanilla RNN and a self-attention network.

Attention RNN

Things to notice about this diagram: the attention weights are dot-products “similarities” between the last hidden state and previous hidden states. The context vector c_{L-1} is a linear combination of the previous hidden states, based on the attention weights a_1, … , a_{L-2}.

More Problems: Batch Training, Positional Information, and Shallow Representations

There are still some limitations to the attention-based RNN. 1) Due to its sequential nature, the recurrent unit is a bottleneck, making batch-training difficult. 2) There is not explicit positional information — if the context vector is a linear combination of hidden states, we don’t know the order of those hidden states. That is like trying to predict the next word based on a randomized bag-of-words as its context, rather than an ordered context. 3) The context vector offers a contextual representation of the previous sentence, relevant for predicting the next word. But you can imagine obtaining an even better context vector by making this attention network deeper — if each context vector was a mix of context vectors from the previous layer.

Solution: Self-Attention

(Vaswani 2017), in “Attention is All you Need”, introduced the Transformer, which addressed these limitations with a new architecture based on “self-attention.”

As with the attention-RNN, I’ve left out many details. This diagram is a single-head, single-layer attention network. I’ve left out activation functions, softmax functions, masks, and normalization. I also pre-multiplied by V-transpose instead of post-multiplying by V. The objective is for this to look similar to the attention-based RNN.

Single-Layer Self-Attention Network

Some things to point out here: notice that the self-attention mechanism is almost identical to normal attention. However, rather than only the last hidden state getting an attention-based context vector, each hidden state gets its own attention-based context vector. If we are trying to predict the next word, we would only use the last context vector, and the rest of the information — the self-attention — is useless.

So why bother with self-attention over normal attention? 1) No one trains a single-layer self-attention network. The real value of self-attention is the recombination of attention information over multiple layers. The output of the first self-attention layer is a contextual embedding of each input token. The output of the second self-attention layer is a contextual embedding of a contextual embedding. And so on. 2) The removal of the sequential constraint for training. Self-attention networks realize that you no longer need to pass contextual information sequentially through an RNN if you use attention. This allows for mass training in batches, rather than needing to pass each word through the RNN sequentially. 3) They also use positional encoding to encode relative positions of words, which seems to help training.

I have one more diagram for you to think about. I have taken the self-attention network, stripped out all of the context vectors except the last, and removed all dependencies, so that we only predict the last context vector.

Single-Layer (non) Self-Attention network predicting only one context vector.

This is almost identical to the attention-RNN. It is no longer a self-attention network, as we only have a context vector for the last token. The remaining differences from a normal attention RNN are 1) the fact that the initial embeddings are not sequentially generated, so they are not contextual, 2) use of position embeddings, 3) a linear unit instead of a recurrent unit, and 4) the fact that each object (the query, key, and value) gets its own prediction layer, rather than sharing the recurrent unit. You can compare this diagram to the attention RNN diagram to see the similarities.

Conclusion

We have discovered the roots of self-attention in RNNs and attention-RNNs. We have seen the limitations of RNNs and attention-based RNNs, and how self-attention networks seek to remedy them.

The concept of self-attention can be complicated to grasp. Hopefully, putting self-attention in the context of RNNs helps it make more sense. Till next time!

Follow me on Twitter: @jacobastern

--

--

Jacob Stern
Deep Learning for Protein Design

PhD student at Brigham Young University. Researching protein design with deep learning.