Building a Vision Transformer Model From Scratch

Matt Nguyen
Correll lab
Published in
12 min readApr 4, 2024

The self-attention-based transformer model was first introduced by Vaswani et al. in their paper Attention Is All You Need in 2017 and has been widely used in natural language processing. A transformer model is what is used by OpenAI to create ChatGPT. Transformers not only work on text, but also on images, and essentially any sequential data. In 2021, Dosovitsky et al. introduced the idea of using the transformers for computer vision tasks such as image classification in their paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. In their paper, they were able to achieve excellent results with their vision transformer model compared to convolutional networks and required a lot less resources to train.

In this tutorial, we are going to build a vision transformer model from scratch and test is on the MNIST dataset, a collection of handwritten digits that have become a standard benchmark in machine learning. Notebook with the code from tutorial can be found here.

Figure 1: Vision Transformer Model Overview. Image: ViT Paper.

Import Libraries and Modules

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.optim import Adam
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import numpy as np

We are going to be building our vision transformer using PyTorch, so we will need to import the library plus others that we will be using in this tutorial.

import torch
import torch.nn as nn

We also need to import in torchvision.transforms in order to resize the input images and convert them to tensors. Resizing the input images is optional. You just need to make sure that the image size is divisible by the patch size.

import torchvision.transforms as T

We are going to be using Adam for our optimizer, so we need to import it in from torch.optim.

from torch.optim import Adam

We are going to import the MNIST dataset that we are using for this tutorial from torchvision. We are going to be using the DataLoader from PyTorch to help load the data, so we need to import that in as well.

from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader

Finally, we need to import in numpy which we will use to perform sin and cosine when creating the positional encodings.

import numpy as np

Patch Embeddings

class PatchEmbedding(nn.Module):
def __init__(self, d_model, img_size, patch_size, n_channels):
super().__init__()

self.d_model = d_model # Dimensionality of Model
self.img_size = img_size # Image Size
self.patch_size = patch_size # Patch Size
self.n_channels = n_channels # Number of Channels

self.linear_project = nn.Conv2d(self.n_channels, self.d_model, kernel_size=self.patch_size, stride=self.patch_size)

# B: Batch Size
# C: Image Channels
# H: Image Height
# W: Image Width
# P_col: Patch Column
# P_row: Patch Row
def forward(self, x):
x = self.linear_project(x) # (B, C, H, W) -> (B, d_model, P_col, P_row)

x = x.flatten(2) # (B, d_model, P_col, P_row) -> (B, d_model, P)

x = x.transpose(1, 2) # (B, d_model, P) -> (B, P, d_model)

return x

The first step in creating a vision transformer is splitting the input image into patches and creating a sequence of linear embeddings of these patches. We are able to achieve this by using PyTorch’s Conv2d method.

The Conv2d method takes the input images, splits them into patches and provides a linear projection of a size equal to the width of the model. By setting kernel_size and stride to patch size, we ensure that the patches are the correct size and there is no overlap.

self.linear_project = nn.Conv2d(self.n_channels, self.d_model, kernel_size=self.patch_size, stride=self.patch_size)

In the forward method we pass through the input that has shape (B, C, H, W) through the linear_project/Conv2D method and receives an output of shape (B, d_model, P_col, P_row).

  def forward(self, x):
x = self.linear_project(x) # (B, C, H, W) -> (B, d_model, P_col, P_row)
Figure 2: Conv2D applied on a single image. Each color represents which patch an element belongs to. Image: own.

We use the flatten method to combine the patch column and patch row dimensions into a single patch dimension giving us a shape of (B, d_model, P)

x = x.flatten(2) # (B, d_model, P_col, P_row) -> (B, d_model, P)
Figure 3: Flatten applied to Conv2d output. Image: own.

Finally we use the transpose method to switch the d_model and patch dimensions to get a shape of (B, P, d_model).

x = x.transpose(-2, -1) # (B, d_model, P) -> (B, P, d_model)
Figure 4: Transpose applied to flatten output. Image: own.

Class Token and Positional Encoding

class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super().__init__()

self.cls_token = nn.Parameter(torch.randn(1, 1, d_model)) # Classification Token

# Creating positional encoding
pe = torch.zeros(max_seq_length, d_model)

for pos in range(max_seq_length):
for i in range(d_model):
if i % 2 == 0:
pe[pos][i] = np.sin(pos/(10000 ** (i/d_model)))
else:
pe[pos][i] = np.cos(pos/(10000 ** ((i-1)/d_model)))

self.register_buffer('pe', pe.unsqueeze(0))

def forward(self, x):
# Expand to have class token for every image in batch
tokens_batch = self.cls_token.expand(x.size()[0], -1, -1)

# Adding class tokens to the beginning of each embedding
x = torch.cat((tokens_batch,x), dim=1)

# Add positional encoding to embeddings
x = x + self.pe

return x

The vision transformer model uses the standard approach of adding a learnable classification token to the patch embeddings in order to perform classification.

self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
Figure 5: Changing patch order can change an O into an X

Unlike models like LSTMs which take embeddings in sequentially, transformers take embeddings in parallel. While this increases the speed, transformers are not aware of what order sequences are supposed to be in. This is a problem because changing the order of the patches for an image would most likely alter the content of the image and what it is supposed to represent. An example of this is Figure 5 which shows that changing an image’s patch order can change the image from an O to something that more closely resembles an X. In order to fix this problem, positional encodings need to be added to the patch embeddings. Each positional encoding is unique to the position that it represents which allows model to identify which position each embedding is supposed to go. In order for the positional encodings to be added to the embeddings, they have to have the same dimension, d_model. We get the positional encodings by using the equations in Figure 6.

Figure 6: Positional encoding equations
pe = torch.zeros(max_seq_length, d_model)

for pos in range(max_seq_length):
for i in range(d_model):
if i % 2 == 0:
pe[pos][i] = np.sin(pos/(10000 ** (i/d_model)))
else:
pe[pos][i] = np.cos(pos/(10000 ** ((i-1)/d_model)))

self.register_buffer('pe', pe.unsqueeze(0))

In the forward method, the input is a batch of patch embeddings for multiple images. Because of this, we need to use the expand function in order to use self.cls_token to create class tokens for every image in the batch.

def forward(self, x):
tokens_batch = self.cls_token.expand(x.size()[0], -1, -1)

These classification tokens are then added to the beginning of each of the patch embeddings by using the torch.cat method.

x = torch.cat((tokens_batch,x), dim=1)

The positional encodings are added before being outputted.

x = x + self.pe

return x

Attention Head

class AttentionHead(nn.Module):
def __init__(self, d_model, head_size):
super().__init__()
self.head_size = head_size

self.query = nn.Linear(d_model, head_size)
self.key = nn.Linear(d_model, head_size)
self.value = nn.Linear(d_model, head_size)

def forward(self, x):
# Obtaining Queries, Keys, and Values
Q = self.query(x)
K = self.key(x)
V = self.value(x)

# Dot Product of Queries and Keys
attention = Q @ K.transpose(-2,-1)

# Scaling
attention = attention / (self.head_size ** 0.5)

attention = torch.softmax(attention, dim=-1)

attention = attention @ V

return attention
Figure 7: Scaled Dot-Product Attention and Multi-Head Attention diagrams

Vision transformers use attention which is a communication mechanism that allows the model to focus on important parts of an image. Attention scores can be calculated using the equation in Figure 8.

Figure 8: Attention equation

The first step in calculating attention is obtaining the queries, keys, and values of the tokens. The query of a token is what the token is looking for, the key is what the token contains, and the value is what is communicated between the tokens. The queries, keys, and values can be calculated by passing tokens through linear modules.

 def forward(self, x):
# Obtaining Queries, Keys, and Values
Q = self.query(x)
K = self.key(x)
V = self.value(x)

We are able to get the relationship between the tokens in a sequence by getting the dot product of the queries and keys.

# Dot Product of Queries and Keys
attention = Q @ K.transpose(-2,-1)

We need to scale these values to control variance at initialization so that tokens are able to aggregate information from more than one other token. Scaling is applied by dividing the dot product by the square root of the size of the attention head.

attention = attention / (self.head_size ** 0.5)

We then need to apply a soft max on the scaled dot product.

attention = torch.softmax(attention, dim=-1)

Finally, we need to get the dot product between the soft max and the values matrix. This is essentially communicating the information between the corresponding tokens.

attention = attention @ V

return attention

Multi-Head Attention

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.head_size = d_model // n_heads

self.W_o = nn.Linear(d_model, d_model)

self.heads = nn.ModuleList([AttentionHead(d_model, self.head_size) for _ in range(n_heads)])

def forward(self, x):
# Combine attention heads
out = torch.cat([head(x) for head in self.heads], dim=-1)

out = self.W_o(out)

return out

Multi-head attention is just running multiple heads of self-attention in parallel and combining them. We can do this by adding the attention heads into a module list,

self.heads = nn.ModuleList([AttentionHead(d_model, self.head_size) for _ in range(n_heads)])

passing through the input and concatenating the results.

def forward(self, x):
# Combine attention heads
out = torch.cat([head(x) for head in self.heads], dim=-1)

We then need to pass the output through another linear module.

out = self.W_o(out)
return out

Transformer Encoder

class TransformerEncoder(nn.Module):
def __init__(self, d_model, n_heads, r_mlp=4):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads

# Sub-Layer 1 Normalization
self.ln1 = nn.LayerNorm(d_model)

# Multi-Head Attention
self.mha = MultiHeadAttention(d_model, n_heads)

# Sub-Layer 2 Normalization
self.ln2 = nn.LayerNorm(d_model)

# Multilayer Perception
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model*r_mlp),
nn.GELU(),
nn.Linear(d_model*r_mlp, d_model)
)

def forward(self, x):
# Residual Connection After Sub-Layer 1
out = x + self.mha(self.ln1(x))

# Residual Connection After Sub-Layer 2
out = out + self.mlp(self.ln2(out))

return out

The transformer encoder is made up of two sub-layers: the first sub-layer performs multi-head attention and the second sub-layer contains a multi-layer perceptron. The multi-head attention sub-layer performs communication between tokens while the multi-layer perceptron sub-layers allows the tokens to individually “think” on what was communicated to them.

Layer normalization is an optimization technique that normalizes each input in the batch independently across its features. For our model, we will pass our inputs through a layer norm module at the beginning of each sub-layer.

# Sub-Layer 1 Normalization
self.ln1 = nn.LayerNorm(d_model)

# Sub-Layer 2 Normalization
self.ln2 = nn.LayerNorm(d_model)

The MLP will consist of two linear layers with a GELU layer in between. GELU is used instead of RELU because it doesn’t have RELU’s limitation of being non-differentiable at zero.

    # Encoder Multilayer Perception
self.mlp = nn.Sequential(
nn.Linear(width, width*r_mlp),
nn.GELU(),
nn.Linear(width*r_mlp, width)
)

In the forward method for the encoder, the input is passed through the first layer normalization module before performing multi-head attention. The original input is added to the output from performing multi-head attention to create a residual connection.

This is then passed through another layer normalization module before being inputted into the MLP. Another residual connection is created by adding the output from the MLP to the out from the first residual connection.

The residual connections are used to help prevent the vanishing gradient problem by creating a path for the gradient to be back-propagated unimpeded back to the original input.

def forward(self, x):
# Residual Connection After Sub-Layer 1
out = x + self.mha(self.ln1(x))

# Residual Connection After Sub-Layer 2
out = out + self.mlp(self.ln2(out))

return out

Vision Transformer

class VisionTransformer(nn.Module):
def __init__(self, d_model, n_classes, img_size, patch_size, n_channels, n_heads, n_layers):
super().__init__()

assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, "img_size dimensions must be divisible by patch_size dimensions"
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

self.d_model = d_model # Dimensionality of model
self.n_classes = n_classes # Number of classes
self.img_size = img_size # Image size
self.patch_size = patch_size # Patch size
self.n_channels = n_channels # Number of channels
self.n_heads = n_heads # Number of attention heads

self.n_patches = (self.img_size[0] * self.img_size[1]) // (self.patch_size[0] * self.patch_size[1])
self.max_seq_length = self.n_patches + 1

self.patch_embedding = PatchEmbedding(self.d_model, self.img_size, self.patch_size, self.n_channels)
self.positional_encoding = PositionalEncoding( self.d_model, self.max_seq_length)
self.transformer_encoder = nn.Sequential(*[TransformerEncoder( self.d_model, self.n_heads) for _ in range(n_layers)])

# Classification MLP
self.classifier = nn.Sequential(
nn.Linear(self.d_model, self.n_classes),
nn.Softmax(dim=-1)
)

def forward(self, images):
x = self.patch_embedding(images)

x = self.positional_encoding(x)

x = self.transformer_encoder(x)

x = self.classifier(x[:,0])

return x

When creating our vision transformer class, we first need to make sure that the input images can be split evenly into patches of size patch_size and that the dimensionality of the model is divisible by the number of attention heads.

assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, "img_size dimensions must be divisible by patch_size dimensions"
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

We also need to calculate the maximum sequence length for the positional encoding, which will be equal to the number of patches plus one. The number of patches can be found by dividing the product of the height and width of the input image by the product of the height and width of the patch size.

self.n_patches = (self.img_size[0] * self.img_size[1]) // (self.patch_size[0] * self.patch_size[1])
self.max_seq_length = self.n_patches + 1

The vision transformer will also need to be able to have multiple encoder modules. This can be achieved by putting a list of encoder layers inside of a sequential wrapper.

self.encoder = nn.Sequential(*[TransformerEncoder(self.d_model, self.n_heads) for _ in range(n_layers)])

The last part of the vision transformer model is the MLP classification head. This is consists of a linear layer followed by a soft-max layer.

self.classifier = nn.Sequential(
nn.Linear(self.d_model, self.n_classes),
nn.Softmax(dim=-1)
)

In the forward method, the input images are first passed through the patch embeddings layer to split the image into patches and get the sequence of linear embeddings for those patches. They are then passed through the positional encoding layer to add the classification token and positional encoding before being passed through the encoder modules. The classification tokens are then passed through the classification MLP to determine the classes of the images.

def forward(self, images):
x = self.patch_embedding(images)

x = self.position_embedding(x)

x = self.encoder(x)

x = self.classifier(x[:,0])

return x

We are done building the model. Now we need to train and test it.

Training Parameters

d_model = 9
n_classes = 10
img_size = (32,32)
patch_size = (16,16)
n_channels = 1
n_heads = 3
n_layers = 3
batch_size = 128
epochs = 5
alpha = 0.005

Loading MNIST Dataset

transform = T.Compose([
T.Resize(img_size),
T.ToTensor()
])

train_set = MNIST(
root="./../datasets", train=True, download=True, transform=transform
)
test_set = MNIST(
root="./../datasets", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)

Training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

transformer = VisionTransformer(d_model, n_classes, img_size, patch_size, n_channels, n_heads, n_layers).to(device)

optimizer = Adam(transformer.parameters(), lr=alpha)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):

training_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()

outputs = transformer(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

training_loss += loss.item()

print(f'Epoch {epoch + 1}/{epochs} loss: {training_loss / len(train_loader) :.3f}')

Testing

correct = 0
total = 0

with torch.no_grad():
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)

outputs = transformer(images)

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'\nModel Accuracy: {100 * correct // total} %')

Results

Using this model, we were able to achieve and accuracy of ~ 92% on the MNIST dataset training over only 5 epochs. This example demonstrates that self-attention can be used as a stand-in for a deep convolutional network. Read on to learn how to combine vision transformers with text.

https://medium.com/correll-lab/building-clip-from-scratch-68f6e42d35f4

--

--