Breaking Down the Components of ChatGPT

Shitij Nigam
12 min readMay 7, 2024

--

Table of Contents

· Setup
Raw Data & Data Loader
Encoder Decoder Solution
Loss Function
Key Hyperparameters
Generative Function
Optimizer function
Training Loop
·
Transformers
(Basic) Scaled Dot-Product Attention
Enhancing your Attention model
·
Appendix
Tensor Dimension Gotchas
Other Open Questions

* This is not a how-to, simply an overview of the different parts that make up a GPT. For any unfamiliar terms, definitions can be found here. For other concepts such as batch / layer normalization, etc. — and why some things are done the way they are, check out neural network gotchas.

Setup

Raw Data & Data Loader

Text-based content is the core component on which you train your model to predict the next word — with each subsequent word acting as a label to the previous sequence. e.g. Wikipedia pages, Shakespeare, etc.

You also need a Data Loader to split data into training and validation sets (and often, test sets too) and stack up data batches (more on that below). The idea is to use the model to predict the next letter / word on a training set, then run the model on a validation set to see how far off the model was on predicting the next letter / word relative to what is actually in the validation set.

Encoder Decoder Solution

This converts letters / words into tokens. A token is simply a numerical representation of a word. e.g. Google uses SentencePiece, OpenAI uses TikToken.
Note: Things can get quite gnarly depending on how words are converted into numbers, e.g. “The ” could be the number 53914, but “ the” could be the number 29145. Sometimes when the tokenizer doesn’t recognize words, it splits up words into multiple tokens. It can be quite messy

Loss Function

This will help you evaluate the quality of your model, e.g. Cross Entropy Loss

Key Hyperparameters

  • Set Block Size T or Time (i.e. context length) for reading the training data; if the batch length is 8 (at a character level), then 9 characters are passed in
    e.g. If the entire training data is the string of english alphabet ABCD…...XYZ and a random block of 8 is picked up — say LMNOPQRS — then context (x) would be L, LM, LMN, LMNO, LMNOP, LMNOPQ, LMNOPQR, LMNOPQRS, and the corresponding target (y) would be M, N, O, P, Q, R, S, T
    N.B. The letters would be represented as tokens, not as actual letters :)
  • Set Batch size B, i.e. how many independent sequences we process in a forward / backward pass
  • Set Class Size C or Channels, i.e. number of potential options (letters or words) that can be predicted. This is typically also the number of embeddings esp. when you’re trying to decouple the vocabulary size
    N.B. For a single attention head, C may be equal to head size (more on this below in multi-headed attention)

Generative Function

A generative function uses the model’s predictions to determine the next letter. For e.g. in a Bigram generative model:
e.g. for a bigram,

  • (i) Input = an xB-like tensor of shape (B,T)
  • (ii) Input fed into model to generate logits of shape (B,T,C) — with C representing corresponding probabilities
  • (iii) Extract last time (T) step of these logits of shape (B, C) — since we want to predict what comes right after the input
  • (iv) Turn this into est. probabilities using a Softmax function
  • (v) Use probabilities predict index of next letter (Multinomial works best for multiple reasons) — sometimes select letters may have similar probabilities of outcomes, so this randomization helps
  • (vi) Extract and append this letter’s index to the input

Optimizer function

This is your gradient optimization function to run your backward pass; it takes gradients and updates parameters using gradients in order to optimize the model weights. e.g. SCD, Adam, etc.
N.B. A good learning rate for Adam (a more advanced optimizer) is 1e-4. For smaller networks, a higher learning rate of 1e-3 works

Training Loop

A training loop essentially does the following:

  • (i) Sample a new batch of data per step
  • (ii) evaluate the loss using your model
  • (iii) zero out the gradients to make sure they aren’t carried over
  • (iv) get the gradients for your parameters via backward() on the loss function
  • (v) update (‘step’) on the parameters

Transformers

(Basic) Scaled Dot-Product Attention

Attention of tokens (or nodes) is the crux of how relationships are determined between tokens. Arguably this deserves its own separate chapter, so here goes.

This amazing paper 😇 https://arxiv.org/pdf/1706.03762

Let’s say we start with a batch x of shape (B,T,C) — which holds a series of embeddings in a batch size B, for a block size or time T, with channels C ..

  • Head size: This determines the complexity of relationships that each token or node is allowed to hold about other tokens.
    The larger the head size, the more information it can hold but the more computational costs increase; the smaller the head size
  • Key k: This is a linear layer which holds information about every token corresponding to “What do I know”
key = nn.Linear(C, head_size, bias = False)
# Here, the # of inputs accepted by the linear layer will be of size C
# .. the outputs of size head_size
# .. with the weights W of shape (C, head_size)

k = key(x)
# .. and key = x @ W
# .. resulting in 'key' of shape (B, T, C) @ (C, head_size) = (B, T, head_size)
  • Query q: Similar to Key, this is a linear layer that holds information about every token corresponding to “What am I looking for”
query = nn.Linear(C, head_size, bias = False)
q = query(x)
# A similar math holds true for the query layer too
  • q @ k — Transposed: This holds information about the relationships between keys and queries; the larger the product, the more relevant the key is to the query. In case of self attention / decoder block, only q @ k-transposed pairs corresponding to the previous tokens in the time step T are used and normalized using softmax.
    Important note: In an encoder block, the masking line wouldn’t be needed (e.g. for things like sentiment analysis, a transformer would need to understand the full context of the string of tokens)
qk = q @ k.transpose(-2, -1)
# i.e. (B, T, head_size) @ (B, head_size, T), to give qk of shape (B, T, T)

tril = torch.tril(torch.ones(T,T)
# to create a lower triangular tensor of 1

qk = qk.masked_fill(tril == 0, float('-inf'))
# this fills the portions that are 0 in tril with -infinity in kq
  • Softmax ( q @ k-transposed ÷ sqrt (head_size) ) to get the “net” probabilities of the query key pairs with the highest relevance
    N.B. q & k-transposed needs to be normalized (i.e. divided) by square root of our head_size, since during initialization — esp. for large head sizes — “the dot products grow large in magnitude”
qk = qk * head_size**-0.5
# normalize qk

qk = F.softmax(qk, dim = -1)
# dim = -1 is to ensure that softmax is computed on the last dimension
# softmax converts all values in last dimension to e^x
# qk is of shape (B, T, T)
  • Value: Similar linear layer to queries and keys of shape (B, T, head_size); the output is “computed as a weighted sum of the values” — where weight is computed by multiplying values to the Softmax of q @ k-transposed
    i.e. assign a set of values to each token in a batch, then use the query/key pairs to determine which values are most important
value = nn.Linear(C, head_size, bias = False)
v = value(x)
output = qk @ v
# attention will have shape (B, T, head_size)

Additional notes

  • Self attention: This type of deployment is called self attention because keys, queries, values all come from the same source x
  • Cross attention: This comes into place when a separate set of nodes is used to inform keys, queries, values

Enhancing your attention model

Once attention is set up, a few more optimizations need to happen

  • Attention → Multi-headed Attention: This basically introduces additional context to attention by splitting up or adding more independent channels, applying them in parallel, and then concatenating their results over the channel dimension C
    E.g. if attention was simply allowing tokens to understand basic relationships with each other (e.g. grammar), multi-headed attention allows more sophisticated relationships between tokens (e.g. grammar, sentiment, ..)
    N.B. Given that we now add multiple heads, we could theoretically reduce the head_size when initializing multi-head attention proportional to the # of heads
class Head(nn.Module):

def __init__(self, head_size):
# key = ..
# query = ..
# value = ..
def forward(self, x):
# B, T, C
# k = key(..)
# q = query(..)
# v = value(..)
# qk calculation ..
return x # of shape B,T,head_size
class MultiHeadAttention(nn.Module):

def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
# create multiple heads, run all of them in parallel

def forward(self, x):
return torch.cat([h(x) for h in self.heads], dim = -1)
# concatenated over channel dimension (last dimension)
  • Multi-headed attention (“Communication”) → Feed Forward (“Computation”): Typically, with only multi-headed attention, you can go “too fast” into logits calculation (i.e. tokens “look at each other but don’t have time to think about what they found”). Feed forward is a multi layer perceptron (MLP) with at least one hidden layer, which introduces more context and adds non linearity. Feed forward theoretically allows capturing of other aspects (e..g sentiment, etc.) to allow the model to go beyond surface level connections
class FeedForward(nn.Module):

def __init__(self, n_embeddings):
super().__init__()

self.net = nn.Sequential(
nn.Linear(n_embeddings, n_embeddings),
# takes inputs with n_embeddings as last dimension
# .. and pushes outputs with n_embeddings

nn.ReLu()
# adds non linearity to avoid entire network being linear
)

def forward(self, x):
return self.net(x)
  • Block = Multi Headed Attention + Feed Forward: Blocks are grouped representations of communication (using attention) followed by computation (using feed forward), that can be stacked over multiple times
class Block(nn.Module):

def __init__(self, n_embeddings, num_heads):
super().__init__()

head_size = n_embeddings // num_heads
# to allow embeddings to spread across multiple heads

self.sa = MultiHeadAttention(num_heads, head_size)
self.ffwd = FeedForward(n_embeddings)

def forward(self, x):
x = self.sa(x)
x = self.ffwd(x)
return x
## ## ##
## in your transformer
## ##

token_embedding_table = nn.Embedding(vocab_size, n_embeddings)
position_embedding_table = nn.Embedding(block_size, n_embeddings)
blocks = nn.Sequential(
Block(n_embeddings, num_heads = 4),
Block(n_embeddings, num_heads = 4),
Block(n_embeddings, num_heads = 4)
)
  • Blocks -> Blocks with Residual Connections or Skip Connections: With multiple blocks sequentially stacked — the network becomes deeper, making it harder to optimize the network, especially since deeper / older blocks have a lower impact on the overall outcome due to multiple transformations (typically referred to as vanishing gradient). Residual connections solve this issue by allowing older blocks to have a say in the final output through an ongoing ‘pathway’ of information.
class Block(nn.Module):

def __init__(self, n_embeddings, num_heads):
super().__init__()

## same as above
head_size = n_embeddings // num_heads
self.sa = MultiHeadAttention(num_heads, head_size)
self.ffwd = FeedForward(n_embeddings)

def forward(self, x):
# this is where things depart from the previous implementation
x = x + self.sa(x) # "Fork off, communicate, come back"
x = x + self.ffwd(x) # "Fork off, compute, come back"
return x
  • Blocks with Residual Connections need Projections: When residual connections are added, often the inputs and outputs don’t match Dimensionality. Projections ensure that inputs / outputs match dimensions (and can often also add an additional layer of non linearity)
class MultiHeadAttention(nn.Module):

def __init__(self, num_heads, head_size):

# following bits same as before
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

# projection layer added in
self.proj = nn.Linear(n_embeddings, n_embeddings)

def forward(self, x):
# changed forward pass with projection
out = torch.cat([h(x) for h in self.heads], dim = -1)

##
## simple linear transformation of concatenation of outputs
##
out = self.proj(out)
return out
class FeedForward(nn.Module):

def __init__(self, n_embeddings):

# following bits same as before
super().__init__()

self.net = nn.Sequential(
nn.Linear(n_embeddings, n_embeddings),
nn.ReLu()

##
## add simple linear projection layer
##
nn.Linear(n_embeddings, n_emebeddings)
)

def forward(self, x):
return self.net(x)
  • Blocks with Residual Connections → LayerNorm: LayerNorm normalizes features (across C) using their mean and variance, ensuring that they have unit Gaussian distribution (i.e. 0 mean, 1 standard deviation). It is typically applied before determining attention and feed forward (‘pre norm formulation’).
    N.B. Batch & Time act as batch dimensions; this is per Token normalization; eventually — with the existence of gamma and beta, they may not be normal anymore given that they are trainable parameters
    N.B. Batch normalization — which happens at a batch level B. Since ‘the computation doesn’t span across examples (or batches)’ there is no running buffer needed. (Good NLP example here at the bottom)
Source: https://arxiv.org/pdf/1803.08494#page=3
class Block(nn.Module):

def __init__(self, n_embeddings, num_heads):
super().__init__()

## same as above
head_size = n_embeddings // num_heads
self.sa = MultiHeadAttention(num_heads, head_size)
self.ffwd = FeedForward(n_embeddings)

## layer norms
self.ln1 = nn.LayerNorm(n_embeddings)
self.ln2 = nn.LayerNorm(n_embeddings)

def forward(self, x):
# layer norms are applied before passing them into attention & ff
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
##
## at the end of the transformer, there should be another layerNorm
##

token_embedding_table = nn.Embedding(vocab_size, n_embeddings)
position_embedding_table = nn.Embedding(block_size, n_embeddings)
blocks = nn.Sequential(
Block(n_embeddings, num_heads = 4),
Block(n_embeddings, num_heads = 4),
Block(n_embeddings, num_heads = 4),

##
## here
##
nn.LayerNorm(n_embeddings),
)
# ..
  • Dropouts: Another regularization technique to avoid overfitting! Dropout is typically added to “randomly prevent some of the nodes from communicating” based on an allocated percentage (by dropping them to zero). This mask of turning nodes to 0 changes every forward/backward pass, allowing us to train an ensemble of test networks — with the ensemble merged during testing.
    N.B. These are typically added right before a connection back into the residual connection pathway (i.e. at the end of multi-headed attention, at the end of feed forward, at the end of calculating affinities)

Woof. This took longer than usual.

Appendix

Tensor Dimension Gotchas

Play with and match up dimensions of B/T/C with your loss function
e.g. If you’re going with a simple bigram model which takes one character and predicts another character, you can select cross entropy loss. However, Pytorch’s cross entropy function requires inputs to be setup with class size being the second dimension, thus requiring some tweaking.

C refers to the class size, i.e. number of possible outcomes that the model is expected to predict

What does this mean? Let’s say we have training data which contains all letters from A-Z tokenized at 0–26 (with one extra for space), containing the sentences “how are you” and “are you okay”

  • Your xB (x-Batch) is a tensor of dimensions (B, T) with batch size B Time/Block Size T
    e.g. xB = [how, are] = [ [7, 14, 22], [0, 17, 4] ] with each letter corresponding to its relevant tokens. Here, the batch size is 2 i.e. 2 rows, the block size is 3, i.e. 3 letters, with xB of shape (2,3)
  • yB (y-Batch) is the target, and is of the same corresponding shape as xB (B, T)
    e.g. yB = [ow_, re_] (_ referring to space) = [ [14, 22, 26], [17, 4, 26] ] with each letter corresponding to relevant tokens, with yB of shape (2,3)
  • In a Bigram model, the Class Size C = vocab. size, implying a token embedding table of shape (C, C), i.e. for every letter appearing in the training data, there are equivalent possible outcomes but with varying probabilities
    e.g. using the above example, the embedding table will be of size 26,26, where 26 is the vocab size
  • Therefore, Logits = Embedding(xB), with a shape of (B, T, C), with logits representing a tensor which contains — for each element in xB — the corresponding probabilities of each letter in the vocabulary 🤯
    e.g. the logits table will basically be of the shape (2,3,26) and look something like this ..
[[[-8.6626e-01, -1.3694e+00,  1.0005e-01, -2.4852e-01, -9.1255e-01,
-8.4218e-01, 9.0579e-01, 8.8845e-01, -7.9284e-01, 5.3425e-01,
-7.9848e-01, 5.1823e-01, -1.1608e+00, -1.0227e+00, -5.0643e-01,
7.8016e-01, -3.3962e-02, 4.1387e-01, -3.7672e-02, -1.4856e+00,
-2.3876e-01, -2.6241e-01, -5.8329e-01, -7.2831e-01, 3.9296e-01,
-1.3158e+00], # 26 possible outcomes, 26 probabilities :)
[ .. ],
[ .. ]],
        [[ .. ],
[ .. ],
[ .. ]]]
  • Now, since Cross Entropy function requires the class size (in this case the vocab size) to be the second dimension, we would need to modify logits and targets so that the second dimension of logits is C (instead of the third)
B, T, C = logits.shape
logits = logits.view(B*T, C)
yB = yB.view(B*T) # also the targets
  • After this reshaping, cross entropy can evaluate the loss because C is now the second dimension for logits

Other Open Questions

  • Why is LayerNorm needed? Is it just for normalization in general? I’m not fully sure, and I think once I do some passes before/after LayerNorm I’ll feel more comfortable. So far the answers have been unsatisfactory (This paper suggests ‘expressivity’)
  • Utility of projections: I get that it helps with dimensionality, but beyond that I’m unclear
  • Do residual connections need to exploding gradients? This was the issue that LSTM aims to avoid. Does LayerNorm prevent this?

--

--

Shitij Nigam

x Strategy & Analytics @ YouTube | x Telecom Consultant