Building the Mighty Transformer for Sequence Tagging in PyTorch : Part I
Attention mechanisms have taken the deep learning world by storm in the last few years. It is not uncommon nowadays to have an attention related component somewhere in your model. Last year Google pushed attention mechanisms to the extreme when it came out with the Transformer models in its infamous “Attention is all you need” paper. As the title says, Transformer models ditch recurrent layers for a combination of attention and regular feedforward networks. Not only is the model faster and scalable but it is also more interpretable because we can examine its attention weights to know what it focuses on. While this newfound approach has earned the Transformer new records in translation tasks no one seems have attempted apply this model to sequence tagging tasks yet. This is an area where RNNs still rule the roost. Would the Transformer be able to take out the venerable BiLSTM? I felt it was a great excuse to build our own implementation and that’s what this post will explain.
In this two part series we are going to build the Transformer model from scratch and make it compete with a BiLSTM model on the chunking task. We’ll do zero math and instead look at the concepts visually.
- This first part will go through the implementation of the core components that make the Transformer — Multi Head Attention and Position wise Feed Forward Network. We’ll especially detail out the inner workings of attention mechanisms.
- The second part will focus on implementing the remaining bits of the Transformer and putting all the parts together. It will conclude with a face off between the Transformer and BiLSTM.
We’ll follow the original Transformer paper to implement our model in PyTorch. But we will incorporate the latest improvements from the TensorfFlow implementation of Transformer as well. We will use the freely available CoNLL 2000 Chunking dataset to run the experiments. If you just want to grab the code it’s all there on Github.
At the top level, the Transformer has an Encoder and Decoder just like sequence-to-sequence models. If we strip away standard neural network embellishments like Dropout, LayerNorm and residual connections, we’ll find that both blocks have two unique components — Multi Head Attention and Position wise Feedforward. They may sound complicated but are not very different from the regular attention and feed-forward layers.
Multi-Head Attention — The Beast
Multi head attention is essentially attention repeated several times in parallel. (If you are not clear about the intuition behind attention I suggest you see this short video explanation by Andrew Ng on Coursera.) Attention in general can be considered to have three inputs and one output:
- Keys: A sequence of vectors also known as the memory. It is the contextual information that we want to look at. In traditional sequence-to-sequence learning they are usually the RNN encoder outputs.
- Values: A sequence of vectors from which we aggregate the output through a weighted linear combination. Often Keys serve as Values.
- Query: A single vector that we use to probe the Keys. By probing we mean the Query is independently combined with each key to arrive at a single probability. The type of attention determines how the combination is done. Usually Query is the decoder RNN state at a given time step in traditional sequence-to-sequence learning
- Output: A single vector which is derived from a linear combination of the Values using the probabilities from the previous step as weights.
With these definitions in place let’s look at the actual implementation in PyTorch. I will assume that you know how
nn.Module works. (If not here is an excellent introduction). Let’s jump right into the
forward() for out
We’ll go step by step repeating code snippets as needed:
1. Project each of the queries, keys and values linearly. By linear projection I mean a single neural network layer without the nonlinear activation. The learnable parameters of these linear projections are important for the attention mechanism:
queries = self.query_linear(queries)
keys = self.key_linear(keys)
values = self.value_linear(values)
This is how the Linears are defined:
output_depth are hyperparameters of the model. Note that key and query size should match as we have to do a dot product between them. Because Transformers don’t have recurrent layers we need not process the queries one step at a time. So the query input here is the entire sequence of queries for a particular layer.
2. Split the projected inputs into required number of partitions for multi head attention. Instead of working with a separate linear projections for each head or parallel attention we divide a single projection into partitoins and perform attention independently:
queries = self._split_heads(queries)
keys = self._split_heads(keys)
values = self._split_heads(values)
The way we split is to reshape the inputs to have an extra heads dimension.
3. Scale the queries by a certain factor. This is supposed to prevent the outputs from growing too big which can make learning difficult:
queries *= self.query_scale
The factor is calculated as the inverse square root of the partition size after splitting:
self.query_scale = (total_key_depth//num_heads)**-0.5
4. Take the dot product between the queries and the keys (since we are doing multiplicative attention). We do that for every query and key pair with one single
logits = torch.matmul(queries, keys.permute(0, 1, 3, 2))
A lot is happening in that single call so let’s break it down. For illustration we’ll use the following values:
num_heads=4 which gives us partitions of size 8 after splitting.
To get the dot product we must transpose the (last two dimensions of the) key tensor before the matrix multiplication
The above figures only show the last two dimensions of the input. But they are actually 4D Tensors. How can
matmul do 4D Tensors??
Turns out it does a matrix multiplication for each higher dimension separately. So as long as the higher dimensions match exactly
matmul can handle any number of them.
Of course the last two dimensions should follow matrix multiplication rules.
5. Apply a mask on the query key dot products. This step is required mainly for the decoder to prevent future inputs from influencing attention. We’ll cover this in part II.
6. Apply Softmax on the dot products to convert them to probabilities:
weights = nn.functional.softmax(logits, dim=-1)
7. Add Dropout. This is something that is not there in the original paper but added in the latest TensorFlow implementation.
8. Using the probabilities as weights do a linear combination of the values. Again we use a single
matmul to do the job across all values in all partitions.
contexts = torch.matmul(weights, values)
Continuing the earlier illustration:
With matrix multiplication we get all the outputs in one shot. This is what makes vectorization neat!
9. Now merge the heads back. Or rather reshape the outputs to get the original 3D shape [batch, sequence length, output depth]:
contexts = self._merge_heads(contexts)
_merge_heads() is just the reverse of
10. Finally we do another linear projection to get the output:
outputs = self.output_linear(contexts)
So there you have it, Multi Head attention in ten steps. See the full implementation in GitHub here Now let’s go into the next core component, the Positionwise Feedforward network.
Positionwise Feedforward Network — The Sidekick
Like the name indicates, this is a regular feedforward network applied to each time step of the Multi Head attention outputs. The network has three layers with a non-linearity like ReLU for the hidden layer. You might be wondering why do we need a feedforward network after attention; after all isn’t attention all we need 😈 ? I suspect it is needed to improve model expressiveness. As we saw earlier the multi head attention partitioned the inputs and applied attention independently. There was only a linear projection to the outputs, i.e. the partitions were combined only linearly. The Positionwise Feedforward network thus brings in some non-linear ‘mixing’ if we call it that. In fact for the sequence tagging task we use convolutions instead of fully connected layers. A filter of width 3 allows interactions to happen with adjacent time steps to improve performance. The implementation in PyTorch is straightforward:
self.layers is a
nn.ModuleList which can take either linear or convolutional layers. E.g. for linear layers:
self.layers = nn.ModuleList([
output_depth are all hyperparameters. (My actual implementation is slightly more complex because it allows the total number of layers and each layer type to be configured.)
With these two components implemented our Transformer is almost completed. All we need to do now is put everything together. The next part will cover that along with some tricks that we apply to improve the model. And of course we’ll run the actual experiments.