Part 1: Building Vision Transformer from Scratch: A PyTorch Deep Dive Plus a Teaser on LORA for Part 2

Pasha Shaik
10 min readNov 1, 2023
Vision Transformer

If you’ve delved into the realm of deep learning, you’re likely aware of the impact that transformer architectures have had on the field of artificial intelligence. These architectures stand at the core of numerous groundbreaking advancements in AI. In this Article, we will embark on an in-depth exploration, guiding you through the process of building Vision Transformers from the ground up.

This article is the first in a four-part series. The next one will show how to build ‘LoRa’ from scratch, for the Vision Transformer we are building here.

I’ve also shared a fully functional example on Colab. You can find the link below.

Lets Define a configuration for the Vision Transformer model using a data class

@dataclass
class ModelArgs:
dim: int = 256 # Dimension of the model embeddings
hidden_dim: int = 512 # Dimension of the hidden layers
n_heads: int = 8 # Number of attention heads
n_layers: int = 6 # Number of layers in the transformer
patch_size: int = 4 # Size of the patches (typically square)
n_channels: int = 3 # Number of input channels (e.g., 3 for RGB images)
n_patches: int = 64 # Number of patches in the input
n_classes: int = 10 # Number of target classes
dropout: float = 0.2 # Dropout rate for regularization

MultiHead Attention

The MultiHeadAttention module in the provided code is an implementation of the multi-head self-attention mechanism, which stands as a fundamental component in transformer architectures. This self-attention mechanism empowers the model to weigh input elements differently, offering the capability to focus more intently on certain parts of the input when generating the output.

First, let’s code the multi-head attention block. Afterward, I’ll break down the key components in detail


class MultiHeadAttention(nn.Module):
def __init__(self, args:ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads

# Linear projections for Q, K, and V
self.wq = nn.Linear(self.dim, self.n_heads*self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_heads*self.head_dim, bias=False)
self.wv = nn.Linear(self.dim, self.n_heads*self.head_dim, bias=False)
self.wo = nn.Linear(self.n_heads*self.head_dim, self.dim, bias=False)

def forward(self, x):
b, seq_len, dim = x.shape # b: batch size, seq_len: sequence length

assert dim == self.dim, "dim is not matching"

q = self.wq(x) # [b, seq_len, n_heads*head_dim]
k = self.wk(x) # [b, seq_len, n_heads*head_dim]
v = self.wv(x) # [b, seq_len, n_heads*head_dim]

# Reshape the tensors for multi-head operations
q = q.contiguous().view(b, seq_len, self.n_heads, self.head_dim) # [b, seq_len, n_heads, head_dim]
k = k.contiguous().view(b, seq_len, self.n_heads, self.head_dim) # [b, seq_len, n_heads, head_dim]
v = v.contiguous().view(b, seq_len, self.n_heads, self.head_dim) # [b, seq_len, n_heads, head_dim]

# Transpose to bring the head dimension to the front
q = q.transpose(1, 2) # [b, n_heads, seq_len, head_dim]
k = k.transpose(1, 2) # [b, n_heads, seq_len, head_dim]
v = v.transpose(1, 2) # [b, n_heads, seq_len, head_dim]

# Compute attention scores and apply softmax
attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) # [b, n_heads, seq_len, seq_len]
attn_scores = F.softmax(attn, dim=-1) # [b, n_heads, seq_len, seq_len]

# Compute the attended features
out = torch.matmul(attn_scores, v) # [b, n_heads, seq_len, head_dim]
out = out.contiguous().view(b, seq_len, -1) # [b, seq_len, n_heads*head_dim]

return self.wo(out) # [b, seq_len, dim]

The MultiHeadAttention module performing the following operations:

  1. Linear transformations of the input tensor into “query” (Q), “key” (K), and “value” (V) representations.
q = self.wq(x)
k = self.wk(x)
v = self.wv(x)

2. Dividing these tensors into multiple “heads”.

q = q.contiguous().view(b, seq_len, self.n_heads, self.head_dim)
k = k.contiguous().view(b, seq_len, self.n_heads, self.head_dim)
v = v.contiguous().view(b, seq_len, self.n_heads, self.head_dim)

3. Computing attention scores via the dot product of Q and K.

    attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)

4. Applying softmax to these scores to procure attention weights.

    attn_scores = F.softmax(attn, dim=-1)

5. Multiplying the attention weights with the V tensor, yielding the attended features.

    out = torch.matmul(attn_scores, v)

6. Aggregating results across all heads and projecting to provide the concluding output.

    out = out.contiguous().view(b, seq_len, -1)
return self.wo(out)

Attention Block

The AttentionBlock module encapsulates a typical block found within transformer architectures. It primarily consists of two significant components: a multi-head self-attention mechanism and a feed-forward neural network (FFN). Additionally, layer normalization and skip connections (residual connections) are employed to facilitate better learning and gradient flow.

let’s code the multi-head attention block.

class AttentionBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(args.dim)
self.attn = MultiHeadAttention(args)

self.layer_norm_2 = nn.LayerNorm(args.dim)

self.ffn = nn.Sequential(
nn.Linear(args.dim, args.hidden_dim),
nn.GELU(),
nn.Dropout(args.dropout),
nn.Linear(args.hidden_dim, args.dim),
nn.Dropout(args.dropout)
)

def forward(self, x):
x = x + self.attn(self.layer_norm_1(x))
x = x + self.ffn(self.layer_norm_2(x))
return x

Let’s delve deeper into its structure:

  1. Layer Normalization (Pre-Attention) Before feeding the input x into the multi-head attention mechanism, it's normalized using nn.LayerNorm.
self.layer_norm_1 = nn.LayerNorm(args.dim)
x = self.layer_norm_1(x)

2. Multi-Head Self-Attention This component allows the model to focus on different parts of the input sequence when generating its output.

self.attn = MultiHeadAttention(args)
x = x + self.attn(x)

3. Layer Normalization (Pre-Feed-Forward Network) Just like before the multi-head attention mechanism, the output is normalized again using LayerNorm before feeding it into the FFN.

self.layer_norm_2 = nn.LayerNorm(args.dim)
x = self.layer_norm_2(x)

4. Feed-Forward Neural Network (FFN) The FFN consists of two linear layers separated by a GELU activation function. There’s also dropout applied for regularization.

self.ffn = nn.Sequential(
nn.Linear(args.dim, args.hidden_dim),
nn.GELU(),
nn.Dropout(args.dropout),
nn.Linear(args.hidden_dim, args.dim),
nn.Dropout(args.dropout)
)
x = x + self.ffn(x)

5. Residual Connections Residual or skip connections are vital for deep architectures like transformers. They help in preventing the vanishing gradient problem and helps in model convergence. In the code, these are represented by the addition operations where the input is added back to the output of both the attention mechanism and the FFN.

x = x + self.attn(...)
x = x + self.ffn(...)

By sequentially organizing the operations, this block ensures efficient and effective feature transformation, which is essential for the transformer’s performance.

Converting Image into Patches

Before creating our full vistion transformer model, we need to create a utility function that transforms images into non-overlapping patches.

def img_to_patch(x, patch_size, flatten_channels=True):
# x: Input image tensor
# B: Batch size, C: Channels, H: Height, W: Width
B, C, H, W = x.shape # (B, C, H, W)

# Reshape the image tensor to get non-overlapping patches
x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size) # (B, C, H/patch_size, patch_size, W/patch_size, patch_size)

# Permute to group the patches and channels
x = x.permute(0, 2, 4, 1, 3, 5) # (B, H/patch_size, W/patch_size, C, patch_size, patch_size)

# Flatten the height and width dimensions for patches
x = x.flatten(1,2) # (B, (H/patch_size * W/patch_size), C, patch_size, patch_size)

# Option to flatten the channel and spatial dimensions
if flatten_channels:
x = x.flatten(2,4) # (B, (H/patch_size * W/patch_size), (C * patch_size * patch_size))

return x

The img_to_patch function takes an image tensor and converts it into non-overlapping patches of a specified size. This operation is typically used in vision transformers to represent an image as a sequence of flattened patches. The function provides an option to flatten the channels or keep them separate.

Vision Transformer

The VisionTransformer effectively integrates the previously discussed components to construct the final model. It operates as an encoder-only architecture similar to BERT, where all tokens attend to all other tokens. Moreover, we introduce an additional class token (cls_token) to every sequence in the batch, and this will be utilized later for classification, much like how BERT does with its special [CLS] token.

class VisionTransformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

# Define the patch size
self.patch_size = args.patch_size

# Embedding layer to transform flattened patches to desired dimension
self.input_layer = nn.Linear(args.n_channels * (args.patch_size ** 2), args.dim)

# Create the attention blocks for the transformer
attn_blocks = []
for _ in range(args.n_layers):
attn_blocks.append(AttentionBlock(args))

# Create the transformer by stacking the attention blocks
self.transformer = nn.Sequential(*attn_blocks)

# Define the classifier
self.mlp = nn.Sequential(
nn.LayerNorm(args.dim),
nn.Linear(args.dim, args.n_classes)
)

# Dropout layer for regularization
self.dropout = nn.Dropout(args.dropout)

# Define the class token (similar to BERT's [CLS] token)
self.cls_token = nn.Parameter(torch.randn(1, 1, args.dim))

# Positional embeddings to give positional information to transformer
self.pos_embedding = nn.Parameter(torch.randn(1, 1+args.n_patches, args.dim))

def forward(self, x):
# Convert image to patches and flatten
x = img_to_patch(x, self.patch_size)
b, seq_len, _ = x.shape

# Transform patches using the embedding layer
x = self.input_layer(x)

# Add the class token to the beginning of each sequence
cls_token = self.cls_token.repeat(b, 1, 1)
x = torch.cat([cls_token, x], dim=1)

# Add positional embeddings to the sequence
x = x + self.pos_embedding[:,:seq_len+1]

# Apply dropout
x = self.dropout(x)

# Process sequence through the transformer
x = self.transformer(x)

# Retrieve the class token's representation (for classification)
x = x.transpose(0, 1)
cls = x[0]

# Classify using the representation of the class token
out = self.mlp(cls)
return out

Let’s dive into its key aspects:

  1. Patch Embedding Instead of operating on raw pixels, the image is divided into fixed-size patches. Each patch is then linearly transformed (flattened and passed through a linear layer) to a specified dimension (args.dim).
self.patch_size = args.patch_size
self.input_layer = nn.Linear(args.n_channels * (args.patch_size ** 2), args.dim)
x = img_to_patch(x, self.patch_size)
x = self.input_layer(x)

2. Transformer Blocks A sequence of attention blocks to process the embedded patches. The number of blocks is defined by args.n_layers.

attn_blocks = []
for _ in range(args.n_layers):
attn_blocks.append(AttentionBlock(args))
self.transformer = nn.Sequential(*attn_blocks)
x = self.transformer(x)

3. CLS Token and Position Embeddings A class token is added to the sequence of embedded patches. This token is later used to obtain the final classification output. Positional embeddings are added to provide the transformer with information about the relative positions of patches.

self.cls_token = nn.Parameter(torch.randn(1, 1, args.dim))
self.pos_embedding = nn.Parameter(torch.randn(1, 1+args.n_patches, args.dim))
cls_token = self.cls_token.repeat(b, 1, 1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.pos_embedding[:,:seq_len+1]

4. Dropout is applied for regularization purposes.

self.dropout = nn.Dropout(args.dropout)
x = self.dropout(x)

5. Classifier The classification head. It uses the class token’s[CLS] representation after it’s been processed by all transformer blocks.

self.mlp = nn.Sequential(
nn.LayerNorm(args.dim),
nn.Linear(args.dim, args.n_classes)
)
x = x.transpose(0, 1)
cls = x[0]
out = self.mlp(cls)

The Below code snippet provides a setup for preprocessing and loading the CIFAR10 dataset

# Path to the directory where CIFAR10 data will be stored/downloaded
DATA_DIR = "../data"

# Define the transformation for testing dataset:
# 1. Convert images to tensors.
# 2. Normalize the tensors using the mean and standard deviation of CIFAR10 dataset.
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
])

# Define the transformation for training dataset:
# 1. Apply random horizontal flip for data augmentation.
# 2. Perform random resizing and cropping of images for data augmentation.
# 3. Convert images to tensors.
# 4. Normalize the tensors using the mean and standard deviation of CIFAR10 dataset.
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
])

# Load the CIFAR10 training dataset with the defined training transformation.
# The dataset will be downloaded if not present in the DATA_DIR.
train_dataset = CIFAR10(root=DATA_DIR, train=True, transform=train_transform, download=True)

# Load the CIFAR10 testing dataset with the defined testing transformation.
# The dataset will be downloaded if not present in the DATA_DIR.
test_set = CIFAR10(root=DATA_DIR, train=False, transform=test_transform, download=True)

# Split the training dataset into training and validation sets.
# The training set will have 45000 images, and the validation set will have 5000 images.
train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])

Let’s setup the data loaders for training, validation and test datasets

# Define the batch size for training, validation, and testing.
batch_size = 16

# Define the number of subprocesses to use for data loading.
num_workers = 4

# Create a DataLoader for the training and validation dataset:
# 1. Shuffle the training data for each epoch.
# 2. Drop the last batch if its size is not equal to `batch_size` to maintain consistency.
train_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True)

# Do not drop any data; process all the validation data.
val_loader = torch.utils.data.DataLoader(dataset=val_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=False)

# Create a DataLoader for the testing dataset:
# Do not drop any data; process all the test data.
test_loader = torch.utils.data.DataLoader(dataset=test_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=False)

Let’s configure the model, optimization strategy, and training criterion using cross entropy loss

# Model, Loss and Optimizer
device = "cuda:0" if torch.cuda.is_available() else 0
args = ModelArgs()
model = VisionTransformer(args).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 130], gamma=0.1)

Time to bring our model to life! Let’s train it.

num_epochs = 150  # example value, adjust as needed

for epoch in range(num_epochs):

# Training Phase
model.train()
total_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)

# Zero the parameter gradients
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()

total_loss += loss.item()

avg_train_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")

# Validation Phase
model.eval()
total_val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

total_val_loss += loss.item()

_, predicted = outputs.max(dim=-1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

avg_val_loss = total_val_loss / len(val_loader)
val_accuracy = 100 * correct / total
print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

# Update the learning rate
lr_scheduler.step()

print("Training complete!")

Now, let’s evaluate our model’s accuracy on the test set.

model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)

# Forward pass
outputs = model(inputs)
_, predicted = outputs.max(dim=-1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy:.2f}%")

Test Accuracy: 75.07%

colab link:-

https://colab.research.google.com/github/pashanitw/pytorch_examples/blob/main/transformer/VisionTransformer.ipynb

Summary:-

Our model’s accuracy is lower than a basic CNN. This is because vision transformers need more data to excel. With the CIFAR-10 dataset, we expect around 75–80% accuracy. But on larger datasets, vision transformers often do better than CNNs.

Part-2 Sneak Peek:-

Coming up in Part 2: I’ll guide you through the hands-on implementation of LoRa (Low Rank Approximation) tailored for the vision transformer we’ve covered. Stay Tuned!

Conclusion:-

If you found value in this, connect or reach out to me on LinkedIn.

--

--

Pasha Shaik

Artificial Intelligence | Deep Learning | NLP | Computer Vision | Generative AI | LinkedIn https://www.linkedin.com/in/pasha-shaik