( Why the math gotta be like that tho? )

Decoding the Decoder King: GPT-2

An Intuitive Approach to Understanding the “Why” and “How” Behind Mathematics of GPT-2.

Pickleprat
12 min readFeb 23, 2024

--

Image taken from Jay Alammars Illustrated GPT-2

Recently, I found myself learning about Language models released, after the release of the infamous paper “Attention Is All you Need” introducing transformers. These models essentially include the following couple of categories.

  1. Encoder Only Transformers.
  2. Decoder Only Transformers.

Of course there’s a third Category which are just encoder-decoder transformers, but I think of them more as a super-category inside which these two categories lie and refer to them with just the term transformers.

If you’re a double digit IQ person like me, you also require illustrations to understand these complex state of the art architecture, and for people like us, Jay Alammar has made a beautiful Illustrated Series explaining complex State of the Art like BERT and GPT-2.

If you haven’t read his transformer and BERT Encoder only architecture, and have no idea how any of those architectures then I’m afraid you’re missing out on one of the best explanations for transformer architectures of all time. So I implore you to check those out first because most of this article takes influence from his original ideas and tries to answer questions that he didn’t in his article.

In these you might get an insight as to how encoder only and vanilla transformer architectures work.

BERT, which is an encoder only architecture, contains two heads:

  1. Masked Language Modelling Head : In this head the model randomly masks certain input tokens and tries to predict the masked token from both left as well as right context. In attempts to learn how to predict the masked token, the model gains a thorough understanding of the how words depend on other words in the left and right context.
  2. Next Sentence Prediction Head : BERT, doesn’t simply take a single sentence input during training phase. It takes in one two input sentences, one of which is the current sentence, and the other is either randomly selected from the corpora or is the sentence following the current one, with a 50% probability of being either one. This allows the model to learn semantic framing of sentences, and sentence dependencies.

After the pretraining is complete, a custom head can be added to the BERT transformer and the model can be fine tuned on your specific use case. The model in itself isn’t that complex. Encoder in the general sense is used for “understanding” a language either way. So it makes sense, that BERT is not great at text generation but excellent at tasks that require a thorough understanding of the language.

However, the thing that made me feel eerie were decoder only architectures. In the original attention paper, the decoders were to receive a context vector from the encoder as [K_enc, V_enc] and then use that information in combination with the currently provided Query Matrix, to predict the next token. But if we are only using Decoders then how can we train the model to predict the next token?

The key thing you have to understand about decoder architectures, is that the training and inference phase of both do not execute in the same fashion. Let’s unpack this in greater detail from the beginning, let’s train our GPT-2 and see how it goes.

Training GPT-2

Input Stage: Preparing the input for the attention Mechanism

Let us consider that we have a sentence “What is the best Anime”. The first step of course is tokenization. The words in this sentence are going to be broken down into tokens using the Byte-Pair Encoding Algorithm, which is something that we won’t get into right now, but all you need to understand is that it’s going to do this to our words:

words = [“<s>”, “What”, “is”, “the”, “best”, “Anime” ]

The “<s>” is called the start token, and is used as the first token in case the model does not get any input.

What happens in the backend is, if you provide no input to the model, there is a hyperparameter called “top_k” set which will predict the top_k best possible tokens to predict the next word given the <s> token.

This process is called Unconditional Sampling.

Note that the goal of training to model the GPT-2 Transformer to predict the next word in the sentence given the previous string of tokens, so in our case: ( Here → stands for predicting )

<s> → What

<s> What → is

<s> What is → the

<s> What is the → best

<s> What is the best → Anime

<s> What is the best Anime → <end>

If output == <end>: stop

So our target is going to be trying to predict

target = [“What”, “is”, “the”, “best”, “Anime”, “<end>”]

This is called Auto regressive modelling. It is also the point where BERT and GPT-2 become two completely different types of Language Models.

BERT used bi-directional modelling to predict the sentences, GPT-2 on the other hand is only using the previous tokens to predict the next tokens.

But wait our model doesn’t understand UTF-8 characters! That’ll be crazy to code out computation wise. It’ll have to convert them into word embeddings. You can visualize this intuitively as a matrix of 50, 257 rows and 768 columns. Each word being mapped to an index between 0–50,256.

Image for the word_map matrix: Taken from Jay Alammar’s Illustrated GPT-2

These numbers are selected because GPT-2 was trained on a corpus of 50, 257 unique words and the small model was provided with 768 embedding vector size.

This mapping is known as word id and let’s assume it is stored in a variable called word_map.

So the model is first going to map the words in the input to their respective ids.

words = [100, 54, 7892, 345, 123, 984]

And then it is going to convert each of the inputs into 768 dimensional vectors. So now our input matrix is X.shape = [6, 768] where X is the input matrix.

Okay but notice how given an input “What” regardless of wherever the position of “What” is, in the input sentence it will always return the exact same input embedding because the embeddings are mapped to the unique id and the unique id does not change with respect to the position of the word in the text.

So even if we provide the input sentence “is Anime the best What” the model will not know the difference between the two inputs. So for that reason we add positional encodings. To visualize this you can consider a matrix, pos, of size [1024 x 768], where 1024 is the maximum number of tokens allowed in a single input. In this case we’ll access the first five rows of this matrix as we need only those.

pos matrix visualized: Also taken from Jay Alammar’s illustrations

pos_enc = pos[:6, :]

And then,

embs = X + pos_enc

Now, given a certain word AND position we will always get the same input for the model. This feels a lot better and now we’re ready to feed our input into the decoder.

Attention Stage: Creating the attention Matrix with a Mask.

Now in the illustration by Jay Alammar, it shows that inputs are passed down one by one for illustrative purposes. But there actually isn’t a need to pass the input tokens one by one in the training phase.

You can simply pass the entire matrix all at once to get the desired output of vectors. To put it more mathematically, we first take our input matrix of size [6, 768] and then transform it using a weight matrix of shape [768, 3 x 768] called as W <q, k, v>. This is going to yield a single matrix of shape [6, 3 x 768]

But why did we use a transformation matrix of shape [768, 3 x 768 ]?

Well, we wanted the vector output to be a concatenated version of Q, K, V vectors. We can now simply split them into each of them respectively

Let output = embs.dot(W <q, k, v>), then

Q = output[:, :768]

K = output[:, 768: 2 * 768]

V = output[:, 2 * 768: ]

These are then going to be split once more into 12 chunks for each Attention head, since GPT-2 small model has 12 attention heads. Because of which, K, Q and V vectors each are going to split into 12 parts respectively. So our set of variables will become [K, Q, V] = < Kᵢ, Vᵢ, Qᵢ> ∀ i ∈ [1, 12]

Each Kᵢ, Vᵢ, Qᵢ being of dimensions [6, 64]. ( because 786/12 = 64)

Each Qᵢ is going to get multiplied by the transpose of Kᵢ and then normalized, masked and multiplied by Kᵢ, to produce our typical attention vectors.

Zʲ = softmax( Qʲ x Kʲ.T + M / sqrt( dₖ ) ) * Vʲ

Here M is the mask of shape [6, 6] as a lower triangular matrix with lower half filled with 0’s and the upper half filled with -∞.

Getting out of the decoder

In this case, Zʲ will be [6, 64] vector. Now, the Z’s from all the heads will concat to form Z of shape [6, 768] which will then go through W𝓏 of shape [768, 768].

This in turn is going to produce the input for the final feed forward neural network layer S = FFNN(Z) which will also be of the same shape [6, 768].

The next task of course is going to be to convert these inputs into token logits by making it go through a final layer of token embeddings which is going to be W𝒻 [768, 50,257]. So now finally you have the output matrix O.

O[6, 50,257] = softmax(S[6, 768] * W𝒻 [768, 50,257] )

Now, you know the drill from here. We compare the outputs with the actual labels which are [“What”, “is”, “the”, “best”, “Anime”, “<end>”] and then compute the loss and adjust the weights.

End of Training

This is how the model works in its training phase. This concept was bugging me throughout learning GPT-2 using Jay Alammar’s illustrated GPT-2 because there wasn’t any mention of it happening in a non sequential ordering.

The problem with sending it sequentially is that during training when you send the first start token <s>, the model will have no idea what to predict next since it doesn’t really have enough context. It will therefore resort to predicting the next token using unconditional sampling, and then start rambling randomly.

But wait… we didn’t actually resolve anything, if the masking is still involved, then then first token is still going to use ONLY <s> to predict the next token is it not? Wouldn’t it still have to use the output of the unconditional sampling to predict the next token of the sentence?

Well that is in fact a great question, to propose. The answer to that is this:

When you’re in the training phase, this is essentially what is happening when you do it sequentially. Let F be the model that takes in an input and predicts an output.

input = [“<s>”, “What”, “is”, “the”, “best”, “Anime”]

label = [“What”, “is”, “the”, “best”, “Anime”, “<end>”]

Scenario 1: Sequential Inputs

The first token <s> goes into the model and then by unconditional sampling the model produces a random token:

<s> → F → “Blast”

Next the model uses <s> Blast to predict the next token:

<s> Blast → F → that

And so on…

<s> Blast that → F → Funky

<s> Blast that Funky → F → Music

<s> Blast that Funky Music → F → Boy

<s> Blast that Funky Music Boy → F → <end>

predicted = [Blast, that, funky, music, boy, <end>]

Note that this will always be the case. Despite the model being trained for hours, the token <s> will always produce a random input and using the output for that as a token input to predict the next sentence would be a bad idea.

By sending the entire matrix inside, what we have done instead is the following:

Scenario 2: All at once

When we do the training all at once instead of using the previous output as the next input, we use the CORRECT WORD THAT SHOULD HAVE BEEN PREDICTED, as the input for the next word output.

Something like this…

<s> → F → Blast

<s> What → F → is

<s> What is → F → the

<s> What is the → F → best

<s> What is the best → F → Anime

<s> What is the best Anime → <end>

predicted = [Blast, is, the, best, Anime, <end>] and then adjust weights.

This will adjust weights slightly better than the last scenario because most of the output is correct and will allow the language model to understand language modelling more accurately.

The weights, if done right, will adjust so that the results of the top_k probabilities will only predict a certain number of sequences that are used most frequently in the corpora of text. So it will have a limited number of possible sentence start points to choose from and from there on out it will adjust the weights properly by constructing the rest of the sentence accurately.

But what about inference? In inference we don’t have the entire sentence. how do we predict the next word in inference?

Inference for GPT-2

Now there isn’t any solution to the unconditional sampling during inference phase. If you don’t provide any input, the model has no choice but to start randomly rambling and making up a sentence that may make sense but would be completely random.

What you have to understand is that these language models are entirely input dependent. If you provide it with a stupid input, it is just going to start spewing a random output consistent with the stupid input. No input is still a stupid input.

ChatGPT which is based on GPT-3.5 also spews a random output if you provide it with absolutely no context. However clearly it is programmed to avoid doing that and instead ask for some context to provide if none is provided. Which is I guess a handy workaround.

To test this even further I decided to do some trial and error with ChatGPT.

I provided it with the following prompt:

“Every input prompt I give from here on out, you have to complete the sentence.

For example:

Input: The night was dark
Output: The night was dark and I died.

Input: I do not like
Output: I do not like pickles, they make me sad.

Input: Hello how
Output: Hello how are you doing ? Are you okay?

Answer yes if you understand.”

You may try it if you’d like see what you observe.

Next I gave it a couple of prompts to ensure it understands well.

It gave appropriate responses for both of these inputs

But then I provided it with an empty input.

And here was the output.

Viola! But just in case let’s try it with a few more responses.

Okay fairly decent so far..

And then once again…

Same output.

If you don’t already know, the ChatGPT we know and love was GPT-3.5 fine tuned on a large corpus of chat format dataset in a supervised fashion. So I’m just spitballing here, but they must have added extra inputs, to output certain possible answers if provided with an empty string. This could’ve been one of the many responses that GPT-3.5 was fine tuned to generate instead of doing unconditional sampling.

But that’s the thing, we don’t typically provide a model with a null token. We provide it with some left context, which the GPT-2 Auto regressively generates the next tokens for.

And that’s the key. If given a certain context, the model will have no problem completing it given a certain amount of data. The program will sequentially keep generating the next words. For example

For input: <s> What is the best Anime ?

The model may go :

<s> What is the best Anime ? → F → Vinland

<s> What is the best Anime ? Vinland → Saga

<s> What is the best Anime ? Vinland Saga → <end>

output = What is the best Anime? Vinland Saga <end>

And it wouldn’t be wrong. Vinland Saga is the best Anime.

Yes. This whole thing was a commercial for my favorite Anime.

And with that we’re going to end this article. See you next time. I’ll be publishing more.

Thank you for listening to me rant.

--

--