Deep Learning: The Transformer
Seq2Seq
Sequence-to-Sequence (Seq2Seq) models contain two models: an Encoder and a Decoder (Thus Seq2Seq models are also referred to as Encoder-Decoders)
RNNs
Recurrent Neural Networks (RNNs) like LSTMs and GRUs have long been the model of choice for sequential data, making RNNs a popular choice for both the encoder and decoder
The Encoders job is to take in an input sequence and output a context vector / thought vector (i.e. the encoder RNN’s final hidden state. Or if the encoder is a bidirectional RNN (bi-RNN) it could be the concatenation of both directions’ final hidden states). This context vector is the input for the Decoder whose job is to output a different sequence (e.g. the translation or reply of the input text, etc).
Unfortunately the performance drops drastically for longer sentences because signals from earlier inputs get diluted as they pass down the longer sequences.
One way to keep earlier signals strong is to use skip-connections that feed every hidden state of the encoder RNN into every input of the decoder RNN (rather than just the encoder’s final hidden state being fed into the decoder’s initial state).
However, there is then the question of how to combine multiple hidden state into a single context vector. You could simply concatenate them.
Or you could sum them together (or average them, or take the max, or min, etc)
However all these suggestions make the assumption that all hidden states (and their corresponding input word) are equally important. We can do better with a weighted sum, e.g. Tf-idf weighted sum of vectors (to pay more attention to relatively rarer words)
However, this also assumes that the same hidden states (and inputs) will be more/less important to each output identically, whereas the importance of inputs depends on the specific output being predicted (is often the case that some words are more relevant for some predictions than others)
Attention
The solution is to use dynamic weighting (aka global alignment weights)
An attention mechanism calculates the dynamic (alignment) weights representing the relative importance of the inputs in the sequence (the keys) for that particular output (the query). Multiplying the dynamic weights (the alignment scores) with the input sequence (the values) will then weight the sequence. A single context vector can then be calculated using the sum of weighted vectors
Calculating the attended context vector (weighted sum of input vectors) is as simple as performing the dot-product (MatMul) of the dynamically attended weights (a) with the input sequence (V).
Dot-Product Attention
There are numerous methods for obtaining the dynamically attended weights (a), such as Dot-Product Attention (specifically the dot-product of the Query (Q) and Keys (K) )
A useful byproduct of the dynamic alignment weights (attention) is that it can be displayed as a heat map to visualise exactly which words in the sequence are more important for predicting each output word.
Scaled Dot-Product Attention
Then there are some normalisation techniques which can be performed, such as softmax(a) to non-linearly scale the weight values between 0 and 1. Because the dot-product can produce very large magnitudes with very large vector dimensions (d) which will result in very small gradients when passed into the softmax function, we can scale the values prior (scale = 1 / √ d).
Efficient Calculations
For efficiency, a set of queries can be computed simultaneously. Thus multiple query vectors are packed together into a matrix Q. The keys and values are also packed together into matrices K and V. For further computational efficiency, you can make the query, key and value vectors smaller using projection vectors that reduce the dimension of some vector (X) via linear weight transformations / projections (These projection vectors / weights (W_Q, W_K, W_V) are learnt during training)
…but what are Q,K,V exactly?
The exact values for Queries, Keys and Values depend on exactly which attention mechanism is being referred to. For the Transformer, there are three separate Attention Mechanisms (which we shall refer to as (i) the Encoder Attention, (ii) the Decoder Attention and (iii) the Encoder-Decoder Attention — which is also in the decoder side, but is named encoder-decoder because it feeds in all the encoder outputs to each of the decoder’s inputs)
Encoder Attention
- Q = the current position-word vector in the input sequence
- K = all the position-word vectors in the input sequence
- V = all the position-word vectors in the input sequence
Decoder Attention
- Q = the current position-word vector in the output sequence
- K = all the position-word vectors in the output sequence
- V = all the position-word vectors in the output sequence
Encoder-Decoder Attention
- Q = the output of the decoder’s masked attention
- K = all the encoder’s hidden state vectors
- V = all the encoder’s hidden state vectors
…why does K = V ???
The attention mechanism is used to relate two different sequences to one another (the keys and values). However, self-attention / intra-attention is used to relate different positions of the same input sequence to one another and so the attention mechanisms can be adapted by replacing the target sequence (values) with the same input sequence (keys)
Transformer
Problem
Since we collect the signals at each state directly (dynamically weighted using Attention), we no longer need to propagate and combine the signals down to the final hidden state of the RNN. But this was the main power of RNNs!
Solution: Attention Only
Therefore, as the authors of the paper “Attention is all you need” concluded, RNNs are no longer needed. Instead, the architecture can be simplified to include Attention mechanisms directly on the encoder inputs, bypassing the RNNs.
This is also more efficient to calculate and process — so deeper models with more parameters than ever before can, and have been trained (e.g. BERT, GPT-2, etc)
Problem
Attention Vectors alone are essentially a set of matrix multiplications and thus linear transformations. How can a network learn hierarchical features to approximate complex functions without non-linear activation functions?
Solution: Feed Forward Neural Networks
Right after the attention mechanism, introduce some non-linear transformations by including fully-connected feedforward neural networks (with simple, yet non-linear, relu activation functions) for each input (with shared parameters for efficiency). The FFNN’s output vectors will essentially be replacing the hidden states of the original RNN encoder.
Problem
Without a sequence-aligned recurrent architecture, how do we account for sequence order? The resulting context vector for the input sequence “dogs chase cats” would appear identical to the context vector for “cats chase dogs” if sequence order is ignored and the decoder would be unable to differentiate between two sequences with identical inputs but differing order (thus hindering performance)
Solution: Positional Encoding
Positional embeddings encode information about the position of tokens in the sequence. These embeddings can be learnt or simple sine / cosine functions can be used
positional encodings should have the same dimensions as the token embeddings so the two embeddings can be summed together (essentially injecting positional information into the input embedding representations)
Tips and Tricks to improve performance further
Training
Dropout (very helpful in avoiding overfitting)
Label smoothing (Penalises overconfident predictions so the model learns to be more unsure which improves accuracy)
Architecture
Multi-head attention (similar to how you have several kernels in CNNs, you can have several self-attention layers in a Transformer which run in parallel. Each linearly project the queries, keys and values with different, learned projections.
Each attention head (calculated context vector) is concatenated
A projection layer (of linear weights) transform this long concatenated vector down into a smaller vector with reduced dimensions. The projection matrix (W) also serves to learn which attention heads and features to take more/less notice of.
Deeper (more layers)
Residual connections (merging the input to the output of the sub-network)