Deep Dive into AI: Analyzing ‘Attention is all you need.’

Ada Choudhry
10 min readApr 7, 2024

--

Breaking down the transformers architecture to first principles

I’ve been wanting to break down this architecture for quite a long time! In the next article, I’ll be implementing these learnings to build my mini LLM!

Welcome or welcome back to the Deep Dive into AI tutorial series where I go deep into the fundamentals of neural networks (specifically large language models) through the bottom-up approach!

I will be referencing a lot from Andrej Karpathy’s series on Neural Networks: Zero to Hero because I have found that it is one of the most practical and in-depth series on building neural networks. And as always, I’ll be breaking everything I learn into first principles through this series and sharing the useful life lessons along the way I have learned from programming!

Transformers have taken the world by storm but its architecture can be a bit intimidating to any newcomer in the field. In this article, I break down the concepts in transformers into first principles!

Here is the original paper Attention is all you need:

The diagram shows one layer of encoder and decoder and their interactions. These layers use multi-head attention, positional encodings, residual connections, and layer normalization. Let’s first understand the new concepts introduced in this paper and then we’ll piece them together to understand the encoder-decoder architecture.

Table of contents:

Understanding new concepts:

  • Attention
  • Positional Encoding
  • Residual Connections

Understanding the architecture:

  • Encoder
  • Decoder

Attention

Think of attention as an algorithm that helps the individual characters gain context. It is a communication mechanism between various tokens that creates a self-directed graph in which each token communicates with the previous character and is aware of its own context as well.

Such as in this graph, all the tokens are connected to themselves and the tokens before them. For example, in this graph, the 8th token is connected to all the previous tokens along with itself, while the 1st token is connected to just itself.

The attention function deals with three kinds of vectors:

  • The query vector represents what the token is looking for.
  • The key vector represents what the token contains.
  • The value vector is computed by compressing the incoming tokens to the attention head size.

The output is a weighted sum of the value vector. The weight of each value is computed through a compatibility function of the query with its associated key (usually we use the dot product between the keys and queries)

In self-attention, keys, queries, and values are all generated from the same source (that could be the input dataset). Cross-attention is when the queries are produced from the input dataset while the values and keys are generated from another dataset or layer in the transformers, to train the model on additional context. An example where cross-attention comes in handy is during language translation. You need an encoder who can simply understand the context of the whole grammar structure before passing it to the decoder, which can use the additional context to generate new tokens.

Scaled Dot-Product Attention

The input consists of queries and keys of dimension d_k and values of dimension d_v. In this case, the compatibility function to compute the query with its associated key is the dot product (that’s why both queries and keys have the same dimension). This dot product is scaled by dividing by the square root of d_k. Scaling the weights before combining them with values is important because the queries and keys have a unit Gaussian distribution when they’re initialized. After computing the dot product, the mean and the variance are scaled in proportion to the dimension size. Scaling it with the square root of d_k ensures that it has a unit Gaussian distribution which makes training easier.

Applying softmax on it gives us the weights for the values. Softmax would exponentiate each of the individual values and then divide it with the sum of the exponentiated numbers in the row or column.

If we don’t scale the weights, then some of the numbers would be too high, and inputting it to softmax would risk it converging the rest of the values to that high number. We risk converting our tensors into one-hot encoded tensors which means one of the values would be high and the rest would be near zero and that would mean that all the tokens are just communicating with one node.

Multi-head attention

This is a multi-head attention which means that the query, key, and value vectors are being projected multiple times. The query, key, and value vectors are projected linearly h times (equal to the head size) to dimensions d_k, d_k, and d_v. The query and key values have the same dimensions as they are paired up together.

On each set of linear projections of query, key, and value pairs, the attention function is performed. As the linear projections are projected h times, the attention function is performed h times, resulting in d_v dimensions. But why does that happen?

Because essentially the key and query projections are computed through the dot product and then scaled (through scaled dot-product attention) to become the weights for the values that have a d_v dimension.

These values are then concatenated together and then projected linearly for the final time to form the result of multi-head attention.

In summary, multi-head attention means applying attention various times in parallel (equal to the value of heads) and then concatenating the results.

Why go through all this trouble?

Multi-head attention allows us to attend to information from different representation subspaces simultaneously. We cannot do this in a single attention head because of averaging.

Even though the tokens have context about each other, they don’t have any information about their positions. This is solved through the next topic!

Positional Encoding

This allows us to inject positional information about the tokens’ relevant and absolute position into the input embeddings.

Since transformers process tokens in parallel, they do not inherently understand the order or position of tokens in a sequence. The positional encoding vectors are added element-wise to the input embeddings, preserving the original information while introducing positional information. The positional embeddings have the same dimension as the input embeddings (d_model) and can be summed up.

The positional encoding vectors are typically computed based on sine and cosine functions of different frequencies. The frequency of the sine and cosine functions determines how different positions in the sequence are encoded. In the original paper, these were the functions they used:

In the encoder and decoder, the layers are connected through residual connections, so let’s take a look at that before moving on!

Residual connection

is when we add identity mapping in addition to the output before passing it to the next layer. This is another way of saying that we add the input to the output before passing it onto the next layer.

Source: Towards Data Science

Why use residual connection?

Residual connection solves the problem of exploding or vanishing gradients seen in feed-forward neural networks because in these networks the path length for output is proportional to the number of layers. On adding more layers, the gradient explodes because the resultant output is huge compared to the outputs of neurons in the initial layers. This makes the network during backpropagation unstable. On the other hand, vanishing gradients are when the gradients are too small to result in learning. The sigmoid function squishes the outputs of the neurons between 0 and 1, resulting in a small derivative. This means that even a large change in the input would result in a small derivative over the output, leading to vanishing gradients in neural networks with multiple layers. Residual connections solve this problem by creating paths of varying lengths. This results in the creation of ensemble networks, which do not depend strongly on each other.

The reason residual connection is powerful because going through a model, we have two paths:

  1. Computation by passing through the intermediate layers
  2. Residual pathway which skips through the computation by adding the input directly to the computed output

When backpropagation occurs through addition, it distributes gradients equally to the numbers being added, so the output gradient will be the same for the computation layer and the residual pathway layer. Skip connections (also called shortcut connections) are added to the network, allowing gradients to bypass one or more layers.

  • During the backward pass (backpropagation), the gradients of the loss function with respect to the output of the block are computed first and the gradient has two paths.
  • If the gradient passes through the intermediate computation layers, it would undergo chain rule.
  • The gradient flowing through the residual connection is added directly to the gradient of the input. Since the residual connection directly connects the input to the output, this gradient flows directly back to the earlier layers of the network.

The gradients can easily bypass problematic layers, which helps in alleviating the vanishing gradient problem and enables the training of much deeper networks.

Understanding the architecture

It has 2 main components:

  1. Encoder: Think of it as an algorithm to convert an input to a representation, an internal language of sorts.
  2. Decoder: Think of this as an algorithm to convert the internal language to an output, which is comprehensible by humans. The decoder is our generator.

The input dataset is transformed through encoding them as embeddings in the dimension of d_model.

Encoder

This has a stack of 6 layers. Each layer has two sub-layers: Multi-head attention and position-wise fully connected Feed-forward Neural Network. The multi-head attention computes the context between the input token by producing key, query, and value vectors. There is a residual connection around both the sub-layers.

The encoder allows all the tokens to have context with each other without the masking of future tokens. This is very helpful in gaining context about the general rules and themes of the input.

First, the input embeddings get positional encodings which helps them understand their positions.

The encoder contains self-attention layers. In a self-attention layer all of the keys, values, and queries come from the same place, in this case, the output of the previous layer in the encoder (which would be just the positional encodings if it is the first encoder layer). Each position in the encoder can attend to all positions in the previous layer of the encoder, helping it have an overall context.

After the self-attention layer, the input is passed to a fully connected feed-forward neural network for more linear transformations.

These neural networks would have 2 layers so that they can have 2 linear transformations with a ReLU activation in between.

Even though different positions in a block in each batch would have the same transformation happening, the parameters differ in these neural networks between each encoder or decoder layer, so that they can focus on different aspects of the text.

After each sub-layer (either multi-head attention or feed-forward neural network), there is a layer normalization. The output of each layer is LayerNorm( x + Sublayer(x)).

Layer normalization is a technique used in machine learning, particularly in neural networks, to normalize the activations of neurons within a layer. It is similar to batch normalization but normalizes across the features in a layer instead of across the batch dimension.

In layer normalization, for each neuron, the mean and standard deviation are computed across all inputs individually. Then, the neuron’s inputs are normalized across the whole layer using these statistics, to have a unit Gaussian distribution. This means that each layer’s inputs are centered around zero and scaled to have unit variance.

Decoder

The function of the decoder is to take the context and generate new tokens. This is why each encoder layer takes in the input embeddings while the decoder layer takes in the generated output embeddings from the previous decoder layer. The decoder also has a stack of 6 layers. Instead of having two sub-layers like the encoder, each decoder layer has three.

Like the encoder, it has a self-attention and a feedforward neural network layer. But unlike the encoder, its self-attention layer is masked, which means that each context can only store information about the previous tokens and not about the future tokens. This is done to prevent leftward information flow in the decoder to preserve the auto-regressive property.

And sandwiched between the two layers, there is a new cross-attention layer. The additional layer performs multi-head attention over the output of the encoder stack. It uses queries from the decoder layer (which are based on the output of the previous decoder layer) and keys/values from the parallel encoder output. This cross-attention mechanism enables the decoder to attend to all positions in the input sequence. It allows the decoder to focus on different parts of the input sequence while generating the output sequence.

The output embeddings are shifted to the right so that predictions are dependent on the predictions of their previous output.

After the decoder layer, there is a linear transformation and then softmax is used to generate probabilities of the next output token.

And this is all there is to the Transformers architecture! Please let me know if you have any questions in the comments or how I can better explain abstract concepts!

Until then, keep learning, keep building!

--

--