Transformer From Scratch In Pytorch

Ritik Nandwal
4 min readSep 5, 2024

--

Introduction

Implementing transformer architecture from scratch is as important as understanding it theoretically, In my previous article I have covered the theoretical explanation on transformer and its working. In this article we will be implementing a transformer from scratch in pytorch and then train it on a very small dataset for Neural Machine Translation task.

If you have not read my previous article please give it a read.

Attention Is All You Need: Understanding transformers

To keep the flow easily understandable, first we will code each component and in the end we can combine them together.

Components

Let’s code the components in the order of their usage.

A transformer block consists of an Encoder and a Decoder , both of them rely on Attention mechanism and a Feed Forward layer.

The Encoder and Decoder can be finally used to form a Transformer Block.

Attention

1. Scaled Dot Product Attention | 2. Multi-Attention

It takes Q, K, V vectors as input which are passed through linear layers before applying the scaled dot product attention. The output is then concatenated and then passed through a linear layer.

Note: In the original paper the dimension of W_Q_i, W_K_i and W_V_i was (d_model,d_k), where the value of i is from 0 to num_of_heads . So finally stacking up the vectors W_Q_i’s it will result into a vector of dimension (d_model,d_model).

Multihead Attention Block

Position Wise Feed Forward

The output of multihead attention layer is passed through a feed forward layer, in both encoder and decoder blocks. It consists of two linear transformations with a ReLU activation in between.

Position Wise Feed Forward

Now let’s code the encoder layer and decoder layer which will further be used to form the encoder and decoder blocks.

Encoder

Encoder Layer

The input tokens are used to find the input embedding, and since they don’t carry any information regarding their positions, we need to add positional encoding to them, which can be calculated using below formula.

Positional Encoding

You can find the code for calculating PE, inside the encoder and decoder blocks which uses the input tensor x to generate the PE’s.

Encoder Layer

Above is a simple encoder layer, in which input is passed through multiple such layers(6 as mentioned in original paper) which combinedly forms the encoder block, as shown below.

Encoder Block

Decoder

Decoder Layer

The output embedding is generated using the output tokens to which their PE is added and passed as input to self attention block of decoder.

The decoder differs from the encoder by having a cross-attention block, that uses the encoder’s output as the Query and Key, while the Value comes from its self-attention block.

Decoder Layer

The decoder block consists of multiple decoder layers(6 as mentioned in original paper) which takes output embedding and encoder blocks output as input, which is then finally passed through a linear layer.

Decoder Block

Transformer Block

Transformer Architecture

Now let’s combine it all together, stacking up the encoder and decoder blocks forms a transformer. The input to the transformer is source and target vectors, which are used to form the src_mask and trg_mask .

Source masks in the encoder prevents attention to padding tokens, while target masks in the decoder blocks future tokens (look-ahead masking) to ensure autoregressive generation.

The source vector and source mask are passed as input to encoder block, to obtain the encoder’s output, which is passed along with target vector, source mask and target mask to decoder, which gives the final output.

Transformer Block

Training on NMT Task

Introduction

We will be training the transformer which we implemented for language translation task(English to Hindi). I have taken a very small dataset and vocabulary so that we can overfit the model on these sample translation sentences and check its performance if its working for the examples which we trained on.

Dataset Preparation

The dataset class takes index as input, for that particular index it converts the sentences into its token using the source and target vocabularies. In case if a word is not present it replaces it with <unk> token. The tokens are finally padded as per the max_len .

Training code

Now let’s train the model on our sample dataset. Note that we are trying to overfit the model, just to check its working.

Inference

We can now test the trained model on some sample sentences.

Conclusion

So finally we have learned how a transformer works and we have a working implementation from scratch. I hope this helps to have a deeper understanding of transformers.

References

--

--