Transformer From Scratch In Pytorch
Introduction
Implementing transformer architecture from scratch is as important as understanding it theoretically, In my previous article I have covered the theoretical explanation on transformer and its working. In this article we will be implementing a transformer from scratch in pytorch and then train it on a very small dataset for Neural Machine Translation task.
If you have not read my previous article please give it a read.
Attention Is All You Need: Understanding transformers
To keep the flow easily understandable, first we will code each component and in the end we can combine them together.
Components
Let’s code the components in the order of their usage.
A transformer block consists of an Encoder and a Decoder , both of them rely on Attention mechanism and a Feed Forward layer.
The Encoder and Decoder can be finally used to form a Transformer Block.
Attention
It takes Q, K, V vectors as input which are passed through linear layers before applying the scaled dot product attention. The output is then concatenated and then passed through a linear layer.
Note: In the original paper the dimension of W_Q_i, W_K_i and W_V_i was (d_model,d_k), where the value of
i
is from0
tonum_of_heads
. So finally stacking up the vectors W_Q_i’s it will result into a vector of dimension (d_model,d_model).
Position Wise Feed Forward
The output of multihead attention layer is passed through a feed forward layer, in both encoder and decoder blocks. It consists of two linear transformations with a ReLU activation in between.
Now let’s code the encoder layer and decoder layer which will further be used to form the encoder and decoder blocks.
Encoder
The input tokens are used to find the input embedding, and since they don’t carry any information regarding their positions, we need to add positional encoding to them, which can be calculated using below formula.
You can find the code for calculating PE, inside the encoder and decoder blocks which uses the input tensor x
to generate the PE’s.
Above is a simple encoder layer, in which input is passed through multiple such layers(6 as mentioned in original paper) which combinedly forms the encoder block, as shown below.
Decoder
The output embedding is generated using the output tokens to which their PE is added and passed as input to self attention block of decoder.
The decoder differs from the encoder by having a cross-attention block, that uses the encoder’s output as the Query and Key, while the Value comes from its self-attention block.
The decoder block consists of multiple decoder layers(6 as mentioned in original paper) which takes output embedding and encoder blocks output as input, which is then finally passed through a linear layer.
Transformer Block
Now let’s combine it all together, stacking up the encoder and decoder blocks forms a transformer. The input to the transformer is source and target vectors, which are used to form the src_mask
and trg_mask
.
Source masks in the encoder prevents attention to padding tokens, while target masks in the decoder blocks future tokens (look-ahead masking) to ensure autoregressive generation.
The source vector and source mask are passed as input to encoder block, to obtain the encoder’s output, which is passed along with target vector, source mask and target mask to decoder, which gives the final output.
Training on NMT Task
Introduction
We will be training the transformer which we implemented for language translation task(English to Hindi). I have taken a very small dataset and vocabulary so that we can overfit the model on these sample translation sentences and check its performance if its working for the examples which we trained on.
Dataset Preparation
The dataset class takes index as input, for that particular index it converts the sentences into its token using the source and target vocabularies. In case if a word is not present it replaces it with <unk>
token. The tokens are finally padded as per the max_len
.
Training code
Now let’s train the model on our sample dataset. Note that we are trying to overfit the model, just to check its working.
Inference
We can now test the trained model on some sample sentences.
Conclusion
So finally we have learned how a transformer works and we have a working implementation from scratch. I hope this helps to have a deeper understanding of transformers.