Deep Learning: The Transformer

Mohammed Terry-Jack
9 min readJun 23, 2019

--

Seq2Seq

Sequence-to-Sequence (Seq2Seq) models contain two models: an Encoder and a Decoder (Thus Seq2Seq models are also referred to as Encoder-Decoders)

RNNs

Recurrent Neural Networks (RNNs) like LSTMs and GRUs have long been the model of choice for sequential data, making RNNs a popular choice for both the encoder and decoder

The Encoders job is to take in an input sequence and output a context vector / thought vector (i.e. the encoder RNN’s final hidden state. Or if the encoder is a bidirectional RNN (bi-RNN) it could be the concatenation of both directions’ final hidden states). This context vector is the input for the Decoder whose job is to output a different sequence (e.g. the translation or reply of the input text, etc).

Unfortunately the performance drops drastically for longer sentences because signals from earlier inputs get diluted as they pass down the longer sequences.

One way to keep earlier signals strong is to use skip-connections that feed every hidden state of the encoder RNN into every input of the decoder RNN (rather than just the encoder’s final hidden state being fed into the decoder’s initial state).

However, there is then the question of how to combine multiple hidden state into a single context vector. You could simply concatenate them.

Or you could sum them together (or average them, or take the max, or min, etc)

However all these suggestions make the assumption that all hidden states (and their corresponding input word) are equally important. We can do better with a weighted sum, e.g. Tf-idf weighted sum of vectors (to pay more attention to relatively rarer words)

However, this also assumes that the same hidden states (and inputs) will be more/less important to each output identically, whereas the importance of inputs depends on the specific output being predicted (is often the case that some words are more relevant for some predictions than others)

Attention

The solution is to use dynamic weighting (aka global alignment weights)

context vector (c_t) = dynamically-weighted (a_t,i) sum of the encoder’s hidden state vectors (h_i)

An attention mechanism calculates the dynamic (alignment) weights representing the relative importance of the inputs in the sequence (the keys) for that particular output (the query). Multiplying the dynamic weights (the alignment scores) with the input sequence (the values) will then weight the sequence. A single context vector can then be calculated using the sum of weighted vectors

Calculating the attended context vector (weighted sum of input vectors) is as simple as performing the dot-product (MatMul) of the dynamically attended weights (a) with the input sequence (V).

Calculating the attended context vector (weighted sum of input vectors) is as simple as performing the dot-product (MatMul) of the dynamically attended weights (a) with the input sequence (V).

Dot-Product Attention

There are numerous methods for obtaining the dynamically attended weights (a), such as Dot-Product Attention (specifically the dot-product of the Query (Q) and Keys (K) )

a = Q @ K (Transposed)

A useful byproduct of the dynamic alignment weights (attention) is that it can be displayed as a heat map to visualise exactly which words in the sequence are more important for predicting each output word.

Scaled Dot-Product Attention

Then there are some normalisation techniques which can be performed, such as softmax(a) to non-linearly scale the weight values between 0 and 1. Because the dot-product can produce very large magnitudes with very large vector dimensions (d) which will result in very small gradients when passed into the softmax function, we can scale the values prior (scale = 1 / √ d).

Therefore, the normalised and scaled-dot-product attention = softmax(a * scale) = softmax(Q @ K / √ d)

Efficient Calculations

For efficiency, a set of queries can be computed simultaneously. Thus multiple query vectors are packed together into a matrix Q. The keys and values are also packed together into matrices K and V. For further computational efficiency, you can make the query, key and value vectors smaller using projection vectors that reduce the dimension of some vector (X) via linear weight transformations / projections (These projection vectors / weights (W_Q, W_K, W_V) are learnt during training)

Q = matrix of Query vectors. K = matrix of Key vectors. V = matrix of Value vectors. (W_Q, W_K, W_V are projection vectors / weights)

…but what are Q,K,V exactly?

The exact values for Queries, Keys and Values depend on exactly which attention mechanism is being referred to. For the Transformer, there are three separate Attention Mechanisms (which we shall refer to as (i) the Encoder Attention, (ii) the Decoder Attention and (iii) the Encoder-Decoder Attention — which is also in the decoder side, but is named encoder-decoder because it feeds in all the encoder outputs to each of the decoder’s inputs)

Encoder Attention

  • Q = the current position-word vector in the input sequence
  • K = all the position-word vectors in the input sequence
  • V = all the position-word vectors in the input sequence

Decoder Attention

  • Q = the current position-word vector in the output sequence
  • K = all the position-word vectors in the output sequence
  • V = all the position-word vectors in the output sequence

Encoder-Decoder Attention

  • Q = the output of the decoder’s masked attention
  • K = all the encoder’s hidden state vectors
  • V = all the encoder’s hidden state vectors

…why does K = V ???

The attention mechanism is used to relate two different sequences to one another (the keys and values). However, self-attention / intra-attention is used to relate different positions of the same input sequence to one another and so the attention mechanisms can be adapted by replacing the target sequence (values) with the same input sequence (keys)

Transformer

Problem

Since we collect the signals at each state directly (dynamically weighted using Attention), we no longer need to propagate and combine the signals down to the final hidden state of the RNN. But this was the main power of RNNs!

Solution: Attention Only

Therefore, as the authors of the paper “Attention is all you need” concluded, RNNs are no longer needed. Instead, the architecture can be simplified to include Attention mechanisms directly on the encoder inputs, bypassing the RNNs.

This is also more efficient to calculate and process — so deeper models with more parameters than ever before can, and have been trained (e.g. BERT, GPT-2, etc)

Problem

Attention Vectors alone are essentially a set of matrix multiplications and thus linear transformations. How can a network learn hierarchical features to approximate complex functions without non-linear activation functions?

Solution: Feed Forward Neural Networks

Right after the attention mechanism, introduce some non-linear transformations by including fully-connected feedforward neural networks (with simple, yet non-linear, relu activation functions) for each input (with shared parameters for efficiency). The FFNN’s output vectors will essentially be replacing the hidden states of the original RNN encoder.

Problem

Without a sequence-aligned recurrent architecture, how do we account for sequence order? The resulting context vector for the input sequence “dogs chase cats” would appear identical to the context vector for “cats chase dogs” if sequence order is ignored and the decoder would be unable to differentiate between two sequences with identical inputs but differing order (thus hindering performance)

Solution: Positional Encoding

Positional embeddings encode information about the position of tokens in the sequence. These embeddings can be learnt or simple sine / cosine functions can be used

positional encodings should have the same dimensions as the token embeddings so the two embeddings can be summed together (essentially injecting positional information into the input embedding representations)

Tips and Tricks to improve performance further

Training

Dropout (very helpful in avoiding overfitting)

Label smoothing (Penalises overconfident predictions so the model learns to be more unsure which improves accuracy)

Architecture

Multi-head attention (similar to how you have several kernels in CNNs, you can have several self-attention layers in a Transformer which run in parallel. Each linearly project the queries, keys and values with different, learned projections.

Each attention head (calculated context vector) is concatenated

A projection layer (of linear weights) transform this long concatenated vector down into a smaller vector with reduced dimensions. The projection matrix (W) also serves to learn which attention heads and features to take more/less notice of.

Deeper (more layers)

Residual connections (merging the input to the output of the sub-network)

--

--