Understanding Self-attention & GPT models

Taus Noor
13 min readOct 16, 2023

--

“Attention Is All You Need”, a 2017 paper from Google kicked off a new generation of AI technology — one that could write code or essays, answer questions, search the Internet and do much more. Unless you’ve been living under a rock, you already know about ChatGPT and have most likely used it. Here, I’m attempting to explain how GPT models work (on a high level) within a tool like ChatGPT in a way that is hopefully easy to understand for folks not deep into deep learning (but still requires a basic understanding of linear algebra and machine learning).

The What

Before we get into self-attention, we need to understand what GPT models are trying to do. For the purpose of this article, we’ll specifically be discussing the use case of ChatGPT, which takes in an input text sequence and produces an output text sequence (e.g. “What is basketball” → “Team sport to score by shooting hoops.”). So, what is the model trying to do here? Search the internet? Answer a question? Learn basketball? Nope. The model is simply trying to predict the “right” text sequence that should follow the input it just got (based on what it has seen during training). What follows a question? An answer, that’s right. This obviously is quite high level — depending on the kind of training data used, the model can learn to do a wide variety of things; but at its core it’s simply trying to predict the text sequence that should follow based on what it learned during training (hence the “right” in quotations).

Embeddings, positions & what not

GPT3/4 (and all deep learning models like BARD, Llama, etc.) are mathematical models, represented as mathematical functions. They don’t understand words or text, just numbers. As such, we need to first convert the sentence (text sequence) into a bunch of numbers. We will not get into the details of how this happens, but essentially each “word” in the text sequence gets converted to a fixed length vector of numbers that is unique to the word itself (e.g. “What” → [0.1, 0.03, 0.93]). Another way to think about it is that there is a mapping between words and vectors, and for the purpose of deep learning models we use the vector corresponding to a word so we can create elaborate mathematical models that can understand language (it’s usually more complex than that, but you get the idea).

By doing so, we end up with a vector (i.e. collection of numbers) for each word in the input sequence. If the input was “What is basketball”, we end up with 3 different vectors for each of the words (e.g. [0.1, 0.03, 0.93], [0.24, 0.61, 0.43], [0.11, 0.39, 0.20]). These vectors are called embeddings (i.e. mathematical representations of the words in your input). There are many ways to do this, and different models can do this differently, so we will not go into details here.

Once you have the embedding vectors, we have to mathematically incorporate “word position” information into those vectors somehow (i.e. the vector should be modified in a way that it contains information on where in the sentence each respective word was). This happens using a process known as position encoding, which we won’t go into detail on. All you need to know is these embeddings are transformed using mathematical functions/models to produce new vectors for each word and these new vectors mathematically incorporate word position information into the embedding. Now your vectors may look something like this: [0.22, 0.93, 0.12], [0.33, 0.19, 0.26], [0.27, 0.89, 0.13]. The way to think about it is that the vectors mathematically represent the following information — “Word: What; Position: 1”, “Word: is; Position: 2” and “Word: basketball; Position: 3” respectively (i.e. each vector represents the word and its position in the input sequence).

These position encoded vectors form the input to the next stage of the GPT model — self-attention.

Note: we now have 1 position encoded vector for each word (so a total of 3 vectors corresponding to the 3 words from the input prompt “What is basketball”).

Query, Key, Value — Could attention get any more complicated?

Your first self-attention layer now does the following operations on each position encoded embedding vector–

  1. Use the vector to create a query vector for that word in that position — this just means the input position encoded vector goes through a mathematical function (whose parameters are trained during training) to produce a new vector that is referred to as the query vector.
  2. Use the vector to create a key vector for that word in that position —i.e. the input position encoded vector goes through another (different) mathematical function to produce another new vector that is referred to as the key vector.
  3. Use the vector to create a value vector for that word in that position — i.e. the vector goes through yet another mathematical function to produce yet another new vector that is referred to as the value vector.

At this stage, you end up with a query vector, a key vector and a value vector for each word (thus, a total of 9 vectors for the 3 words in the input sequence).

This figure shows how query, key and value vectors are created from the position-encoded vector for one word. The Query, Key and Value Weight matrices consist of numbers that are “learned” during training.

Note: The model within a self-attention layer has 3 different mathematical functions that take in the same input position-encoded vector for a specific word to produce the query, key and value vectors for that word in that position. These mathematical functions use matrix multiplication with weight matrices to do this, where the numbers in the weight matrices are “learned” during training. All position encoded vectors for all words go through these 3 functions independently to produce query, key and value vectors for each word.

The dot product of a query vector (of any word) and a key vector (of any word) is a single number that represents the “similarity” or attention the network needs to provide to one word in context of the other. So, Query_vector_What ⋅ Key_vector_basketball produces the attention score for the word “What” with respect to the word “basketball”. Similarly, Query_vector_What ⋅ Key_vector_What represents the attention value for the word “What” in context of itself. Intuitively, attention score denotes how important one word is in context of another.

This figure shows how the attention score is calculated as a dot product (prior to applying softmax activation)

This process to calculate attention score is used to calculate attention scores for each word with itself and the words preceding (i.e. coming before) it. In our example, the model would be calculating attention scores for What-What, is-is, What-is, basketball-basketball, is-basketball, and What-basketball (i.e. for every word, there would be an attention score with itself and with every word that came before it). Each word can therefore have a varying number of attention scores; we will create a vector of attention scores for each word (the first word will have a vector with only 1 number, representing the attention score with itself — i.e. attention for What-What; the second word would get an attention vector with 2 numbers, i.e. representing attention scores for What-is and is-is; so on and so forth).

Once we have the attention vectors for each word, we will then apply the softmax function to each attention vector to ensure all of its attention scores (i.e. with respect to itself and its preceding words) add up to 1.0 (for example: the 3rd word “basketball” may get attention scores of 0.7, 0.1, 0.2 after applying softmax for “What”, “is” and “basketball” respectively (i.e. all words preceding it and itself); note: adding them up gives you 1.0).

Note: This process of calculating attention scores for only words preceding a word (and not words that appear after a certain word) is called masked self-attention.

Now that we have attention scores, we have to calculate the output of the self-attention layer. The self-attention layer will produce one output vector for each word, so we’ll go through the process word by word.

Let’s start from the first word “What”. The output vector for this word from the self-attention layer will be Value_vector_what x Attention_score_what_what (i.e. just multiply the value vector with attention score of the word with itself).

Note: Attention_score_<word1>_<word2> represents the attention score for word1 in context of word2 (which is a word preceding word1 or word1 itself); it is a scalar (single) number.

Since the value vector is a vector of numbers and the attention score for one word with respect to another is a scalar (single) number, the output of this multiplication operation is a vector of the same size as the value vector.

The process is similar for the second word “is”. The only difference is, that on top of using its own value vector and attention score with itself, we will also add the value vector of the words preceding it and multiply it with the respective attention score (in this case just 1 other word — “What”). Here’s how we would calculate the final output of the self attention layer for the second word “is” –

(Value_vector_is x Attention_score_is_is) + (Value_vector_what x Attention_score_is_what)

Similarly, for the third word “basketball”, we would do –

(Value_vector_basketball x Attention_score_basketball_basketball) + (Value_vector_is x Attention_score_basketball_is) + (Value_vector_what x Attention_score_basketball_what)

Note: We are adding multiple vectors of the same size as the value vector, ending up with a vector of the same size

In other words, the final output of the self attention layer for a specific word is a weighted sum of value vectors of the word itself and all preceding words where the value vector of a word is weighed by its attention score with respect to the current word. Intuitively, you can think of it like this: when we reach the 3rd word “basketball” — the word “What” represents 70% of the meaning (denoting it is a question) and the word “basketball” represents 20% of the meaning (denoting the object the question is on” and the word “is” denotes 10% of the meaning (since it doesn’t have a huge impact on the meaning of the question); as such, the self attention layer outputs 70% of Value_What + 10% of Value_is + 20% of Value_basketball. Another way of looking at it is that these attention weights denote how important it is to retain information on a certain word in the preceding input sequence is in order to generate the next. Once done, you get the final output vector for a word after applying self-attention.

This figure shows how the self-attention layer would calculate the output vector for the 3rd word “basketball” using query, key and value vectors of itself and words preceding it (over-?simplified)

Now, once again we have 3 vectors for the 3 words (each vector corresponding to each word). Remember what we started with? 3 vectors for the 3 words (1 for each word). As such, these 3 vectors can form the input for the next self-attention layer, which will further apply self-attention on these vectors. This can go on sequentially for as many layers as the model has (output from layer 1 goes into layer 2 as input, and then output from layer 2 goes into layer 3 as input and so on and so forth until the last layer). A model like the one behind ChatGPT will have several such layers (e.g. the largest GPT-3 model has 96 such layers) stacked on top of one another (with each layer having multiple “self-attention heads”). That being said, like this one, every attention layer will take an input vector for each word and produce an output vector for each word. In the end, we take the output from the very last layer in the model, which is still going to be 1 vector per word (so 3 vectors for the example above for “What is basketball”).

Feed Forward Network (FFN)

To predict the immediate next “word” (i.e. the first word of the output sequence), we would take the final output of the last self-attention layer and use only the vector representing the last word “basketball” for the remainder of the process. If you recall how self-attention works, the self-attention layers would output one vector for each word in the input sequence; from those vectors, we would only use the output vector representing the last word to predict the next word.

Note: The output vector for the last word was calculated after applying self-attention on all words prior to it and itself (therefore: all words in the input sequence) and as such represents information that is most relevant/important to predict the next word. Also note that self-attention was applied to position encoded vectors, which means that not only does the “attention” mechanism take into account its relationship with preceding words, but it also takes into account where (which position) in the sentence each word was.

The output vector for this last word (in our example: “basketball”) will go through a feed forward network which operates using fully connected layers. Essentially it will take a vector of size m and produce an output vector of size n using matrix multiplication and then an activation function (a non-linear mathematical function).

The output vector from the FFN will then be used to predict the probability of each word in the vocabulary to be the next “word”, and in the end the network will output the word with the maximum probability. This is done is a very similar fashion as the FFN where a weight matrix is applied to produce the “chance” of each word to be the next word and then softmax activation function enables us to convert that into a probability value.

Once this is done, you have the first word in your output.

Note: Since the algorithm above can be used to apply self-attention to any number of words above, the input can theoretically be of any size and with enough computing resources, the model will be able to remember everything it needs to, to predict the next word (since each time self-attention layer looks at every word prior to the last word to calculate the final output).

Are we done yet?

We just got to the first word on the output. Typically the output can be of any length. So, how does the network output the next word?

Well, we append the new (predicted) word to the input sequence and that becomes the new input sequence (i.e. input sequence + [new predicted word] = new input sequence). The whole process repeats to predict the next word (i.e. second word of the output). This goes on and on an on until the network outputs a “special word” or special token — often called “end token”– that denotes the end of output at which point we stop running the model and have the final output text sequence.

This figure demonstrates how a GPT model using self-attention responds to a user input (text sequence) on a high level, starting at the arrow on the top left corner (over-?simplified)

Remember: ChatGPT and LLMs like it is taking an input sequence of text and predicting the next word at any given point in time. As such, once you have the next word, you can add it to the input to predict the next word and keep repeating until the model generates and end token denoting end of output. In practice the input text can look like “User: [user prompt] System: [system output] User: [user prompt 2] System: [system output 2]”, so the same text sequence includes both input from the user and output from the system, used to keep generating the remainder of output and handling future inputs/conversations.

How does all of it happen so fast so many times?

Well, first and foremost, running a large language model (LLM) like ChatGPT/GPT-4 can be quite expensive as it requires a lot of computing resources. That being said, it’s still extremely impressive how fast ChatGPT responds to queries. What other tips and tricks could they be using?

First, the model does not have to recompute query, key, value vectors each time for the words preceding the last output word since it already did it in previous stages and can re-use those values (which simplifies computation and means the network does not actually have to go through the whole process for each subsequent output word). In general the system can benefit from various caching techniques to avoid recomputing vectors that it already computed in a previous step.

Secondly, most of the steps in the self-attention layer can be parallelized. Since the query, key and value vectors of each word are independent to each other, they can be simultaneously calculated in separate processes and then used sequentially when its time to calculate the final output vector (where we calculate and use attention scores).

Lastly, there’s likely various optimizations both within the network like pruning (i.e. running a subset of computations each time that are most relevant for the task at hand) and at the hardware level that can be leveraged to speed up inference time even more.

Okay, but how does the model learn?

The model is initialized with random numbers as weights (across both the self-attention layers and the FFN) and then uses gradient descent through backpropagation to “learn” (i.e. get better at producing the “right” output). On a high level, the process works as follows –

  1. The model is given an input sequence and it predicts the next word using self-attention.
  2. If the predicted next word is not the same as the expected next word, the model parameters are updated using gradient descent.
  3. This process is repeated for numerous times using a large training dataset (with labelled input and expected output).

The math behind this process can be quite complicated, especially for an LLM, since we have to train shared query, key and value parameters across multiple words in an iteration. Given the process is obviously a lot more complex than we can cover here, we’ll leave it at that.

One More Thing

ChatGPT, BARD and most of the new LLMs used for conversational AI are decoder-only transformers that leverage self-attention as explained above. The initial paper from Google introduced transformers with an encoder-decoder architecture where encoders used self attention to understand input and decoders used attention to generate output from the “understood” input (very useful for tasks like translation where you have an input and an output and that’s it, not necessarily conversational in nature). However more recent conversational AI models like the one behind ChatGPT (where you have an indefinite sequence of inputs and outputs as part of a conversation), tend to use a decoder-only architecture that leverages self-attention in the decoder to produce its output. This post is written in the context of those kinds of decoder-only models, but the idea of self-attention is the same in an encoder-decoder transformer.

Note: This article is a high-level (over-?)simplified version of how GPT models (and transformers in general) work. This is not intended to be a detailed explanation of the math, algorithms or performance of different LLMs. There are numerous optimizations, mathematical functions and transformations that an LLM can actually use in practice which are outside the scope of this article. Also, the numbers (for input, weights, etc.) used in this article are random and may not be representative of actual numbers that you may see in a GPT model.

--

--

Taus Noor

Engineering leader @ DoorDash. Ex Amazon/AWS, NVIDIA & startup founder. Forbes 30U30.