Vision Transformers

amirsina torfi
Machine Learning Mindset
5 min readMar 18, 2023

--

Vision Transformer (ViT) is a type of neural network architecture used for image recognition tasks. It was proposed in a 2020 paper titled “An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale” by Dosovitskiy et al.

Ref: https://arxiv.org/pdf/2010.11929.pdf

The motivation

Convolutional neural networks (CNNs) have traditionally been the go-to architecture for image recognition tasks. However, ViT uses a different approach by using a transformer-based architecture, which was originally proposed for natural language processing tasks.

The key idea behind ViT is to treat an image as a sequence of patches and then use a transformer-based model to process these patches. The patches are first flattened into a series of vectors, which are then processed by the transformer. This allows the model to capture long-range dependencies between different patches, potentially leading to better performance than traditional CNNs.

ViT has achieved state-of-the-art results on several image recognition benchmarks and has been shown to be effective even on tasks that require fine-grained recognition, such as recognizing individual bird species in images. However, ViT is also computationally expensive and requires a large amount of training data to achieve optimal performance.

Code Example

Here’s an example implementation of the Vision Transformer architecture using PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ViT(nn.Module):
def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim):
super().__init__()
assert image_size % patch_size == 0, "image size must be divisible by patch size"
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2

# create patch embeddings
self.patch_embeddings = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=False)

# create positional embeddings
self.positional_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, dim))

# create transformer encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

# create output layer
self.layer_norm = nn.LayerNorm(dim)
self.fc = nn.Linear(dim, num_classes)

def forward(self, x):
x = self.patch_embeddings(x) # extract patches
x = x.flatten(2) # flatten patches
x = x.transpose(1, 2) # rearrange dimensions for transformer input
x = torch.cat((self.positional_embeddings[:, :x.size(1), :], x), dim=1) # add positional embeddings
x = self.transformer_encoder(x) # apply transformer encoder
x = x.mean(dim=1) # compute mean of transformer output
x = self.layer_norm(x) # apply layer normalization
x = self.fc(x) # apply output layer
return x

This code defines a PyTorch module ViT that implements the Vision Transformer architecture. The constructor takes as input the image size, patch size, number of classes, dimension of the patch embeddings and transformer hidden layers, number of transformer heads, and dimension of the MLP feedforward layers.

The forward method takes an input image tensor and passes it through the patch embedding layer, positional embedding layer, transformer encoder, layer normalization layer, and output layer. The resulting tensor is returned as the output of the module.

And here’s an example of how to train the Vision Transformer model on the CIFAR-10 dataset using PyTorch.

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from vit import ViT

# define hyperparameters
image_size = 32
patch_size = 4
num_classes = 10
dim = 512
depth = 6
heads = 8
mlp_dim = 1024
batch_size = 128
epochs = 10
lr = 1e-4

# load CIFAR-10 dataset
transform = transforms.Compose([
transforms.RandomCrop(image_size, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# create model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(image_size=image_size, patch_size=patch_size, num_classes=num_classes, dim=dim, depth=depth, heads=heads, mlp_dim=mlp_dim).to(device)
optimizer = Adam(model.parameters(), lr=lr)

# train model
for epoch in range(epochs):
model.train()
for i, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()

if (i+1) % 100 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

In this example, we use the transforms module from PyTorch to apply data augmentation techniques to the CIFAR-10 dataset, such as random cropping and horizontal flipping. We then create a DataLoader to load the dataset into batches for training.

We create the ViT model, move it to the device (GPU or CPU), and define the optimizer as Adam. We then train the model for the specified number of epochs, iterating over each batch of images and labels. We compute the loss using cross-entropy and backpropagate the loss to update the weights of the model. We print the loss after every 100 steps to monitor the training progress.

Note: The vit module contains the implementation of the ViT class from the previous example. You can either define it in your code or import it from a separate module.

What do Vision Transformers promise?

The Vision Transformer (ViT) promises to provide a competitive alternative to Convolutional Neural Networks (CNNs) for image classification tasks. CNNs have been the dominant approach for image classification for many years, but they have some limitations, such as the need for hand-designed architectures and a lack of interpretability. ViT offers some advantages over CNNs:

  1. Scalability: ViT is highly scalable and can handle images of any size, without requiring any architecture modifications. This is because ViT operates on fixed-size patches of the input image, which are processed by a transformer network.
  2. Adaptability: ViT can easily adapt to different computer vision tasks beyond image classification. For example, it can be used for object detection, segmentation, and generation tasks.
  3. Interpretable: ViT is more interpretable than CNNs because it uses self-attention to compute global feature representations of the input image. This allows us to visualize and understand how the model attends to different parts of the image when making predictions.

Overall, the Vision Transformer is a promising direction for computer vision research, and it has already shown state-of-the-art performance on several benchmarks.

Conclusion

In conclusion, the Vision Transformer (ViT) is a recent neural network architecture that uses self-attention to compute global feature representations of an input image for classification. It has shown competitive performance on several benchmarks and promises advantages such as scalability, adaptability, and interpretability over traditional convolutional neural networks. ViT is a promising direction for computer vision research, and its success has spurred interest in exploring the potential of transformer-based architectures for other computer vision tasks beyond image classification.

--

--